From 2fd7cf4cd6fae0b80dc994db687edcbd2071996f Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 15:10:25 -0700 Subject: [PATCH 01/19] Fix some typos in README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7e7f6a8..bf5ec4c 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,11 @@ This is the github repo for our paper ["Not All Language Model Features Are Line Below are instructions to reproduce each figure (aspirationally). -The required pthon packages to run this repo are +The required Python packages to run this repo are ``` transformer_lens sae_lens transformers datasets torch adjustText circuitsvis ipython ``` -We recommend you creat a new python venv named multid and install these packages, +We recommend you create a new python venv named multid and install these packages, either manually using pip or using the existing requirements.txt if you are on a linux machine with Cuda 12.1: ``` From 694f222e4e46b6c3fd823d4ce5683f84fee45bb2 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 15:13:38 -0700 Subject: [PATCH 02/19] README: add some hyperlinks and code font --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 7e7f6a8..4ea328f 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Let us know if anything does not work with this environment! ### Intervention Experiments -Before running experiments, you should change BASE_DIR in intervention/utils.py to point to a location on your machine where large artifacts can be downloaded and saved (Mistral and Llama 3 take ~60GB and experiment artifacts are ~100GB). +Before running experiments, you should change `BASE_DIR` in [`intervention/utils.py`](./intervention/utils.py) to point to a location on your machine where large artifacts can be downloaded and saved (Mistral and Llama 3 take ~60GB and experiment artifacts are ~100GB). To reproduce the intervention results, you will first need to run intervention experiments with the following commands: @@ -37,7 +37,7 @@ python3 circle_probe_interventions.py day a llama --device 0 --intervention_pca_ python3 circle_probe_interventions.py month a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin ``` -You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in intervention/main_text_plots.ipynb. +You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in [`intervention/main_text_plots.ipynb`](./intervention/main_text_plots.ipynb). After running these intervention experiments, you can reproduce *Figure 6* by running @@ -45,16 +45,16 @@ After running these intervention experiments, you can reproduce *Figure 6* by ru cd intervention python3 intervene_in_middle_of_circle.py --only_paper_plots ``` -and then running the corresponding cell in intervention/main_text_plots.ipynb. +and then running the corresponding cell in [`intervention/main_text_plots.ipynb`](./intervention/main_text_plots.ipynb). You can reproduce *Figure 13*, *Figure 14*, *Figure 15*, *Table 2*, *Table 3*, and *Table 4* (all from the appendix) by running cells in intervention/appendix_plots.ipynb. ### SAE feature search experiments -Before running experiments, you should again change BASE_DIR in sae_multid_feature_discovery/utils.py to point to a location on your machine where large artifacts can be downloaded and saved. +Before running experiments, you should again change `BASE_DIR` in [`sae_multid_feature_discovery/utils.py`](./sae_multid_feature_discovery/utils.py) to point to a location on your machine where large artifacts can be downloaded and saved. -You will need to generate SAE feature activations to generate the cluster reconstructions. The GPT-2 SAEs will be automatically downloaded when you run the below scripts, while for Mistral you will need to download our pretrained Mistral SAEs from https://www.dropbox.com/scl/fo/hznwqj4fkqvpr7jtx9uxz/AJUe0wKmJS1-fD982PuHb5A?rlkey=ffnq6pm6syssf2p7t98q9kuh1&dl=0 to sae_multid_feature_discovery/saes/mistral_saes. You can generate SAE feature activations with one of the following two commands: +You will need to generate SAE feature activations to generate the cluster reconstructions. The GPT-2 SAEs will be automatically downloaded when you run the below scripts, while for Mistral you will need to download our pretrained Mistral SAEs from https://www.dropbox.com/scl/fo/hznwqj4fkqvpr7jtx9uxz/AJUe0wKmJS1-fD982PuHb5A?rlkey=ffnq6pm6syssf2p7t98q9kuh1&dl=0 to [`sae_multid_feature_discovery/saes/mistral_saes`](./sae_multid_feature_discovery/saes/mistral_saes). You can generate SAE feature activations with one of the following two commands: ``` cd sae_multid_feature_discovery @@ -64,7 +64,7 @@ python3 generate_feature_occurence_data.py --model_name mistral You can also directly download the gpt-2 layer 7 and Mistral-7B layer 8 activations data from this Dropbox folder: https://www.dropbox.com/scl/fo/frn4tihzkvyesqoumtl9u/AFPEAa6KFb8mY3NTXIEStnA?rlkey=z60j3g45jzhxwc5s5qxmbjvxs&st=da2tzqk5&dl=0. You should put them in the `sae_multid_feature_discovery` directory. -You will also need to generate the actual clusters by running clustering.py, e.g. +You will also need to generate the actual clusters by running `clustering.py`, e.g. ``` python3 clustering.py --model_name gpt-2 --clustering_type spectral --layer 7 python3 clustering.py --model_name mistral --clustering_type graph --layer 8 @@ -120,7 +120,7 @@ To reproduce the residual RGB plots in the paper (*Figure 8*, and *Figure 16*), ## Contact -If you have any questions about the paper or reproducing results, feel free to email jengels@mit.edu. +If you have any questions about the paper or reproducing results, feel free to email [jengels@mit.edu](mailto:jengels@mit.edu). ## Citation From 4b01f898db0f0f62a0b2cc58ac4321dedc848b43 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 13 Sep 2024 18:00:44 -0700 Subject: [PATCH 03/19] `legendHandles` => `legend_handles` As per https://github.com/matplotlib/matplotlib/blob/42b88d01fdd93846d925b1097167d36ea31c7733/doc/api/prev_api_changes/api_changes_3.9.0/removals.rst?plain=1#L147, `legendHandles` was removed in matplotlib 3.9 after previously being deprecated in favor of `legend_handles`. Since `requirements.txt` requires matplotlib 3.9, we just use the new name. Fixes #12 Done with ``` git grep --name-only legendHandles | xargs sed -i.bak 's/legendHandles/legend_handles/g' ``` --- intervention/appendix_plots.ipynb | 4 ++-- intervention/main_text_plots.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/intervention/appendix_plots.ipynb b/intervention/appendix_plots.ipynb index e64984b..9771126 100644 --- a/intervention/appendix_plots.ipynb +++ b/intervention/appendix_plots.ipynb @@ -108,8 +108,8 @@ " frameon=False,\n", " )\n", "\n", - " for i in range(len(legend.legendHandles)):\n", - " legend.legendHandles[i]._sizes = [2]\n", + " for i in range(len(legend.legend_handles)):\n", + " legend.legend_handles[i]._sizes = [2]\n", "\n", " plt.tight_layout()\n", "\n", diff --git a/intervention/main_text_plots.ipynb b/intervention/main_text_plots.ipynb index 2266226..e83e129 100644 --- a/intervention/main_text_plots.ipynb +++ b/intervention/main_text_plots.ipynb @@ -309,7 +309,7 @@ " columnspacing=1,\n", " handlelength=0.8,\n", ")\n", - "for legobj in leg.legendHandles:\n", + "for legobj in leg.legend_handles:\n", " legobj.set_linewidth(1.5)\n", "\n", "fig.add_artist(\n", @@ -404,7 +404,7 @@ " columnspacing=0,\n", ")\n", "for i in range(circle_size):\n", - " lgnd.legendHandles[i]._sizes = [10]\n", + " lgnd.legend_handles[i]._sizes = [10]\n", "\n", "plt.show()\n", "\n", From e2bf9ad73b1b7d6cb856800545f255e23f9f50c7 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 15:02:27 -0700 Subject: [PATCH 04/19] Fix to support MPS: convert to float32 earlier --- intervention/circle_finding_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intervention/circle_finding_utils.py b/intervention/circle_finding_utils.py index 53f2ce6..64da461 100644 --- a/intervention/circle_finding_utils.py +++ b/intervention/circle_finding_utils.py @@ -113,7 +113,7 @@ def get_logit_diffs_from_subspace_formula_resid_intervention( probe_r = probe_r.to(device) target_embedding_in_q_space = target_to_embedding.to(device) @ probe_r.inverse() - pca_projection_matrix = torch.tensor(pca_projection_matrix).to(device).T.float() + pca_projection_matrix = torch.tensor(pca_projection_matrix).float().to(device).T all_pcas = ( torch.tensor( From ebd14079278beda8236f2b1bad65beb69f8f7eaf Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 27 Aug 2024 22:28:44 -0700 Subject: [PATCH 05/19] Don't fail with "AssertionError: Not enough CUDA devices to support n_devices 2" when there aren't 2 GPUs ``` File "/home/jason/MultiDimensionalFeatures/multid/lib/python3.10/site-packages/transformer_lens/HookedTransformerConfig.py", line 315, in __post_init__ torch.cuda.device_count() >= self.n_devices AssertionError: Not enough CUDA devices to support n_devices 2 ``` --- intervention/days_of_week_task.py | 7 ++++++- intervention/months_of_year_task.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..0819ab3 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -7,6 +7,7 @@ import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching @@ -148,7 +149,11 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..49e9e2b 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -7,6 +7,7 @@ import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching @@ -159,7 +160,11 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( From f3ae2c7391bc112d2708f571b75b9c14ad16cc92 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 14:06:41 -0700 Subject: [PATCH 06/19] Revert to cpu if cuda is not available Also allow passing `--device cpu`, `--device mps`, etc to `circle_probe_interventions.py`, and remove the logic for special handling of numbers *n* as `cuda:n` to simplify logic. --- README.md | 16 +++++++--------- intervention/circle_probe_interventions.py | 11 ++++++++--- intervention/days_of_week_task.py | 2 +- intervention/months_of_year_task.py | 3 ++- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 7e7f6a8..a5b5fb3 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ This is the github repo for our paper ["Not All Language Model Features Are Line ## Reproducing each figure -Below are instructions to reproduce each figure (aspirationally). +Below are instructions to reproduce each figure (aspirationally). The required pthon packages to run this repo are ``` @@ -17,7 +17,7 @@ either manually using pip or using the existing requirements.txt if you are on a machine with Cuda 12.1: ``` python -m venv multid -pip install -r requirements.txt +pip install -r requirements.txt OR pip install transformer_lens sae_lens transformers datasets torch adjustText circuitsvis ipython ``` @@ -31,16 +31,16 @@ To reproduce the intervention results, you will first need to run intervention e ``` cd intervention -python3 circle_probe_interventions.py day a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py month a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py day a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py month a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py day a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py month a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py day a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py month a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin ``` You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in intervention/main_text_plots.ipynb. -After running these intervention experiments, you can reproduce *Figure 6* by running +After running these intervention experiments, you can reproduce *Figure 6* by running ``` cd intervention python3 intervene_in_middle_of_circle.py --only_paper_plots @@ -132,5 +132,3 @@ If you have any questions about the paper or reproducing results, feel free to e year={2024} } ``` - - diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..ba481c5 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -41,7 +41,12 @@ choices=["llama", "mistral"], help="Choose 'llama' or 'mistral' model", ) - parser.add_argument("--device", type=int, default=4, help="CUDA device number") + parser.add_argument( + "--device", + type=str, + default="cuda:4" if torch.cuda.is_available() else "cpu", + help="Device to use", + ) parser.add_argument( "--use_inverse_regression_probe", action="store_true", @@ -73,7 +78,7 @@ help="Probe on linear representation with center of 0.", ) args = parser.parse_args() - device = f"cuda:{args.device}" + device = args.device day_month_choice = args.problem_type circle_letter = args.intervene_on model_name = args.model @@ -100,7 +105,7 @@ # use_inverse_regression_probe = False # intervention_pca_k = 5 - device = "cuda:4" + device = "cuda:4" if torch.cuda.is_available() else "cpu" circle_letter = "c" day_month_choice = "day" model_name = "mistral" diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..4e8e061 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -11,7 +11,7 @@ from task import activation_patching -device = "cuda:4" +device = "cuda:4" if torch.cuda.is_available() else "cpu" # # %% diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..6adfef3 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -5,13 +5,14 @@ setup_notebook() +import torch import numpy as np import transformer_lens from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching -device = "cuda:4" +device = "cuda:4" if torch.cuda.is_available() else "cpu" # # %% From cbfac5b42875f39408b355d7486be80172746ba5 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 13:56:17 -0700 Subject: [PATCH 07/19] Use pathlib more and have default cache paths be relative to the repo directory --- intervention/appendix_plots.ipynb | 5 +-- intervention/circle_probe_interventions.py | 4 +-- intervention/days_of_week_task.py | 6 ++-- intervention/intervene_in_middle_of_circle.py | 2 +- intervention/main_text_plots.ipynb | 5 +-- intervention/months_of_year_task.py | 6 ++-- intervention/task.py | 32 +++++++++---------- intervention/utils.py | 5 +-- .../generate_feature_occurence_data.py | 8 ++--- sae_multid_feature_discovery/utils.py | 4 ++- 10 files changed, 41 insertions(+), 36 deletions(-) diff --git a/intervention/appendix_plots.ipynb b/intervention/appendix_plots.ipynb index e64984b..9b872e2 100644 --- a/intervention/appendix_plots.ipynb +++ b/intervention/appendix_plots.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "# %%\n", + "from pathlib import Path\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from task import get_acts, get_acts_pca, get_all_acts\n", @@ -462,7 +463,7 @@ "\n", "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " results_mistral = pd.read_csv(\n", - " f\"{BASE_DIR}/mistral_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"mistral_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", "\n", " results_mistral = results_mistral.rename(\n", @@ -480,7 +481,7 @@ " print(sum(results_mistral[\"mistral_correct\"]))\n", "\n", " results_llama = pd.read_csv(\n", - " f\"{BASE_DIR}/llama_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"llama_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", "\n", " results_llama = results_llama.rename(\n", diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..5012fd9 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -171,7 +171,7 @@ probe_projections = {} target_to_embeddings = {} -os.makedirs(f"{task.prefix}/circle_probes_{circle_letter}", exist_ok=True) +(task.prefix / f"circle_probes_{circle_letter}").mkdir(exist_ok=True) all_maes = [] all_r_squareds = [] @@ -262,7 +262,7 @@ "probe_r": probe_r, "target_to_embedding": target_to_embedding, }, - f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt", + task.prefix / f"circle_probes_{circle_letter}" / f"{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt", ) mae = (predictions - multid_targets_train).abs().mean() diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..fa159da 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -1,5 +1,6 @@ # %% +from pathlib import Path import os from utils import setup_notebook, BASE_DIR @@ -49,9 +50,8 @@ def __init__(self, device, model_name="mistral", n_devices=None): # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = days_of_week - self.prefix = f"{BASE_DIR}{model_name}_days_of_week/" - if not os.path.exists(self.prefix): - os.makedirs(self.prefix) + self.prefix = Path(BASE_DIR) / f"{model_name}_days_of_week" + self.prefix.mkdir(parents=True, exist_ok=True) self.num_tokens_in_answer = 1 diff --git a/intervention/intervene_in_middle_of_circle.py b/intervention/intervene_in_middle_of_circle.py index 2fb75fb..95e1ddd 100644 --- a/intervention/intervene_in_middle_of_circle.py +++ b/intervention/intervene_in_middle_of_circle.py @@ -40,7 +40,7 @@ def vary_wthin_circle(circle_letter, duration, layer, token, pca_k, all_points): model = task.get_model() circle_projection_qr = torch.load( - f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt" + task.prefix / f"circle_probes_{circle_letter}" / f"cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt" ) for problem in task.generate_problems(): diff --git a/intervention/main_text_plots.ipynb b/intervention/main_text_plots.ipynb index 2266226..bb65ad4 100644 --- a/intervention/main_text_plots.ipynb +++ b/intervention/main_text_plots.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "# %%\n", + "from pathlib import Path\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from task import get_acts, get_acts_pca\n", @@ -516,7 +517,7 @@ "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " for model_name in [\"mistral\", \"llama\"]:\n", " results = pd.read_csv(\n", - " f\"{BASE_DIR}/{model_name}_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"{model_name}_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", " number_correct = results[\"best_token\"] == results[\"ground_truth\"]\n", " print(task_name, model_name, np.sum(number_correct))\n", @@ -560,7 +561,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..5f9d411 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -1,6 +1,7 @@ # %% import os +from pathlib import Path from utils import setup_notebook, BASE_DIR setup_notebook() @@ -71,9 +72,8 @@ def __init__(self, device, model_name="mistral", n_devices=None): # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = months_of_year - self.prefix = f"{BASE_DIR}{model_name}_months_of_year/" - if not os.path.exists(self.prefix): - os.makedirs(self.prefix) + self.prefix = Path(BASE_DIR) / f"{model_name}_months_of_year" + self.prefix.mkdir(parents=True, exist_ok=True) self.num_tokens_in_answer = 1 diff --git a/intervention/task.py b/intervention/task.py index 5c02857..93f798f 100644 --- a/intervention/task.py +++ b/intervention/task.py @@ -1,3 +1,4 @@ +from pathlib import Path from utils import BASE_DIR # Need this import to set the huggingface cache directory import os import numpy as np @@ -24,7 +25,6 @@ def __str__(self): def __repr__(self): return str(self) - def generate_and_save_acts( task, names_filter, @@ -39,10 +39,10 @@ def generate_and_save_acts( forward_batch_size = 2 num_tokens_to_generate = task.num_tokens_in_answer all_problems = task.generate_problems() - output_file = task.prefix + "results.csv" + output_file = task.prefix / "results.csv" if save_results_csv: - os.makedirs(task.prefix, exist_ok=True) + task.prefix.mkdir(parents=True, exist_ok=True) model_best_addition = "" if not save_best_logit else ", best_token" with open(output_file, "w") as f: f.write( @@ -98,7 +98,7 @@ def generate_and_save_acts( print(tensors.shape) torch.save( tensors, - f"{task.prefix}{save_file_prefix}{current_problem_index}.pt", + task.prefix / f"{save_file_prefix}{current_problem_index}.pt", ) if save_results_csv: @@ -146,7 +146,7 @@ def get_all_acts( all_problems = task.generate_problems() all_problems_already_generated = True for i in range(len(all_problems)): - if not os.path.exists(f"{task.prefix}{save_file_prefix}{i}.pt"): + if not (task.prefix / f"{save_file_prefix}{i}.pt").exists(): all_problems_already_generated = False break if not all_problems_already_generated or force_regenerate: @@ -163,7 +163,7 @@ def get_all_acts( all_acts = [] for i in range(0, len(all_problems)): tensors = torch.load( - f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu" + task.prefix / f"{save_file_prefix}{i}.pt", map_location="cpu" ) all_acts.append(tensors) if len(all_acts) > 1: @@ -186,9 +186,9 @@ def get_acts( if save_file_prefix != "" and save_file_prefix[-1] != "_": save_file_prefix += "_" file_name = ( - f"{task.prefix}{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt" + task.prefix / f"{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt" ) - if not os.path.exists(file_name) or force_regenerate: + if not file_name.exists() or force_regenerate: print(file_name, "not exists") all_acts = get_all_acts( task, names_filter=names_filter, save_file_prefix=save_file_prefix @@ -196,7 +196,7 @@ def get_acts( for layer in range(all_acts.shape[1]): for token in range(all_acts.shape[2]): file_name = ( - f"{task.prefix}{save_file_prefix}layer{layer}_token{token}.pt" + task.prefix / f"{save_file_prefix}layer{layer}_token{token}.pt" ) torch.save( all_acts[:, layer, token, :].detach().cpu().clone(), file_name @@ -218,11 +218,11 @@ def get_acts_pca( names_filter=lambda x: "resid_post" in x or "hook_embed" in x, save_file_prefix="", ): - act_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt" - pca_pkl_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl" - os.makedirs(f"{task.prefix}/pca/{save_file_prefix}", exist_ok=True) + act_file_name = task.prefix / "pca" / save_file_prefix / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt" + pca_pkl_file_name = task.prefix / "pca" / save_file_prefix / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl" + (task.prefix / "pca" / save_file_prefix).mkdir(parents=True, exist_ok=True) - if not os.path.exists(act_file_name) or not os.path.exists(pca_pkl_file_name): + if not act_file_name.exists() or not pca_pkl_file_name.exists(): acts = get_acts( task, layer, @@ -239,9 +239,9 @@ def get_acts_pca( def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): - act_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt" - pls_pkl_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl" - os.makedirs(f"{task.prefix}/pls", exist_ok=True) + act_file_name = task.prefix / "pls" / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt" + pls_pkl_file_name = task.prefix / "pls" / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl" + (task.prefix / "pls").mkdir(parents=True, exist_ok=True) # if not os.path.exists(act_file_name) or not os.path.exists(pls_pkl_file_name): if True: diff --git a/intervention/utils.py b/intervention/utils.py index 216a26c..d3bfd1f 100644 --- a/intervention/utils.py +++ b/intervention/utils.py @@ -1,9 +1,10 @@ import os import dill as pickle +from pathlib import Path -BASE_DIR = "/data/scratch/jae/" +BASE_DIR = Path(__file__).parent.parent / "cache" -os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}/.cache/" +os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/" def setup_notebook(): diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..022c09c 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -1,12 +1,12 @@ # %% - +from pathlib import Path import os from utils import BASE_DIR # hopefully this will help with memory fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" -os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}.cache/" +os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/" import einops import numpy as np @@ -52,8 +52,8 @@ num_sae_activations_to_save = 10**9 -save_folder = f"{BASE_DIR}{model_name}" -os.makedirs(save_folder, exist_ok=True) +save_folder = Path(BASE_DIR) / model_name +save_folder.mkdir(exist_ok=True, parents=True) t.set_grad_enabled(False) diff --git a/sae_multid_feature_discovery/utils.py b/sae_multid_feature_discovery/utils.py index f69f0ab..23f40dc 100644 --- a/sae_multid_feature_discovery/utils.py +++ b/sae_multid_feature_discovery/utils.py @@ -1,7 +1,9 @@ + +from pathlib import Path from huggingface_hub import hf_hub_download import os -BASE_DIR = "/data/scratch/jae/" +BASE_DIR = Path(__file__).parent.parent / "cache" def get_gpt2_sae(device, layer): from sae_lens import SAE From 6357ecda5a7523052d6ae0397053c8c107f63c29 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 14:01:29 -0700 Subject: [PATCH 08/19] Add support for loading 16-bit models --- intervention/circle_probe_interventions.py | 7 +++++-- intervention/days_of_week_task.py | 7 +++++-- intervention/intervene_in_middle_of_circle.py | 7 ++++--- intervention/months_of_year_task.py | 10 ++++++++-- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..71f748d 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -42,6 +42,7 @@ help="Choose 'llama' or 'mistral' model", ) parser.add_argument("--device", type=int, default=4, help="CUDA device number") + parser.add_argument("--dtype", type=str, default="float32", help="Data type for torch tensors") parser.add_argument( "--use_inverse_regression_probe", action="store_true", @@ -74,6 +75,7 @@ ) args = parser.parse_args() device = f"cuda:{args.device}" + dtype = args.dtype day_month_choice = args.problem_type circle_letter = args.intervene_on model_name = args.model @@ -101,6 +103,7 @@ # intervention_pca_k = 5 device = "cuda:4" + dtype = "float32" circle_letter = "c" day_month_choice = "day" model_name = "mistral" @@ -131,9 +134,9 @@ # %% if day_month_choice == "day": - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype) else: - task = MonthsOfYearTask(device, model_name=model_name) + task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype) # %% diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..c82c01a 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -39,13 +39,15 @@ class DaysOfWeekTask: - def __init__(self, device, model_name="mistral", n_devices=None): + def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"): self.device = device self.model_name = model_name self.n_devices = n_devices + self.dtype = dtype + # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = days_of_week @@ -152,7 +154,7 @@ def get_model(self): if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( - "mistral-7b", device=self.device, n_devices=self.n_devices + "mistral-7b", device=self.device, n_devices=self.n_devices, dtype=self.dtype ) elif self.model_name == "llama": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( @@ -160,6 +162,7 @@ def get_model(self): "meta-llama/Meta-Llama-3-8B", device=self.device, n_devices=self.n_devices, + dtype=self.dtype, ) return self._lazy_model diff --git a/intervention/intervene_in_middle_of_circle.py b/intervention/intervene_in_middle_of_circle.py index 2fb75fb..6e056ca 100644 --- a/intervention/intervene_in_middle_of_circle.py +++ b/intervention/intervene_in_middle_of_circle.py @@ -257,6 +257,7 @@ def get_circle_hook(layer, circle_point): parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu" ) + parser.add_argument("--dtype", type=str, default="float32") args = parser.parse_args() @@ -265,7 +266,7 @@ def get_circle_hook(layer, circle_point): if args.only_paper_plots: task_level_granularity = "day" model_name = "mistral" - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=args.dtype) layer = 5 bs = range(2, 6) pca_k = 5 @@ -282,9 +283,9 @@ def get_circle_hook(layer, circle_point): bs = range(1, 13) for b in bs: if task_level_granularity == "day": - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=args.dtype) elif task_level_granularity == "month": - task = MonthsOfYearTask(device, model_name=model_name) + task = MonthsOfYearTask(device, model_name=model_name, dtype=args.dtype) else: raise ValueError(f"Unknown {task_level_granularity}") for pca_k in [5]: diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..b98018b 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -61,13 +61,15 @@ class MonthsOfYearTask: - def __init__(self, device, model_name="mistral", n_devices=None): + def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"): self.device = device self.model_name = model_name self.n_devices = n_devices + self.dtype = dtype + # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = months_of_year @@ -163,7 +165,10 @@ def get_model(self): if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( - "mistral-7b", device=self.device, n_devices=self.n_devices + "mistral-7b", + device=self.device, + n_devices=self.n_devices, + dtype=self.dtype, ) elif self.model_name == "llama": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( @@ -171,6 +176,7 @@ def get_model(self): "meta-llama/Meta-Llama-3-8B", device=self.device, n_devices=self.n_devices, + dtype=self.dtype, ) return self._lazy_model From f9730770b7c7491aa69ee8b70e28d0fbcb94e886 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Fri, 13 Sep 2024 18:45:07 -0700 Subject: [PATCH 09/19] Add use of dtype in ipynb --- intervention/appendix_plots.ipynb | 19 ++++++++++--------- intervention/main_text_plots.ipynb | 20 +++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/intervention/appendix_plots.ipynb b/intervention/appendix_plots.ipynb index e64984b..9c9f7e1 100644 --- a/intervention/appendix_plots.ipynb +++ b/intervention/appendix_plots.ipynb @@ -28,7 +28,8 @@ "\n", "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n", "\n", - "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"" + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "dtype = \"float32\"" ] }, { @@ -130,9 +131,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"task_name\", \"months_of_year\"]:\n", " if task_name == \"{task_name}\":\n", - " task = DaysOfWeekTask(model_name=model_name, device=device)\n", + " task = DaysOfWeekTask(model_name=model_name, device=device, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(model_name=model_name, device=device)\n", + " task = MonthsOfYearTask(model_name=model_name, device=device, dtype=dtype)\n", "\n", " for keep_same_index in [0, 1]:\n", " for layer_type in [\"mlp\", \"attention\", \"resid\", \"attention_head\"]:\n", @@ -186,9 +187,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " for patching_type in [\"mlp\", \"attention\"]:\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", @@ -283,9 +284,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", "\n", @@ -369,9 +370,9 @@ "data = []\n", "for model_name, task_name in all_top_heads.keys():\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " acts = get_all_acts(\n", " task,\n", diff --git a/intervention/main_text_plots.ipynb b/intervention/main_text_plots.ipynb index 2266226..19cf6d4 100644 --- a/intervention/main_text_plots.ipynb +++ b/intervention/main_text_plots.ipynb @@ -25,7 +25,9 @@ "\n", "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n", "\n", - "torch.set_grad_enabled(False)" + "torch.set_grad_enabled(False)\n", + "device = \"cpu\"\n", + "dtype = \"float32\"" ] }, { @@ -57,7 +59,7 @@ "\n", "\n", "# Left plot\n", - "task = DaysOfWeekTask(\"cpu\", \"mistral\")\n", + "task = DaysOfWeekTask(device, \"mistral\", dtype=dtype)\n", "problems = task.generate_problems()\n", "tokens = task.allowable_tokens\n", "acts = get_acts_pca(task, layer=30, token=task.a_token, pca_k=2)[0]\n", @@ -88,7 +90,7 @@ "ax1.set_ylim(-8, 8)\n", "\n", "# Right plot\n", - "task = MonthsOfYearTask(\"cpu\", \"llama\")\n", + "task = MonthsOfYearTask(device, \"llama\", dtype=dtype)\n", "problems = task.generate_problems()\n", "tokens = task.allowable_tokens\n", "acts = get_acts_pca(task, layer=3, token=task.a_token, pca_k=2)[0]\n", @@ -350,7 +352,7 @@ "s = 0.1\n", "\n", "\n", - "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n", + "task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n", "layer = 5\n", "token = task.a_token\n", "durations = range(2, 6)\n", @@ -430,7 +432,7 @@ "fig = plt.figure(figsize=(1.65, 1.5))\n", "ax = plt.gca()\n", "\n", - "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n", + "task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n", "acts = get_acts(task, layer_fetch=25, token_fetch=task.before_c_token)\n", "\n", "problems = task.generate_problems()\n", @@ -524,13 +526,13 @@ "# GPT 2\n", "from transformer_lens import HookedTransformer\n", "\n", - "model = HookedTransformer.from_pretrained(\"gpt2\")\n", + "model = HookedTransformer.from_pretrained(\"gpt2\", device=device, dtype=dtype)\n", "\n", "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(\"cpu\", model_name=\"gpt2\")\n", + " task = DaysOfWeekTask(device, model_name=\"gpt2\", dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(\"cpu\", model_name=\"gpt2\")\n", + " task = MonthsOfYearTask(device, model_name=\"gpt2\", dtype=dtype)\n", " problems = task.generate_problems()\n", " answer_logits = [model.to_single_token(token) for token in task.allowable_tokens]\n", " num_correct = 0\n", @@ -546,7 +548,7 @@ ], "metadata": { "kernelspec": { - "display_name": "multiplexing", + "display_name": "multid", "language": "python", "name": "python3" }, From fe5e53189341d05df9764c00b37c1bd9e596c93f Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Fri, 20 Feb 2026 14:05:19 -0500 Subject: [PATCH 10/19] Update torch.load calls to include weights_only parameter --- intervention/task.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/intervention/task.py b/intervention/task.py index 5c02857..562ab6b 100644 --- a/intervention/task.py +++ b/intervention/task.py @@ -163,7 +163,9 @@ def get_all_acts( all_acts = [] for i in range(0, len(all_problems)): tensors = torch.load( - f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu" + f"{task.prefix}{save_file_prefix}{i}.pt", + map_location="cpu", + weights_only=False, ) all_acts.append(tensors) if len(all_acts) > 1: @@ -201,7 +203,7 @@ def get_acts( torch.save( all_acts[:, layer, token, :].detach().cpu().clone(), file_name ) - data = torch.load(file_name) + data = torch.load(file_name, weights_only=False) if normalize_rms: eps = 1e-5 scale = (data.pow(2).mean(-1, keepdim=True) + eps).sqrt() @@ -235,7 +237,9 @@ def get_acts_pca( pca_acts = pca_object.transform(acts) torch.save(pca_acts, act_file_name) pkl.dump(pca_object, open(pca_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pca_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pca_pkl_file_name, "rb") + ) def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): @@ -255,7 +259,9 @@ def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): torch.save(torch.tensor(pls_acts), act_file_name) pkl.dump(pls, open(pls_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pls_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pls_pkl_file_name, "rb") + ) def _set_plotting_sizes(): From 1abd74ee377fc2543b6b8ccd6a2da7d0e541e6f8 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Thu, 12 Mar 2026 13:40:09 -0400 Subject: [PATCH 11/19] Update all calls to torch.load with required weights_only parameter. --- .../days_of_the_week_deconstruction.py | 3 +- .../months_of_the_year_deconstruction.py | 3 +- .../compare_circle_intervention_types.py | 175 +++++++++++++----- intervention/intervene_in_middle_of_circle.py | 3 +- .../saes/sparse_autoencoder.py | 6 +- 5 files changed, 142 insertions(+), 48 deletions(-) diff --git a/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py b/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py index 2c85289..9706026 100644 --- a/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py +++ b/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups): + str(start_token + token) + "_pca" + str(n_pca_dims) - + ".pt" + + ".pt", + weights_only=False, ) flat_activations = activations[order, :] # problem, pca activations = flat_activations.reshape([mod, mod, n_pca_dims]) diff --git a/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py b/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py index 419e824..b99bba1 100644 --- a/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py +++ b/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups): + str(start_token + token) + "_pca" + str(n_pca_dims) - + ".pt" + + ".pt", + weights_only=False, ) flat_activations = activations[order, :] # problem, pca activations = flat_activations.reshape([mod, mod, n_pca_dims]) diff --git a/intervention/compare_circle_intervention_types.py b/intervention/compare_circle_intervention_types.py index 621c2d7..a2f8799 100644 --- a/intervention/compare_circle_intervention_types.py +++ b/intervention/compare_circle_intervention_types.py @@ -61,17 +61,21 @@ # %% -mistral_pcas = pickle.load(open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb")).components_[1:3, :] +mistral_pcas = pickle.load( + open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb") +).components_[1:3, :] # %% # Get original probe data -original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt") +original_probe = torch.load( + f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt", + weights_only=False, +) original_probe_data = [] for layer in [6, 7, 8, 9, 10]: - ( logit_diffs_before, logit_diffs_after, @@ -98,8 +102,30 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - original_probe_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - original_probe_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + original_probe_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + original_probe_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% @@ -119,21 +145,15 @@ current_probe_dimension = 0 if probe_on_cos: multid_targets[:, current_probe_dimension] = torch.cos(w * oned_targets) - target_to_embedding[:, current_probe_dimension] = torch.cos( - w * torch.arange(p) - ) + target_to_embedding[:, current_probe_dimension] = torch.cos(w * torch.arange(p)) current_probe_dimension += 1 if probe_on_sin: multid_targets[:, current_probe_dimension] = torch.sin(w * oned_targets) - target_to_embedding[:, current_probe_dimension] = torch.sin( - w * torch.arange(p) - ) + target_to_embedding[:, current_probe_dimension] = torch.sin(w * torch.arange(p)) current_probe_dimension += 1 if probe_on_centered_linear: multid_targets[:, current_probe_dimension] = oned_targets - (p - 1) / 2 - target_to_embedding[:, current_probe_dimension] = ( - torch.arange(p) - (p - 1) / 2 - ) + target_to_embedding[:, current_probe_dimension] = torch.arange(p) - (p - 1) / 2 current_probe_dimension += 1 assert current_probe_dimension == probe_dimension @@ -144,9 +164,7 @@ projections = (acts_train @ mistral_pcas.T).float() -least_squares_sol = torch.linalg.lstsq( - projections, multid_targets_train -).solution +least_squares_sol = torch.linalg.lstsq(projections, multid_targets_train).solution probe_q, probe_r = torch.linalg.qr(least_squares_sol) @@ -159,7 +177,6 @@ mistral_data = [] for layer in [6, 7, 8, 9, 10]: - ( logit_diffs_before, logit_diffs_after, @@ -187,18 +204,41 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - mistral_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - mistral_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + mistral_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + mistral_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% - original_probe_varying_layer_data = [] for layer in [6, 7, 8, 9, 10]: - - original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt") + original_probe = torch.load( + f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt", + weights_only=False, + ) ( logit_diffs_before, @@ -226,8 +266,30 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - original_probe_varying_layer_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - original_probe_varying_layer_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + original_probe_varying_layer_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + original_probe_varying_layer_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% @@ -235,7 +297,10 @@ pickle.dump(original_probe_data, open("figs/original_probe_data.pkl", "wb")) pickle.dump(mistral_data, open("figs/mistral_data.pkl", "wb")) -pickle.dump(original_probe_varying_layer_data, open("figs/original_probe_varying_layer_data.pkl", "wb")) +pickle.dump( + original_probe_varying_layer_data, + open("figs/original_probe_varying_layer_data.pkl", "wb"), +) # %% @@ -246,19 +311,25 @@ # Get means average_after_original_probe = [x[2] for x in original_probe_data[::2]] average_after_mistral = [x[2] for x in mistral_data[::2]] -average_after_original_probe_varying_layer = [x[2] for x in original_probe_varying_layer_data[::2]] +average_after_original_probe_varying_layer = [ + x[2] for x in original_probe_varying_layer_data[::2] +] print(average_after_original_probe[0]) print(average_after_mistral[0]) print(average_after_original_probe_varying_layer[0]) import scipy + + def mean_confidence_interval(data, confidence=0.96): a = 1.0 * np.array(data) n = len(a) m, se = np.mean(a), scipy.stats.sem(a) h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1) return m, m - h, m + h + + # Get confidence intervals original_probe_means = [] original_probe_lower = [] @@ -288,22 +359,13 @@ def mean_confidence_interval(data, confidence=0.96): varying_layer_upper.append(upper) ax.plot(x, original_probe_means, label="Intervene with Layer 8 Probe", marker="o") -ax.fill_between(x, - original_probe_lower, - original_probe_upper, - alpha=0.3) +ax.fill_between(x, original_probe_lower, original_probe_upper, alpha=0.3) ax.plot(x, mistral_means, label="Intervene with SAE Subspace", marker="o") -ax.fill_between(x, - mistral_lower, - mistral_upper, - alpha=0.3) +ax.fill_between(x, mistral_lower, mistral_upper, alpha=0.3) ax.plot(x, varying_layer_means, label="Intervene with Probe", marker="o") -ax.fill_between(x, - varying_layer_lower, - varying_layer_upper, - alpha=0.3) +ax.fill_between(x, varying_layer_lower, varying_layer_upper, alpha=0.3) ax.set_xlabel("Layer") ax.set_xticks(x) @@ -318,19 +380,46 @@ def mean_confidence_interval(data, confidence=0.96): # Map each target value to a consistent color based on its position in the circle cmap = plt.get_cmap("tab10") -days_of_week = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"] +days_of_week = [ + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", +] added_labels = set() for i in range(len(projections)): if int(oned_targets[i]) not in added_labels: added_labels.add(int(oned_targets[i])) - plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10, label=days_of_week[int(oned_targets[i])]) + plt.plot( + projections[i, 0], + projections[i, 1], + ".", + color=cmap(int(oned_targets[i])), + markersize=10, + label=days_of_week[int(oned_targets[i])], + ) else: - plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10) + plt.plot( + projections[i, 0], + projections[i, 1], + ".", + color=cmap(int(oned_targets[i])), + markersize=10, + ) # Sort legend by days of the week handles, labels = ax.get_legend_handles_labels() order = np.argsort([days_of_week.index(label) for label in labels]) -ax.legend([handles[idx] for idx in order], [labels[idx] for idx in order], loc="upper left", bbox_to_anchor=(-0.1, 1.2), ncol=4) +ax.legend( + [handles[idx] for idx in order], + [labels[idx] for idx in order], + loc="upper left", + bbox_to_anchor=(-0.1, 1.2), + ncol=4, +) ax.set_xlabel("Projection onto second SAE PCA component") ax.set_ylabel("Projection onto third SAE PCA component") diff --git a/intervention/intervene_in_middle_of_circle.py b/intervention/intervene_in_middle_of_circle.py index 2fb75fb..6a7dcb4 100644 --- a/intervention/intervene_in_middle_of_circle.py +++ b/intervention/intervene_in_middle_of_circle.py @@ -40,7 +40,8 @@ def vary_wthin_circle(circle_letter, duration, layer, token, pca_k, all_points): model = task.get_model() circle_projection_qr = torch.load( - f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt" + f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt", + weights_only=False, ) for problem in task.generate_problems(): diff --git a/sae_multid_feature_discovery/saes/sparse_autoencoder.py b/sae_multid_feature_discovery/saes/sparse_autoencoder.py index 85a9f57..ba61316 100755 --- a/sae_multid_feature_discovery/saes/sparse_autoencoder.py +++ b/sae_multid_feature_discovery/saes/sparse_autoencoder.py @@ -182,10 +182,12 @@ def load_from_pretrained(cls, path: str): if path.endswith(".pt"): try: if torch.backends.mps.is_available(): - state_dict = torch.load(path, map_location="mps") + state_dict = torch.load( + path, map_location="mps", weights_only=False + ) state_dict["cfg"].device = "mps" else: - state_dict = torch.load(path) + state_dict = torch.load(path, weights_only=False) except Exception as e: raise IOError(f"Error loading the state dictionary from .pt file: {e}") From c251a5fe57b3519dbe650bfeb1605cf48cdcd621 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Thu, 12 Mar 2026 21:12:52 -0400 Subject: [PATCH 12/19] Fix TransformerLens model name for GPT-2 TransformerLens expects "gpt2" not "gpt-2" as the model identifier. Add tl_model_name mapping so HookedTransformer.from_pretrained() receives the correct name. --- .../generate_feature_occurence_data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..59bbd31 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -44,8 +44,9 @@ num_workers = 8 sae_hidden_size = 24576 +tl_model_name = "gpt2" if model_name == "gpt-2" else model_name model = transformer_lens.HookedTransformer.from_pretrained( - model_name, device=device, n_devices=num_devices + tl_model_name, device=device, n_devices=num_devices ) ctx_len = 256 From 5890d59db0c1b81bcd606ffa6d4b31a6d16b473e Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Thu, 12 Mar 2026 21:13:00 -0400 Subject: [PATCH 13/19] Use dynamic device count instead of hardcoded num_devices=2 Replace hardcoded num_devices=2 for Mistral with max(1, t.cuda.device_count()) to support single-GPU and CPU-only machines while using all available GPUs when present. --- sae_multid_feature_discovery/generate_feature_occurence_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..f5b2a21 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -32,7 +32,7 @@ model_name = "mistral-7b" batch_size = 16 layers_to_evaluate = [8, 16, 24] - num_devices = 2 + num_devices = max(1, t.cuda.device_count()) sae_hidden_size = 65536 else: From 1ff27dbe843991d12a4eb8177b66edf03164abdf Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Thu, 12 Mar 2026 21:13:06 -0400 Subject: [PATCH 14/19] Add compatibility for newer sae_lens forward() return type Newer versions of sae_lens changed forward() to return a reconstructed tensor instead of an object with feature_acts. Add hasattr check and fall back to encode() to get feature activations. --- .../generate_feature_occurence_data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..7301f76 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -139,8 +139,12 @@ def next_batch_activations(): forward_pass = ae.forward(activations) if isinstance(forward_pass, tuple): hidden_sae = forward_pass[1] - else: + elif hasattr(forward_pass, "feature_acts"): hidden_sae = forward_pass.feature_acts + else: + # Newer sae_lens returns reconstructed tensor from forward(); + # use encode() to get feature activations instead + hidden_sae = ae.encode(activations) nonzero_sae = hidden_sae.abs() > 1e-6 nonzero_sae_values = hidden_sae[nonzero_sae] From ae86fa82730325cb576954de841a894ebdaa01f4 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Thu, 12 Mar 2026 21:13:11 -0400 Subject: [PATCH 15/19] Fix incorrect --clustering_type arg in README examples The actual CLI argument in clustering.py is --method, not --clustering_type. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7e7f6a8..d0d1175 100644 --- a/README.md +++ b/README.md @@ -66,8 +66,8 @@ You can also directly download the gpt-2 layer 7 and Mistral-7B layer 8 activati You will also need to generate the actual clusters by running clustering.py, e.g. ``` -python3 clustering.py --model_name gpt-2 --clustering_type spectral --layer 7 -python3 clustering.py --model_name mistral --clustering_type graph --layer 8 +python3 clustering.py --model_name gpt-2 --method spectral --layer 7 +python3 clustering.py --model_name mistral --method graph --layer 8 ``` Unfortunately, we did not set a seed when we ran spectral clustering in our original experiments, so the clusters you get from the above command may not be the same as the ones we used in the paper. In the `sae_multid_feature_discovery` directory, we provide the GPT-2 (`gpt-2_layer_7_clusters_spectral_n1000.pkl`) and Mistral-7B (`mistral_layer_8_clusters_cutoff_0.5.pkl`) clusters that were used in the paper. For easy reference, here are the GPT-2 SAE feature indices for the days, weeks, and years clusters we reported in the paper (Figure 1): From bd6ed813bc356b3bd36db31a5d87978384a1490b Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Fri, 13 Mar 2026 13:57:39 -0400 Subject: [PATCH 16/19] Read BASE_DIR from environment variable with fallback to default Allow overriding BASE_DIR via environment variable for cloud/remote environments. Falls back to the existing relative cache/ path when the variable is not set. Depends on PR #5 (pathlib refactor). --- intervention/utils.py | 2 +- sae_multid_feature_discovery/utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/intervention/utils.py b/intervention/utils.py index d3bfd1f..dcc1e97 100644 --- a/intervention/utils.py +++ b/intervention/utils.py @@ -2,7 +2,7 @@ import dill as pickle from pathlib import Path -BASE_DIR = Path(__file__).parent.parent / "cache" +BASE_DIR = Path(os.environ.get("BASE_DIR", Path(__file__).parent.parent / "cache")) os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/" diff --git a/sae_multid_feature_discovery/utils.py b/sae_multid_feature_discovery/utils.py index b44ed3b..370290f 100644 --- a/sae_multid_feature_discovery/utils.py +++ b/sae_multid_feature_discovery/utils.py @@ -1,9 +1,9 @@ - from pathlib import Path from huggingface_hub import hf_hub_download import os -BASE_DIR = Path(__file__).parent.parent / "cache" +BASE_DIR = Path(os.environ.get("BASE_DIR", Path(__file__).parent.parent / "cache")) + def get_gpt2_sae(device, layer): from sae_lens import SAE @@ -11,7 +11,7 @@ def get_gpt2_sae(device, layer): return SAE.from_pretrained( release="gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml sae_id=f"blocks.{layer}.hook_resid_pre", # won't always be a hook point - device=device + device=device, )[0] From 18e77fe7fb28a728ea6cb0bda293ee4c48f63a97 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Fri, 13 Mar 2026 13:58:10 -0400 Subject: [PATCH 17/19] Change default CUDA device from cuda:4 to cuda:0 Most machines index GPUs from 0. The previous default of cuda:4 assumed a specific multi-GPU server setup and would fail on machines with fewer than 5 GPUs. Depends on PR #11 (CPU fallback). --- intervention/circle_probe_interventions.py | 4 ++-- intervention/days_of_week_task.py | 2 +- intervention/months_of_year_task.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index c74243b..4f894e9 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -44,7 +44,7 @@ parser.add_argument( "--device", type=str, - default="cuda:4" if torch.cuda.is_available() else "cpu", + default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use", ) parser.add_argument( @@ -109,7 +109,7 @@ # use_inverse_regression_probe = False # intervention_pca_k = 5 - device = "cuda:4" if torch.cuda.is_available() else "cpu" + device = "cuda:0" if torch.cuda.is_available() else "cpu" dtype = "float32" circle_letter = "c" day_month_choice = "day" diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index e9b0cea..9ee675a 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -13,7 +13,7 @@ from task import activation_patching -device = "cuda:4" if torch.cuda.is_available() else "cpu" +device = "cuda:0" if torch.cuda.is_available() else "cpu" # # %% diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index c767e89..1dea61b 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -14,7 +14,7 @@ from task import activation_patching -device = "cuda:4" if torch.cuda.is_available() else "cpu" +device = "cuda:0" if torch.cuda.is_available() else "cpu" # # %% From 6365a40f107239f93e164a8e255e067c2a74fb97 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Fri, 13 Mar 2026 13:58:42 -0400 Subject: [PATCH 18/19] Remove NVIDIA CUDA 12.1 hard pins from requirements Strip explicit NVIDIA library version pins (nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, etc.) and triton. These are automatically resolved by PyTorch based on the system's CUDA installation, and hard-pinning them to CUDA 12.1 prevents installation on systems with different CUDA versions. --- requirements.txt | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/requirements.txt b/requirements.txt index 11d7bfb..d475998 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,18 +70,7 @@ nest-asyncio==1.6.0 networkx==3.3 nltk==3.8.1 numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.18.1 -nvidia-nvjitlink-cu12==12.6.20 -nvidia-nvtx-cu12==12.1.105 + orjson==3.10.6 packaging==24.1 pandas==2.2.2 @@ -139,7 +128,7 @@ tqdm==4.66.4 traitlets==5.14.3 transformer-lens==2.3.0 transformers==4.43.3 -triton==2.1.0 + typeguard==2.13.3 typer==0.12.3 typing_extensions==4.12.2 From b9dc4efca8a993d3cdf52b6c777af6cb8f6ffa49 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Fri, 13 Mar 2026 15:21:49 -0400 Subject: [PATCH 19/19] Add requirements.in and update requirements.txt via pip-compile - Add requirements.in as the unpinned input file for pip-compile - Remove nvidia-* packages from requirements.in (transitive deps of torch) - Replace requirements.txt with pip-compile output (Python 3.12) - Major version updates: torch 2.10.0, sae-lens 6.37.1, transformer-lens 2.17.0, transformers 4.57.6, datasets 4.5.0 - Adds jupyterlab, ipywidgets, tensorflow as direct dependencies --- requirements.in | 144 ++++++ requirements.txt | 1109 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 1131 insertions(+), 122 deletions(-) create mode 100644 requirements.in diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..233f48f --- /dev/null +++ b/requirements.in @@ -0,0 +1,144 @@ +jupyterlab +ipywidgets +numpy +pandas +torch +tensorflow +transformers +huggingface_hub +hf +accelerate +adjustText +aiohappyeyeballs +aiohttp +aiosignal +anyio +asttokens +async-timeout +attrs +automated-interpretability +babe +beartype +better-abc +blobfile +boostedblob +certifi +charset-normalizer +circuitsvis +click +comm +config2py +contourpy +cycler +datasets +debugpy +decorator +dill +docker-pycreds +dol +einops +exceptiongroup +executing +fancy-einsum +filelock +fonttools +frozenlist +fsspec +gitdb +GitPython +gprof2dot +graze +h11 +httpcore +httpx +i2 +idna +importlib_metadata +importlib_resources +iniconfig +jaxtyping +jedi +Jinja2 +joblib +jupyter_client +jupyter_core +kiwisolver +lxml +markdown-it-py +MarkupSafe +matplotlib +matplotlib-inline +mdurl +mpmath +multidict +multiprocess +nest-asyncio +networkx +nltk +orjson +packaging +parso +patsy +pexpect +pillow +platformdirs +plotly +plotly-express +pluggy +prompt_toolkit +protobuf +psutil +ptyprocess +pure_eval +py2store +pyarrow +pyarrow-hotfix +pycryptodomex +Pygments +pyparsing +pytest +pytest-profiling +python-dateutil +python-dotenv +pytz +PyYAML +pyzmq +regex +requests +rich +sae-lens +safetensors +scikit-learn +scipy +sentencepiece +sentry-sdk +setproctitle +shellingham +six +smmap +sniffio +stack-data +statsmodels +sympy +tenacity +threadpoolctl +tiktoken +tokenizers +tomli +tornado +tqdm +traitlets +transformer-lens +triton +typeguard +typer +typing_extensions +tzdata +urllib3 +uvloop +wandb +wcwidth +xxhash +yarl +zipp +zstandard diff --git a/requirements.txt b/requirements.txt index d475998..ea3f7c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,143 +1,1008 @@ -accelerate==0.33.0 -adjustText==1.2.0 -aiohappyeyeballs==2.3.4 -aiohttp==3.10.0 -aiosignal==1.3.1 -anyio==4.4.0 -asttokens==2.4.1 -async-timeout==4.0.3 -attrs==24.1.0 -automated-interpretability==0.0.5 +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# pip-compile requirements.in +# +#euporie==2.10.4 + # euporie +absl-py==2.4.0 + # via + # keras + # tensorboard + # tensorflow +accelerate==1.12.0 + # via + # -r requirements.in + # transformer-lens +adjusttext==1.3.0 + # via -r requirements.in +aiohappyeyeballs==2.6.1 + # via + # -r requirements.in + # aiohttp +aiohttp==3.13.3 + # via + # -r requirements.in + # boostedblob + # fsspec +aiosignal==1.4.0 + # via + # -r requirements.in + # aiohttp +annotated-doc==0.0.4 + # via typer +annotated-types==0.7.0 + # via pydantic +anyio==4.12.1 + # via + # -r requirements.in + # httpx + # jupyter-server +argon2-cffi==25.1.0 + # via jupyter-server +argon2-cffi-bindings==25.1.0 + # via argon2-cffi +arrow==1.4.0 + # via isoduration +asttokens==3.0.1 + # via + # -r requirements.in + # stack-data +astunparse==1.6.3 + # via tensorflow +async-lru==2.1.0 + # via jupyterlab +async-timeout==5.0.1 + # via -r requirements.in +attrs==25.4.0 + # via + # -r requirements.in + # aiohttp + # jsonschema + # referencing +automated-interpretability==0.0.23 + # via -r requirements.in babe==0.0.7 + # via + # -r requirements.in + # sae-lens +babel==2.18.0 + # via jupyterlab-server beartype==0.14.1 + # via + # -r requirements.in + # transformer-lens +beautifulsoup4==4.14.3 + # via nbconvert better-abc==0.0.3 + # via + # -r requirements.in + # transformer-lens +bleach[css]==6.3.0 + # via nbconvert blobfile==2.1.1 -boostedblob==0.15.4 -certifi==2024.7.4 -charset-normalizer==3.3.2 -circuitsvis==1.43.2 -click==8.1.7 -comm==0.2.2 -config2py==0.1.36 -contourpy==1.2.1 + # via + # -r requirements.in + # automated-interpretability +boostedblob==0.15.6 + # via + # -r requirements.in + # automated-interpretability +certifi==2026.1.4 + # via + # -r requirements.in + # httpcore + # httpx + # requests + # sentry-sdk +cffi==2.0.0 + # via argon2-cffi-bindings +charset-normalizer==3.4.4 + # via + # -r requirements.in + # requests +circuitsvis==1.43.3 + # via -r requirements.in +click==8.3.1 + # via + # -r requirements.in + # nltk + # typer + # wandb +comm==0.2.3 + # via + # -r requirements.in + # ipykernel + # ipywidgets +config2py==0.1.46 + # via + # -r requirements.in + # py2store +contourpy==1.3.3 + # via + # -r requirements.in + # matplotlib +cuda-bindings==12.9.4 + # via torch +cuda-pathfinder==1.3.4 + # via cuda-bindings cycler==0.12.1 -datasets==2.20.0 -debugpy>=1.6.5 -decorator==5.1.1 -dill==0.3.8 + # via + # -r requirements.in + # matplotlib +datasets==4.5.0 + # via + # -r requirements.in + # sae-lens + # transformer-lens +debugpy==1.8.20 + # via + # -r requirements.in + # ipykernel +decorator==5.2.1 + # via + # -r requirements.in + # ipython +defusedxml==0.7.1 + # via nbconvert +dill==0.4.0 + # via + # -r requirements.in + # datasets + # multiprocess docker-pycreds==0.4.0 -dol==0.2.55 -einops==0.8.0 -exceptiongroup==1.2.2 -executing==2.0.1 + # via -r requirements.in +docstring-parser==0.17.0 + # via simple-parsing +dol==0.3.38 + # via + # -r requirements.in + # config2py + # graze + # py2store +einops==0.8.2 + # via + # -r requirements.in + # transformer-lens +exceptiongroup==1.3.1 + # via -r requirements.in +executing==2.2.1 + # via + # -r requirements.in + # stack-data fancy-einsum==0.0.3 -filelock==3.15.4 -fonttools==4.53.1 -frozenlist==1.4.1 -fsspec==2024.5.0 -gitdb==4.0.11 -GitPython==3.1.43 -gprof2dot==2024.6.6 -graze==0.1.24 -h11==0.14.0 -httpcore==1.0.5 -httpx==0.27.0 -huggingface-hub==0.24.5 -i2==0.1.18 -idna==3.7 -importlib_metadata==8.2.0 -importlib_resources==6.4.0 -iniconfig==2.0.0 -ipykernel==6.29.5 -ipython==8.26.0 -jaxtyping==0.2.33 -jedi==0.19.1 -Jinja2==3.1.4 -joblib==1.4.2 -jupyter_client==8.6.2 -jupyter_core==5.7.2 -kiwisolver==1.4.5 + # via + # -r requirements.in + # transformer-lens +fastjsonschema==2.21.2 + # via nbformat +filelock==3.24.2 + # via + # -r requirements.in + # blobfile + # datasets + # hf + # huggingface-hub + # torch + # transformers +flatbuffers==25.12.19 + # via tensorflow +fonttools==4.61.1 + # via + # -r requirements.in + # matplotlib +fqdn==1.5.1 + # via jsonschema +frozenlist==1.8.0 + # via + # -r requirements.in + # aiohttp + # aiosignal +fsspec[http]==2025.10.0 + # via + # -r requirements.in + # datasets + # hf + # huggingface-hub + # torch +gast==0.7.0 + # via tensorflow +gitdb==4.0.12 + # via + # -r requirements.in + # gitpython +gitpython==3.1.46 + # via + # -r requirements.in + # wandb +google-pasta==0.2.0 + # via tensorflow +gprof2dot==2025.4.14 + # via + # -r requirements.in + # pytest-profiling +graze==0.1.39 + # via + # -r requirements.in + # babe +grpcio==1.78.0 + # via + # tensorboard + # tensorflow +h11==0.16.0 + # via + # -r requirements.in + # httpcore +h5py==3.15.1 + # via + # keras + # tensorflow +hf==1.1.0 + # via -r requirements.in +hf-xet==1.2.0 + # via + # hf + # huggingface-hub +httpcore==1.0.9 + # via + # -r requirements.in + # httpx +httpx==0.28.1 + # via + # -r requirements.in + # automated-interpretability + # datasets + # hf + # jupyterlab +huggingface-hub==0.36.2 + # via + # -r requirements.in + # accelerate + # datasets + # tokenizers + # transformer-lens + # transformers +i2==0.1.63 + # via + # -r requirements.in + # config2py +idna==3.11 + # via + # -r requirements.in + # anyio + # httpx + # jsonschema + # requests + # yarl +importlib-metadata==8.7.1 + # via + # -r requirements.in + # circuitsvis +importlib-resources==6.5.2 + # via + # -r requirements.in + # py2store +iniconfig==2.3.0 + # via + # -r requirements.in + # pytest +ipykernel==7.2.0 + # via jupyterlab +ipython==9.10.0 + # via + # ipykernel + # ipywidgets +ipython-pygments-lexers==1.1.1 + # via ipython +ipywidgets==8.1.8 + # via -r requirements.in +isoduration==20.11.0 + # via jsonschema +jaxtyping==0.3.9 + # via + # -r requirements.in + # transformer-lens +jedi==0.19.2 + # via + # -r requirements.in + # ipython +jinja2==3.1.6 + # via + # -r requirements.in + # jupyter-server + # jupyterlab + # jupyterlab-server + # nbconvert + # torch +joblib==1.5.3 + # via + # -r requirements.in + # nltk + # scikit-learn +json5==0.13.0 + # via jupyterlab-server +jsonpointer==3.0.0 + # via jsonschema +jsonschema[format-nongpl]==4.26.0 + # via + # jupyter-events + # jupyterlab-server + # nbformat +jsonschema-specifications==2025.9.1 + # via jsonschema +jupyter-client==8.8.0 + # via + # -r requirements.in + # ipykernel + # jupyter-server + # nbclient +jupyter-core==5.9.1 + # via + # -r requirements.in + # ipykernel + # jupyter-client + # jupyter-server + # jupyterlab + # nbclient + # nbconvert + # nbformat +jupyter-events==0.12.0 + # via jupyter-server +jupyter-lsp==2.3.0 + # via jupyterlab +jupyter-server==2.17.0 + # via + # jupyter-lsp + # jupyterlab + # jupyterlab-server + # notebook-shim +jupyter-server-terminals==0.5.4 + # via jupyter-server +jupyterlab==4.5.3 + # via -r requirements.in +jupyterlab-pygments==0.3.0 + # via nbconvert +jupyterlab-server==2.28.0 + # via jupyterlab +jupyterlab-widgets==3.0.16 + # via ipywidgets +keras==3.13.2 + # via tensorflow +kiwisolver==1.4.9 + # via + # -r requirements.in + # matplotlib +lark==1.3.1 + # via rfc3987-syntax +libclang==18.1.1 + # via tensorflow lxml==4.9.4 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -matplotlib==3.9.1 -matplotlib-inline==0.1.7 + # via + # -r requirements.in + # blobfile + # boostedblob +markdown==3.10.2 + # via tensorboard +markdown-it-py==4.0.0 + # via + # -r requirements.in + # rich +markupsafe==3.0.3 + # via + # -r requirements.in + # jinja2 + # nbconvert + # werkzeug +matplotlib==3.10.8 + # via + # -r requirements.in + # adjusttext +matplotlib-inline==0.2.1 + # via + # -r requirements.in + # ipykernel + # ipython mdurl==0.1.2 + # via + # -r requirements.in + # markdown-it-py +mistune==3.2.0 + # via nbconvert +ml-dtypes==0.5.4 + # via + # keras + # tensorflow mpmath==1.3.0 -multidict==6.0.5 -multiprocess==0.70.16 + # via + # -r requirements.in + # sympy +multidict==6.7.1 + # via + # -r requirements.in + # aiohttp + # yarl +multiprocess==0.70.18 + # via + # -r requirements.in + # datasets +namex==0.1.0 + # via keras +narwhals==2.16.0 + # via plotly +nbclient==0.10.4 + # via nbconvert +nbconvert==7.17.0 + # via jupyter-server +nbformat==5.10.4 + # via + # jupyter-server + # nbclient + # nbconvert nest-asyncio==1.6.0 -networkx==3.3 -nltk==3.8.1 + # via + # -r requirements.in + # ipykernel +networkx==3.6.1 + # via + # -r requirements.in + # torch +nltk==3.9.2 + # via + # -r requirements.in + # sae-lens +notebook-shim==0.2.4 + # via jupyterlab numpy==1.26.4 - -orjson==3.10.6 -packaging==24.1 -pandas==2.2.2 -parso==0.8.4 -patsy==0.5.6 + # via + # -r requirements.in + # accelerate + # adjusttext + # automated-interpretability + # circuitsvis + # contourpy + # datasets + # h5py + # keras + # matplotlib + # ml-dtypes + # pandas + # patsy + # plotly-express + # scikit-learn + # scipy + # statsmodels + # tensorboard + # tensorflow + # transformer-lens + # transformers +nvidia-cublas-cu12==12.8.4.1 + # via + # -r requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.90 + # via + # -r requirements.in + # torch +nvidia-cuda-nvrtc-cu12==12.8.93 + # via + # -r requirements.in + # torch +nvidia-cuda-runtime-cu12==12.8.90 + # via + # -r requirements.in + # torch +nvidia-cudnn-cu12==9.10.2.21 + # via + # -r requirements.in + # torch +nvidia-cufft-cu12==11.3.3.83 + # via + # -r requirements.in + # torch +nvidia-cufile-cu12==1.13.1.3 + # via torch +nvidia-curand-cu12==10.3.9.90 + # via + # -r requirements.in + # torch +nvidia-cusolver-cu12==11.7.3.90 + # via + # -r requirements.in + # torch +nvidia-cusparse-cu12==12.5.8.93 + # via + # -r requirements.in + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.7.1 + # via torch +nvidia-nccl-cu12==2.27.5 + # via + # -r requirements.in + # torch +nvidia-nvjitlink-cu12==12.8.93 + # via + # -r requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvshmem-cu12==3.4.5 + # via torch +nvidia-nvtx-cu12==12.8.90 + # via + # -r requirements.in + # torch +opt-einsum==3.4.0 + # via tensorflow +optree==0.18.0 + # via keras +orjson==3.11.7 + # via + # -r requirements.in + # automated-interpretability +packaging==26.0 + # via + # -r requirements.in + # accelerate + # datasets + # hf + # huggingface-hub + # ipykernel + # jupyter-events + # jupyter-server + # jupyterlab + # jupyterlab-server + # keras + # matplotlib + # nbconvert + # plotly + # pytest + # statsmodels + # tensorboard + # tensorflow + # transformers + # wandb + # wheel +pandas==3.0.1 + # via + # -r requirements.in + # babe + # datasets + # plotly-express + # statsmodels + # transformer-lens +pandocfilters==1.5.1 + # via nbconvert +parso==0.8.6 + # via + # -r requirements.in + # jedi +patsy==1.0.2 + # via + # -r requirements.in + # plotly-express + # statsmodels pexpect==4.9.0 -pillow==10.4.0 -platformdirs==4.2.2 -plotly==5.23.0 + # via + # -r requirements.in + # ipython +pillow==12.1.1 + # via + # -r requirements.in + # matplotlib + # tensorboard +platformdirs==4.9.2 + # via + # -r requirements.in + # jupyter-core + # wandb +plotly==6.5.2 + # via + # -r requirements.in + # plotly-express + # sae-lens plotly-express==0.4.1 -pluggy==1.5.0 -prompt_toolkit==3.0.47 -protobuf==5.27.3 -psutil==6.0.0 + # via + # -r requirements.in + # sae-lens +pluggy==1.6.0 + # via + # -r requirements.in + # pytest +prometheus-client==0.24.1 + # via jupyter-server +prompt-toolkit==3.0.52 + # via + # -r requirements.in + # ipython +propcache==0.4.1 + # via + # aiohttp + # yarl +protobuf==6.33.5 + # via + # -r requirements.in + # tensorboard + # tensorflow + # transformer-lens + # wandb +psutil==7.2.2 + # via + # -r requirements.in + # accelerate + # ipykernel ptyprocess==0.7.0 -pure_eval==0.2.3 -py2store==0.1.20 -pyarrow==17.0.0 -pyarrow-hotfix==0.6 -pycryptodomex==3.20.0 -Pygments==2.18.0 -pyparsing==3.1.2 -pytest==8.3.2 -pytest-profiling==1.7.0 + # via + # -r requirements.in + # pexpect + # terminado +pure-eval==0.2.3 + # via + # -r requirements.in + # stack-data +py2store==0.1.22 + # via + # -r requirements.in + # babe +pyarrow==23.0.1 + # via + # -r requirements.in + # datasets +pyarrow-hotfix==0.7 + # via -r requirements.in +pycparser==3.0 + # via cffi +pycryptodomex==3.23.0 + # via + # -r requirements.in + # blobfile + # boostedblob +pydantic==2.12.5 + # via wandb +pydantic-core==2.41.5 + # via pydantic +pygments==2.19.2 + # via + # -r requirements.in + # ipython + # ipython-pygments-lexers + # nbconvert + # pytest + # rich +pyparsing==3.3.2 + # via + # -r requirements.in + # matplotlib +pytest==9.0.2 + # via + # -r requirements.in + # pytest-profiling +pytest-profiling==1.8.1 + # via -r requirements.in python-dateutil==2.9.0.post0 -python-dotenv==1.0.1 -pytz==2024.1 -PyYAML==6.0.1 -pyzmq==26.0.0 -regex==2024.7.24 -requests==2.32.3 -rich==13.7.1 -sae-lens==3.13.1 -safetensors==0.4.3 -scikit-learn==1.5.1 -scipy==1.14.0 -sentencepiece==0.2.0 -sentry-sdk==2.12.0 -setproctitle==1.3.3 + # via + # -r requirements.in + # arrow + # jupyter-client + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # -r requirements.in + # sae-lens +python-json-logger==4.0.0 + # via jupyter-events +pytz==2025.2 + # via -r requirements.in +pyyaml==6.0.3 + # via + # -r requirements.in + # accelerate + # datasets + # hf + # huggingface-hub + # jupyter-events + # sae-lens + # transformers + # wandb +pyzmq==27.1.0 + # via + # -r requirements.in + # ipykernel + # jupyter-client + # jupyter-server +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications + # jupyter-events +regex==2026.1.15 + # via + # -r requirements.in + # nltk + # tiktoken + # transformers +requests==2.32.5 + # via + # -r requirements.in + # datasets + # graze + # huggingface-hub + # jupyterlab-server + # tensorflow + # tiktoken + # transformers + # wandb +rfc3339-validator==0.1.4 + # via + # jsonschema + # jupyter-events +rfc3986-validator==0.1.1 + # via + # jsonschema + # jupyter-events +rfc3987-syntax==1.1.0 + # via jsonschema +rich==14.3.2 + # via + # -r requirements.in + # keras + # transformer-lens + # typer +rpds-py==0.30.0 + # via + # jsonschema + # referencing +sae-lens==6.37.1 + # via -r requirements.in +safetensors==0.7.0 + # via + # -r requirements.in + # accelerate + # sae-lens + # transformers +scikit-learn==1.8.0 + # via + # -r requirements.in + # automated-interpretability +scipy==1.17.0 + # via + # -r requirements.in + # adjusttext + # plotly-express + # scikit-learn + # statsmodels +send2trash==2.1.0 + # via jupyter-server +sentencepiece==0.2.1 + # via + # -r requirements.in + # transformer-lens +sentry-sdk==2.53.0 + # via + # -r requirements.in + # wandb +setproctitle==1.3.7 + # via -r requirements.in shellingham==1.5.4 -six==1.16.0 -smmap==5.0.1 + # via + # -r requirements.in + # hf + # typer +simple-parsing==0.1.8 + # via sae-lens +six==1.17.0 + # via + # -r requirements.in + # astunparse + # docker-pycreds + # google-pasta + # pytest-profiling + # python-dateutil + # rfc3339-validator + # tensorflow +smmap==5.0.2 + # via + # -r requirements.in + # gitdb sniffio==1.3.1 + # via -r requirements.in +soupsieve==2.8.3 + # via beautifulsoup4 stack-data==0.6.3 -statsmodels==0.14.2 -sympy==1.13.1 -tenacity==9.0.0 -threadpoolctl==3.5.0 -tiktoken==0.6.0 -tokenizers==0.19.1 -tomli==2.0.1 -torch==2.1.2 -tornado==6.4.1 -tqdm==4.66.4 + # via + # -r requirements.in + # ipython +statsmodels==0.14.6 + # via + # -r requirements.in + # plotly-express +sympy==1.14.0 + # via + # -r requirements.in + # torch +tenacity==9.1.4 + # via + # -r requirements.in + # sae-lens +tensorboard==2.20.0 + # via tensorflow +tensorboard-data-server==0.7.2 + # via tensorboard +tensorflow==2.20.0 + # via -r requirements.in +termcolor==3.3.0 + # via tensorflow +terminado==0.18.1 + # via + # jupyter-server + # jupyter-server-terminals +threadpoolctl==3.6.0 + # via + # -r requirements.in + # scikit-learn +tiktoken==0.12.0 + # via + # -r requirements.in + # automated-interpretability +tinycss2==1.4.0 + # via bleach +tokenizers==0.22.2 + # via + # -r requirements.in + # transformers +tomli==2.4.0 + # via -r requirements.in +torch==2.10.0 + # via + # -r requirements.in + # accelerate + # circuitsvis + # transformer-lens +tornado==6.5.4 + # via + # -r requirements.in + # ipykernel + # jupyter-client + # jupyter-server + # jupyterlab + # terminado +tqdm==4.67.3 + # via + # -r requirements.in + # datasets + # hf + # huggingface-hub + # nltk + # transformer-lens + # transformers traitlets==5.14.3 -transformer-lens==2.3.0 -transformers==4.43.3 + # via + # -r requirements.in + # ipykernel + # ipython + # ipywidgets + # jupyter-client + # jupyter-core + # jupyter-events + # jupyter-server + # jupyterlab + # matplotlib-inline + # nbclient + # nbconvert + # nbformat +transformer-lens==2.17.0 + # via + # -r requirements.in + # sae-lens +transformers==4.57.6 + # via + # -r requirements.in + # sae-lens + # transformer-lens + # transformers-stream-generator +transformers-stream-generator==0.0.5 + # via transformer-lens +triton==3.6.0 + # via + # -r requirements.in + # torch +typeguard==4.5.0 + # via + # -r requirements.in + # transformer-lens +typer==0.24.0 + # via + # -r requirements.in + # typer-slim +typer-slim==0.24.0 + # via + # hf + # transformers +typing-extensions==4.15.0 + # via + # -r requirements.in + # aiosignal + # anyio + # beautifulsoup4 + # exceptiongroup + # grpcio + # hf + # huggingface-hub + # optree + # pydantic + # pydantic-core + # referencing + # sae-lens + # simple-parsing + # tensorflow + # torch + # transformer-lens + # typeguard + # typing-inspection + # wandb +typing-inspection==0.4.2 + # via pydantic +tzdata==2025.3 + # via + # -r requirements.in + # arrow +uri-template==1.3.0 + # via jsonschema +urllib3==2.6.3 + # via + # -r requirements.in + # blobfile + # requests + # sentry-sdk +uvloop==0.22.1 + # via + # -r requirements.in + # boostedblob +wadler-lindig==0.1.7 + # via jaxtyping +wandb==0.25.0 + # via + # -r requirements.in + # transformer-lens +wcwidth==0.6.0 + # via + # -r requirements.in + # prompt-toolkit +webcolors==25.10.0 + # via jsonschema +webencodings==0.5.1 + # via + # bleach + # tinycss2 +websocket-client==1.9.0 + # via jupyter-server +werkzeug==3.1.5 + # via tensorboard +wheel==0.46.3 + # via astunparse +widgetsnbextension==4.0.15 + # via ipywidgets +wrapt==2.1.1 + # via tensorflow +xxhash==3.6.0 + # via + # -r requirements.in + # datasets +yarl==1.22.0 + # via + # -r requirements.in + # aiohttp +zipp==3.23.0 + # via + # -r requirements.in + # importlib-metadata +zstandard==0.25.0 + # via -r requirements.in -typeguard==2.13.3 -typer==0.12.3 -typing_extensions==4.12.2 -tzdata==2024.1 -urllib3==2.2.2 -uvloop==0.19.0 -wandb==0.17.5 -wcwidth==0.2.13 -xxhash==3.4.1 -yarl==1.9.4 -zipp==3.19.2 -zstandard==0.22.0 +# The following packages are considered to be unsafe in a requirements file: +# setuptools