Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups):
+ str(start_token + token)
+ "_pca"
+ str(n_pca_dims)
+ ".pt"
+ ".pt",
weights_only=False,
)
flat_activations = activations[order, :] # problem, pca
activations = flat_activations.reshape([mod, mod, n_pca_dims])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def deconstruct(layer, n_feature_groups):
+ str(start_token + token)
+ "_pca"
+ str(n_pca_dims)
+ ".pt"
+ ".pt",
weights_only=False,
)
flat_activations = activations[order, :] # problem, pca
activations = flat_activations.reshape([mod, mod, n_pca_dims])
Expand Down
175 changes: 132 additions & 43 deletions intervention/compare_circle_intervention_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,21 @@

# %%

mistral_pcas = pickle.load(open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb")).components_[1:3, :]
mistral_pcas = pickle.load(
open("../sae_multid_feature_discovery/fit_pca_days.pkl", "rb")
).components_[1:3, :]

# %%

# Get original probe data

original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt")
original_probe = torch.load(
f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_8_token_{token}_pca_5.pt",
weights_only=False,
)
original_probe_data = []

for layer in [6, 7, 8, 9, 10]:

(
logit_diffs_before,
logit_diffs_after,
Expand All @@ -98,8 +102,30 @@
average_zero_circle = np.mean(logit_diffs_zero_circle)
average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle)

original_probe_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle))
original_probe_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle))
original_probe_data.append(
(
layer,
average_before,
average_after,
average_replace_pca,
average_replace_all,
average_average_ablate,
average_zero_circle,
average_zero_everything_but_circle,
)
)
original_probe_data.append(
(
layer,
logit_diffs_before,
logit_diffs_after,
logit_diffs_replace_pca,
logit_diffs_replace_all,
logit_diffs_average_ablate,
logit_diffs_zero_circle,
logit_diffs_zero_everything_but_circle,
)
)

# %%

Expand All @@ -119,21 +145,15 @@
current_probe_dimension = 0
if probe_on_cos:
multid_targets[:, current_probe_dimension] = torch.cos(w * oned_targets)
target_to_embedding[:, current_probe_dimension] = torch.cos(
w * torch.arange(p)
)
target_to_embedding[:, current_probe_dimension] = torch.cos(w * torch.arange(p))
current_probe_dimension += 1
if probe_on_sin:
multid_targets[:, current_probe_dimension] = torch.sin(w * oned_targets)
target_to_embedding[:, current_probe_dimension] = torch.sin(
w * torch.arange(p)
)
target_to_embedding[:, current_probe_dimension] = torch.sin(w * torch.arange(p))
current_probe_dimension += 1
if probe_on_centered_linear:
multid_targets[:, current_probe_dimension] = oned_targets - (p - 1) / 2
target_to_embedding[:, current_probe_dimension] = (
torch.arange(p) - (p - 1) / 2
)
target_to_embedding[:, current_probe_dimension] = torch.arange(p) - (p - 1) / 2
current_probe_dimension += 1

assert current_probe_dimension == probe_dimension
Expand All @@ -144,9 +164,7 @@

projections = (acts_train @ mistral_pcas.T).float()

least_squares_sol = torch.linalg.lstsq(
projections, multid_targets_train
).solution
least_squares_sol = torch.linalg.lstsq(projections, multid_targets_train).solution

probe_q, probe_r = torch.linalg.qr(least_squares_sol)

Expand All @@ -159,7 +177,6 @@
mistral_data = []

for layer in [6, 7, 8, 9, 10]:

(
logit_diffs_before,
logit_diffs_after,
Expand Down Expand Up @@ -187,18 +204,41 @@
average_zero_circle = np.mean(logit_diffs_zero_circle)
average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle)

mistral_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle))
mistral_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle))
mistral_data.append(
(
layer,
average_before,
average_after,
average_replace_pca,
average_replace_all,
average_average_ablate,
average_zero_circle,
average_zero_everything_but_circle,
)
)
mistral_data.append(
(
layer,
logit_diffs_before,
logit_diffs_after,
logit_diffs_replace_pca,
logit_diffs_replace_all,
logit_diffs_average_ablate,
logit_diffs_zero_circle,
logit_diffs_zero_everything_but_circle,
)
)

# %%



original_probe_varying_layer_data = []

for layer in [6, 7, 8, 9, 10]:

original_probe = torch.load(f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt")
original_probe = torch.load(
f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_5.pt",
weights_only=False,
)

(
logit_diffs_before,
Expand Down Expand Up @@ -226,16 +266,41 @@
average_zero_circle = np.mean(logit_diffs_zero_circle)
average_zero_everything_but_circle = np.mean(logit_diffs_zero_everything_but_circle)

original_probe_varying_layer_data.append((layer, average_before, average_after, average_replace_pca, average_replace_all, average_average_ablate, average_zero_circle, average_zero_everything_but_circle))
original_probe_varying_layer_data.append((layer, logit_diffs_before, logit_diffs_after, logit_diffs_replace_pca, logit_diffs_replace_all, logit_diffs_average_ablate, logit_diffs_zero_circle, logit_diffs_zero_everything_but_circle))
original_probe_varying_layer_data.append(
(
layer,
average_before,
average_after,
average_replace_pca,
average_replace_all,
average_average_ablate,
average_zero_circle,
average_zero_everything_but_circle,
)
)
original_probe_varying_layer_data.append(
(
layer,
logit_diffs_before,
logit_diffs_after,
logit_diffs_replace_pca,
logit_diffs_replace_all,
logit_diffs_average_ablate,
logit_diffs_zero_circle,
logit_diffs_zero_everything_but_circle,
)
)

# %%

# Save data

pickle.dump(original_probe_data, open("figs/original_probe_data.pkl", "wb"))
pickle.dump(mistral_data, open("figs/mistral_data.pkl", "wb"))
pickle.dump(original_probe_varying_layer_data, open("figs/original_probe_varying_layer_data.pkl", "wb"))
pickle.dump(
original_probe_varying_layer_data,
open("figs/original_probe_varying_layer_data.pkl", "wb"),
)

# %%

Expand All @@ -246,19 +311,25 @@
# Get means
average_after_original_probe = [x[2] for x in original_probe_data[::2]]
average_after_mistral = [x[2] for x in mistral_data[::2]]
average_after_original_probe_varying_layer = [x[2] for x in original_probe_varying_layer_data[::2]]
average_after_original_probe_varying_layer = [
x[2] for x in original_probe_varying_layer_data[::2]
]

print(average_after_original_probe[0])
print(average_after_mistral[0])
print(average_after_original_probe_varying_layer[0])

import scipy


def mean_confidence_interval(data, confidence=0.96):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), scipy.stats.sem(a)
h = se * scipy.stats.t.ppf((1 + confidence) / 2.0, n - 1)
return m, m - h, m + h


# Get confidence intervals
original_probe_means = []
original_probe_lower = []
Expand Down Expand Up @@ -288,22 +359,13 @@ def mean_confidence_interval(data, confidence=0.96):
varying_layer_upper.append(upper)

ax.plot(x, original_probe_means, label="Intervene with Layer 8 Probe", marker="o")
ax.fill_between(x,
original_probe_lower,
original_probe_upper,
alpha=0.3)
ax.fill_between(x, original_probe_lower, original_probe_upper, alpha=0.3)

ax.plot(x, mistral_means, label="Intervene with SAE Subspace", marker="o")
ax.fill_between(x,
mistral_lower,
mistral_upper,
alpha=0.3)
ax.fill_between(x, mistral_lower, mistral_upper, alpha=0.3)

ax.plot(x, varying_layer_means, label="Intervene with Probe", marker="o")
ax.fill_between(x,
varying_layer_lower,
varying_layer_upper,
alpha=0.3)
ax.fill_between(x, varying_layer_lower, varying_layer_upper, alpha=0.3)

ax.set_xlabel("Layer")
ax.set_xticks(x)
Expand All @@ -318,19 +380,46 @@ def mean_confidence_interval(data, confidence=0.96):
# Map each target value to a consistent color based on its position in the circle
cmap = plt.get_cmap("tab10")

days_of_week = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
days_of_week = [
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
]
added_labels = set()
for i in range(len(projections)):
if int(oned_targets[i]) not in added_labels:
added_labels.add(int(oned_targets[i]))
plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10, label=days_of_week[int(oned_targets[i])])
plt.plot(
projections[i, 0],
projections[i, 1],
".",
color=cmap(int(oned_targets[i])),
markersize=10,
label=days_of_week[int(oned_targets[i])],
)
else:
plt.plot(projections[i, 0], projections[i, 1], ".", color=cmap(int(oned_targets[i])), markersize=10)
plt.plot(
projections[i, 0],
projections[i, 1],
".",
color=cmap(int(oned_targets[i])),
markersize=10,
)

# Sort legend by days of the week
handles, labels = ax.get_legend_handles_labels()
order = np.argsort([days_of_week.index(label) for label in labels])
ax.legend([handles[idx] for idx in order], [labels[idx] for idx in order], loc="upper left", bbox_to_anchor=(-0.1, 1.2), ncol=4)
ax.legend(
[handles[idx] for idx in order],
[labels[idx] for idx in order],
loc="upper left",
bbox_to_anchor=(-0.1, 1.2),
ncol=4,
)

ax.set_xlabel("Projection onto second SAE PCA component")
ax.set_ylabel("Projection onto third SAE PCA component")
Expand Down
3 changes: 2 additions & 1 deletion intervention/intervene_in_middle_of_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def vary_wthin_circle(circle_letter, duration, layer, token, pca_k, all_points):
model = task.get_model()

circle_projection_qr = torch.load(
f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt"
f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt",
weights_only=False,
)

for problem in task.generate_problems():
Expand Down
14 changes: 10 additions & 4 deletions intervention/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def get_all_acts(
all_acts = []
for i in range(0, len(all_problems)):
tensors = torch.load(
f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu"
f"{task.prefix}{save_file_prefix}{i}.pt",
map_location="cpu",
weights_only=False,
)
all_acts.append(tensors)
if len(all_acts) > 1:
Expand Down Expand Up @@ -201,7 +203,7 @@ def get_acts(
torch.save(
all_acts[:, layer, token, :].detach().cpu().clone(), file_name
)
data = torch.load(file_name)
data = torch.load(file_name, weights_only=False)
if normalize_rms:
eps = 1e-5
scale = (data.pow(2).mean(-1, keepdim=True) + eps).sqrt()
Expand Down Expand Up @@ -235,7 +237,9 @@ def get_acts_pca(
pca_acts = pca_object.transform(acts)
torch.save(pca_acts, act_file_name)
pkl.dump(pca_object, open(pca_pkl_file_name, "wb"))
return torch.load(act_file_name), pkl.load(open(pca_pkl_file_name, "rb"))
return torch.load(act_file_name, weights_only=False), pkl.load(
open(pca_pkl_file_name, "rb")
)


def get_acts_pls(task, layer, token, pls_k, normalize_rms=False):
Expand All @@ -255,7 +259,9 @@ def get_acts_pls(task, layer, token, pls_k, normalize_rms=False):
torch.save(torch.tensor(pls_acts), act_file_name)
pkl.dump(pls, open(pls_pkl_file_name, "wb"))

return torch.load(act_file_name), pkl.load(open(pls_pkl_file_name, "rb"))
return torch.load(act_file_name, weights_only=False), pkl.load(
open(pls_pkl_file_name, "rb")
)


def _set_plotting_sizes():
Expand Down
6 changes: 4 additions & 2 deletions sae_multid_feature_discovery/saes/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,12 @@ def load_from_pretrained(cls, path: str):
if path.endswith(".pt"):
try:
if torch.backends.mps.is_available():
state_dict = torch.load(path, map_location="mps")
state_dict = torch.load(
path, map_location="mps", weights_only=False
)
state_dict["cfg"].device = "mps"
else:
state_dict = torch.load(path)
state_dict = torch.load(path, weights_only=False)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pt file: {e}")

Expand Down