Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions intervention/appendix_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"outputs": [],
"source": [
"# %%\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from task import get_acts, get_acts_pca, get_all_acts\n",
Expand Down Expand Up @@ -462,7 +463,7 @@
"\n",
"for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" results_mistral = pd.read_csv(\n",
" f\"{BASE_DIR}/mistral_{task_name}/results.csv\", skipinitialspace=True\n",
" Path(BASE_DIR) / f\"mistral_{task_name}\" / \"results.csv\", skipinitialspace=True\n",
" )\n",
"\n",
" results_mistral = results_mistral.rename(\n",
Expand All @@ -480,7 +481,7 @@
" print(sum(results_mistral[\"mistral_correct\"]))\n",
"\n",
" results_llama = pd.read_csv(\n",
" f\"{BASE_DIR}/llama_{task_name}/results.csv\", skipinitialspace=True\n",
" Path(BASE_DIR) / f\"llama_{task_name}\" / \"results.csv\", skipinitialspace=True\n",
" )\n",
"\n",
" results_llama = results_llama.rename(\n",
Expand Down
4 changes: 2 additions & 2 deletions intervention/circle_probe_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
probe_projections = {}
target_to_embeddings = {}

os.makedirs(f"{task.prefix}/circle_probes_{circle_letter}", exist_ok=True)
(task.prefix / f"circle_probes_{circle_letter}").mkdir(exist_ok=True)

all_maes = []
all_r_squareds = []
Expand Down Expand Up @@ -262,7 +262,7 @@
"probe_r": probe_r,
"target_to_embedding": target_to_embedding,
},
f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt",
task.prefix / f"circle_probes_{circle_letter}" / f"{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt",
)

mae = (predictions - multid_targets_train).abs().mean()
Expand Down
6 changes: 3 additions & 3 deletions intervention/days_of_week_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# %%

from pathlib import Path
import os
from utils import setup_notebook, BASE_DIR

Expand Down Expand Up @@ -49,9 +50,8 @@ def __init__(self, device, model_name="mistral", n_devices=None):
# Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall)
self.allowable_tokens = days_of_week

self.prefix = f"{BASE_DIR}{model_name}_days_of_week/"
if not os.path.exists(self.prefix):
os.makedirs(self.prefix)
self.prefix = Path(BASE_DIR) / f"{model_name}_days_of_week"
self.prefix.mkdir(parents=True, exist_ok=True)

self.num_tokens_in_answer = 1

Expand Down
2 changes: 1 addition & 1 deletion intervention/intervene_in_middle_of_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def vary_wthin_circle(circle_letter, duration, layer, token, pca_k, all_points):
model = task.get_model()

circle_projection_qr = torch.load(
f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt"
task.prefix / f"circle_probes_{circle_letter}" / f"cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt"
)

for problem in task.generate_problems():
Expand Down
5 changes: 3 additions & 2 deletions intervention/main_text_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"outputs": [],
"source": [
"# %%\n",
"from pathlib import Path\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from task import get_acts, get_acts_pca\n",
Expand Down Expand Up @@ -516,7 +517,7 @@
"for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" for model_name in [\"mistral\", \"llama\"]:\n",
" results = pd.read_csv(\n",
" f\"{BASE_DIR}/{model_name}_{task_name}/results.csv\", skipinitialspace=True\n",
" Path(BASE_DIR) / f\"{model_name}_{task_name}\" / \"results.csv\", skipinitialspace=True\n",
" )\n",
" number_correct = results[\"best_token\"] == results[\"ground_truth\"]\n",
" print(task_name, model_name, np.sum(number_correct))\n",
Expand Down Expand Up @@ -560,7 +561,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions intervention/months_of_year_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# %%

import os
from pathlib import Path
from utils import setup_notebook, BASE_DIR

setup_notebook()
Expand Down Expand Up @@ -71,9 +72,8 @@ def __init__(self, device, model_name="mistral", n_devices=None):
# Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall)
self.allowable_tokens = months_of_year

self.prefix = f"{BASE_DIR}{model_name}_months_of_year/"
if not os.path.exists(self.prefix):
os.makedirs(self.prefix)
self.prefix = Path(BASE_DIR) / f"{model_name}_months_of_year"
self.prefix.mkdir(parents=True, exist_ok=True)

self.num_tokens_in_answer = 1

Expand Down
32 changes: 16 additions & 16 deletions intervention/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from utils import BASE_DIR # Need this import to set the huggingface cache directory
import os
import numpy as np
Expand All @@ -24,7 +25,6 @@ def __str__(self):
def __repr__(self):
return str(self)


def generate_and_save_acts(
task,
names_filter,
Expand All @@ -39,10 +39,10 @@ def generate_and_save_acts(
forward_batch_size = 2
num_tokens_to_generate = task.num_tokens_in_answer
all_problems = task.generate_problems()
output_file = task.prefix + "results.csv"
output_file = task.prefix / "results.csv"

if save_results_csv:
os.makedirs(task.prefix, exist_ok=True)
task.prefix.mkdir(parents=True, exist_ok=True)
model_best_addition = "" if not save_best_logit else ", best_token"
with open(output_file, "w") as f:
f.write(
Expand Down Expand Up @@ -98,7 +98,7 @@ def generate_and_save_acts(
print(tensors.shape)
torch.save(
tensors,
f"{task.prefix}{save_file_prefix}{current_problem_index}.pt",
task.prefix / f"{save_file_prefix}{current_problem_index}.pt",
)

if save_results_csv:
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_all_acts(
all_problems = task.generate_problems()
all_problems_already_generated = True
for i in range(len(all_problems)):
if not os.path.exists(f"{task.prefix}{save_file_prefix}{i}.pt"):
if not (task.prefix / f"{save_file_prefix}{i}.pt").exists():
all_problems_already_generated = False
break
if not all_problems_already_generated or force_regenerate:
Expand All @@ -163,7 +163,7 @@ def get_all_acts(
all_acts = []
for i in range(0, len(all_problems)):
tensors = torch.load(
f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu"
task.prefix / f"{save_file_prefix}{i}.pt", map_location="cpu"
)
all_acts.append(tensors)
if len(all_acts) > 1:
Expand All @@ -186,17 +186,17 @@ def get_acts(
if save_file_prefix != "" and save_file_prefix[-1] != "_":
save_file_prefix += "_"
file_name = (
f"{task.prefix}{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt"
task.prefix / f"{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt"
)
if not os.path.exists(file_name) or force_regenerate:
if not file_name.exists() or force_regenerate:
print(file_name, "not exists")
all_acts = get_all_acts(
task, names_filter=names_filter, save_file_prefix=save_file_prefix
)
for layer in range(all_acts.shape[1]):
for token in range(all_acts.shape[2]):
file_name = (
f"{task.prefix}{save_file_prefix}layer{layer}_token{token}.pt"
task.prefix / f"{save_file_prefix}layer{layer}_token{token}.pt"
)
torch.save(
all_acts[:, layer, token, :].detach().cpu().clone(), file_name
Expand All @@ -218,11 +218,11 @@ def get_acts_pca(
names_filter=lambda x: "resid_post" in x or "hook_embed" in x,
save_file_prefix="",
):
act_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt"
pca_pkl_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl"
os.makedirs(f"{task.prefix}/pca/{save_file_prefix}", exist_ok=True)
act_file_name = task.prefix / "pca" / save_file_prefix / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt"
pca_pkl_file_name = task.prefix / "pca" / save_file_prefix / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl"
(task.prefix / "pca" / save_file_prefix).mkdir(parents=True, exist_ok=True)

if not os.path.exists(act_file_name) or not os.path.exists(pca_pkl_file_name):
if not act_file_name.exists() or not pca_pkl_file_name.exists():
acts = get_acts(
task,
layer,
Expand All @@ -239,9 +239,9 @@ def get_acts_pca(


def get_acts_pls(task, layer, token, pls_k, normalize_rms=False):
act_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt"
pls_pkl_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl"
os.makedirs(f"{task.prefix}/pls", exist_ok=True)
act_file_name = task.prefix / "pls" / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt"
pls_pkl_file_name = task.prefix / "pls" / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl"
(task.prefix / "pls").mkdir(parents=True, exist_ok=True)

# if not os.path.exists(act_file_name) or not os.path.exists(pls_pkl_file_name):
if True:
Expand Down
5 changes: 3 additions & 2 deletions intervention/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import dill as pickle
from pathlib import Path

BASE_DIR = "/data/scratch/jae/"
BASE_DIR = Path(__file__).parent.parent / "cache"

os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}/.cache/"
os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/"


def setup_notebook():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# %%

from pathlib import Path
import os
from utils import BASE_DIR


# hopefully this will help with memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}.cache/"
os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/"

import einops
import numpy as np
Expand Down Expand Up @@ -52,8 +52,8 @@

num_sae_activations_to_save = 10**9

save_folder = f"{BASE_DIR}{model_name}"
os.makedirs(save_folder, exist_ok=True)
save_folder = Path(BASE_DIR) / model_name
save_folder.mkdir(exist_ok=True, parents=True)

t.set_grad_enabled(False)

Expand Down
4 changes: 3 additions & 1 deletion sae_multid_feature_discovery/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

from pathlib import Path
from huggingface_hub import hf_hub_download
import os

BASE_DIR = "/data/scratch/jae/"
BASE_DIR = Path(__file__).parent.parent / "cache"

def get_gpt2_sae(device, layer):
from sae_lens import SAE
Expand Down