diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..59bbd31 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -44,8 +44,9 @@ num_workers = 8 sae_hidden_size = 24576 +tl_model_name = "gpt2" if model_name == "gpt-2" else model_name model = transformer_lens.HookedTransformer.from_pretrained( - model_name, device=device, n_devices=num_devices + tl_model_name, device=device, n_devices=num_devices ) ctx_len = 256