From 9213652a596d00fe19bf51b0d9edfa6beb160166 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Fri, 20 Feb 2026 14:05:19 -0500 Subject: [PATCH 1/2] Update torch.load calls to include weights_only parameter --- intervention/task.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/intervention/task.py b/intervention/task.py index 5c02857..562ab6b 100644 --- a/intervention/task.py +++ b/intervention/task.py @@ -163,7 +163,9 @@ def get_all_acts( all_acts = [] for i in range(0, len(all_problems)): tensors = torch.load( - f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu" + f"{task.prefix}{save_file_prefix}{i}.pt", + map_location="cpu", + weights_only=False, ) all_acts.append(tensors) if len(all_acts) > 1: @@ -201,7 +203,7 @@ def get_acts( torch.save( all_acts[:, layer, token, :].detach().cpu().clone(), file_name ) - data = torch.load(file_name) + data = torch.load(file_name, weights_only=False) if normalize_rms: eps = 1e-5 scale = (data.pow(2).mean(-1, keepdim=True) + eps).sqrt() @@ -235,7 +237,9 @@ def get_acts_pca( pca_acts = pca_object.transform(acts) torch.save(pca_acts, act_file_name) pkl.dump(pca_object, open(pca_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pca_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pca_pkl_file_name, "rb") + ) def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): @@ -255,7 +259,9 @@ def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): torch.save(torch.tensor(pls_acts), act_file_name) pkl.dump(pls, open(pls_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pls_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pls_pkl_file_name, "rb") + ) def _set_plotting_sizes(): From 29b759e0c439a0f4a52aa984ebc8466f2b1b4b39 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Thu, 12 Mar 2026 13:40:09 -0400 Subject: [PATCH 2/2] Update all calls to torch.load with required weights_only parameter. --- .../days_of_the_week_deconstruction.py | 3 +- .../months_of_the_year_deconstruction.py | 3 +- .../compare_circle_intervention_types.py | 175 +++++++++++++----- intervention/intervene_in_middle_of_circle.py | 3 +- .../saes/sparse_autoencoder.py | 6 +- 5 files changed, 142 insertions(+), 48 deletions(-) diff --git a/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py b/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py index 2c85289..9706026 100644 --- a/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py +++ b/feature_deconstruction/days_of_the_week/days_of_the_week_deconstruction.py @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups): + str(start_token + token) + "_pca" + str(n_pca_dims) - + ".pt" + + ".pt", + weights_only=False, ) flat_activations = activations[order, :] # problem, pca activations = flat_activations.reshape([mod, mod, n_pca_dims]) diff --git a/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py b/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py index 419e824..b99bba1 100644 --- a/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py +++ b/feature_deconstruction/months_of_the_year/months_of_the_year_deconstruction.py @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups): + str(start_token + token) + "_pca" + str(n_pca_dims) - + ".pt" + + ".pt", + weights_only=False, ) flat_activations = activations[order, :] # problem, pca activations = flat_activations.reshape([mod, mod, n_pca_dims]) diff --git a/intervention/compare_circle_intervention_types.py b/intervention/compare_circle_intervention_types.py index 621c2d7..a2f8799 100644 --- a/intervention/compare_circle_intervention_types.py +++ b/intervention/compare_circle_intervention_types.py @@ -61,17 +61,21 @@ # %% -mistral_pcas = pickle.load(open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb")).components_[1:3, :] +mistral_pcas = pickle.load( + open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb") +).components_[1:3, :] # %% # Get original probe data -original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt") +original_probe = torch.load( + f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt", + weights_only=False, +) original_probe_data = [] for layer in [6, 7, 8, 9, 10]: - ( logit_diffs_before, logit_diffs_after, @@ -98,8 +102,30 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - original_probe_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - original_probe_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + original_probe_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + original_probe_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% @@ -119,21 +145,15 @@ current_probe_dimension = 0 if probe_on_cos: multid_targets[:, current_probe_dimension] = torch.cos(w * oned_targets) - target_to_embedding[:, current_probe_dimension] = torch.cos( - w * torch.arange(p) - ) + target_to_embedding[:, current_probe_dimension] = torch.cos(w * torch.arange(p)) current_probe_dimension += 1 if probe_on_sin: multid_targets[:, current_probe_dimension] = torch.sin(w * oned_targets) - target_to_embedding[:, current_probe_dimension] = torch.sin( - w * torch.arange(p) - ) + target_to_embedding[:, current_probe_dimension] = torch.sin(w * torch.arange(p)) current_probe_dimension += 1 if probe_on_centered_linear: multid_targets[:, current_probe_dimension] = oned_targets - (p - 1) / 2 - target_to_embedding[:, current_probe_dimension] = ( - torch.arange(p) - (p - 1) / 2 - ) + target_to_embedding[:, current_probe_dimension] = torch.arange(p) - (p - 1) / 2 current_probe_dimension += 1 assert current_probe_dimension == probe_dimension @@ -144,9 +164,7 @@ projections = (acts_train @ mistral_pcas.T).float() -least_squares_sol = torch.linalg.lstsq( - projections, multid_targets_train -).solution +least_squares_sol = torch.linalg.lstsq(projections, multid_targets_train).solution probe_q, probe_r = torch.linalg.qr(least_squares_sol) @@ -159,7 +177,6 @@ mistral_data = [] for layer in [6, 7, 8, 9, 10]: - ( logit_diffs_before, logit_diffs_after, @@ -187,18 +204,41 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - mistral_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - mistral_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + mistral_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + mistral_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% - original_probe_varying_layer_data = [] for layer in [6, 7, 8, 9, 10]: - - original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt") + original_probe = torch.load( + f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt", + weights_only=False, + ) ( logit_diffs_before, @@ -226,8 +266,30 @@ average_zero_circle = np.mean(logit_diffs_zero_circle) average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle) - original_probe_varying_layer_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle)) - original_probe_varying_layer_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle)) + original_probe_varying_layer_data.append( + ( + layer, + average_before, + average_after, + average_replace_pca, + average_replace_all, + average_average_ablate, + average_zero_circle, + average_zero_everything_but_circle, + ) + ) + original_probe_varying_layer_data.append( + ( + layer, + logit_diffs_before, + logit_diffs_after, + logit_diffs_replace_pca, + logit_diffs_replace_all, + logit_diffs_average_ablate, + logit_diffs_zero_circle, + logit_diffs_zero_everything_but_circle, + ) + ) # %% @@ -235,7 +297,10 @@ pickle.dump(original_probe_data, open("figs/original_probe_data.pkl", "wb")) pickle.dump(mistral_data, open("figs/mistral_data.pkl", "wb")) -pickle.dump(original_probe_varying_layer_data, open("figs/original_probe_varying_layer_data.pkl", "wb")) +pickle.dump( + original_probe_varying_layer_data, + open("figs/original_probe_varying_layer_data.pkl", "wb"), +) # %% @@ -246,19 +311,25 @@ # Get means average_after_original_probe = [x[2] for x in original_probe_data[::2]] average_after_mistral = [x[2] for x in mistral_data[::2]] -average_after_original_probe_varying_layer = [x[2] for x in original_probe_varying_layer_data[::2]] +average_after_original_probe_varying_layer = [ + x[2] for x in original_probe_varying_layer_data[::2] +] print(average_after_original_probe[0]) print(average_after_mistral[0]) print(average_after_original_probe_varying_layer[0]) import scipy + + def mean_confidence_interval(data, confidence=0.96): a = 1.0 * np.array(data) n = len(a) m, se = np.mean(a), scipy.stats.sem(a) h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1) return m, m - h, m + h + + # Get confidence intervals original_probe_means = [] original_probe_lower = [] @@ -288,22 +359,13 @@ def mean_confidence_interval(data, confidence=0.96): varying_layer_upper.append(upper) ax.plot(x, original_probe_means, label="Intervene with Layer 8 Probe", marker="o") -ax.fill_between(x, - original_probe_lower, - original_probe_upper, - alpha=0.3) +ax.fill_between(x, original_probe_lower, original_probe_upper, alpha=0.3) ax.plot(x, mistral_means, label="Intervene with SAE Subspace", marker="o") -ax.fill_between(x, - mistral_lower, - mistral_upper, - alpha=0.3) +ax.fill_between(x, mistral_lower, mistral_upper, alpha=0.3) ax.plot(x, varying_layer_means, label="Intervene with Probe", marker="o") -ax.fill_between(x, - varying_layer_lower, - varying_layer_upper, - alpha=0.3) +ax.fill_between(x, varying_layer_lower, varying_layer_upper, alpha=0.3) ax.set_xlabel("Layer") ax.set_xticks(x) @@ -318,19 +380,46 @@ def mean_confidence_interval(data, confidence=0.96): # Map each target value to a consistent color based on its position in the circle cmap = plt.get_cmap("tab10") -days_of_week = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"] +days_of_week = [ + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", +] added_labels = set() for i in range(len(projections)): if int(oned_targets[i]) not in added_labels: added_labels.add(int(oned_targets[i])) - plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10, label=days_of_week[int(oned_targets[i])]) + plt.plot( + projections[i, 0], + projections[i, 1], + ".", + color=cmap(int(oned_targets[i])), + markersize=10, + label=days_of_week[int(oned_targets[i])], + ) else: - plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10) + plt.plot( + projections[i, 0], + projections[i, 1], + ".", + color=cmap(int(oned_targets[i])), + markersize=10, + ) # Sort legend by days of the week handles, labels = ax.get_legend_handles_labels() order = np.argsort([days_of_week.index(label) for label in labels]) -ax.legend([handles[idx] for idx in order], [labels[idx] for idx in order], loc="upper left", bbox_to_anchor=(-0.1, 1.2), ncol=4) +ax.legend( + [handles[idx] for idx in order], + [labels[idx] for idx in order], + loc="upper left", + bbox_to_anchor=(-0.1, 1.2), + ncol=4, +) ax.set_xlabel("Projection onto second SAE PCA component") ax.set_ylabel("Projection onto third SAE PCA component") diff --git a/intervention/intervene_in_middle_of_circle.py b/intervention/intervene_in_middle_of_circle.py index 2fb75fb..6a7dcb4 100644 --- a/intervention/intervene_in_middle_of_circle.py +++ b/intervention/intervene_in_middle_of_circle.py @@ -40,7 +40,8 @@ def vary_wthin_circle(circle_letter, duration, layer, token, pca_k, all_points): model = task.get_model() circle_projection_qr = torch.load( - f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt" + f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt", + weights_only=False, ) for problem in task.generate_problems(): diff --git a/sae_multid_feature_discovery/saes/sparse_autoencoder.py b/sae_multid_feature_discovery/saes/sparse_autoencoder.py index 85a9f57..ba61316 100755 --- a/sae_multid_feature_discovery/saes/sparse_autoencoder.py +++ b/sae_multid_feature_discovery/saes/sparse_autoencoder.py @@ -182,10 +182,12 @@ def load_from_pretrained(cls, path: str): if path.endswith(".pt"): try: if torch.backends.mps.is_available(): - state_dict = torch.load(path, map_location="mps") + state_dict = torch.load( + path, map_location="mps", weights_only=False + ) state_dict["cfg"].device = "mps" else: - state_dict = torch.load(path) + state_dict = torch.load(path, weights_only=False) except Exception as e: raise IOError(f"Error loading the state dictionary from .pt file: {e}")