Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions crosslayer_transcoder/model/serializable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,26 @@ def to_config(self) -> ConfigDict:
@classmethod
def from_config(cls, config: ConfigDict) -> Self:
"""Construct module from a config dict. Weights are not loaded."""

module_name, class_name = config["class_path"].rsplit(".", 1)
resolved_cls = getattr(importlib.import_module(module_name), class_name)
if cls != resolved_cls:
raise ValueError(
f"Incorrect class_path specified for building {cls}. Classpath specified: {config['class_path']}"
)

init_args = config.get("init_args", {})

resolved_args = {}

for key, value in init_args.items():
if isinstance(value, dict) and "class_path" in value:
target_module_name, target_class_name = value["class_path"].rsplit(".", 1)
target_cls = getattr(importlib.import_module(target_module_name), target_class_name)
target_module_name, target_class_name = value["class_path"].rsplit(
".", 1
)
target_cls = getattr(
importlib.import_module(target_module_name), target_class_name
)
resolved_args[key] = target_cls.from_config(value)
else:
resolved_args[key] = value
Expand All @@ -54,16 +66,26 @@ def save_pretrained(
def from_pretrained(cls, directory: Union[Path, str]) -> Self:
"""Load model from directory."""
directory = Path(directory)
with open(directory / "config.yaml") as f:
full_config = yaml.safe_load(f)
checkpoint = directory / "checkpoint.safetensors"
config = directory / "config.yaml"
return cls.from_config_and_checkpoint(config, checkpoint)

@classmethod
def from_config_and_checkpoint(cls, config: Path, checkpoint: Path) -> Self:
if not checkpoint.exists():
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint}")
if not config.exists():
raise FileNotFoundError(f"Config file not found: {config}")

with open(config, "r") as f:
full_config = yaml.load(f, Loader=yaml.FullLoader)

model_config = full_config.get("model")
if model_config is None:
raise ValueError("Model config not found in config.yaml", full_config)

model = cls.from_config(model_config)
model.load_state_dict(load_file(directory / "checkpoint.safetensors"))
model.load_state_dict(load_file(checkpoint))

model._is_folded = model_config.get("is_folded", False)

return model
56 changes: 56 additions & 0 deletions tests/serializable-module/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import tempfile
import pytest

from crosslayer_transcoder.model.serializable_module import ConfigDict


@pytest.fixture
def config_dict() -> ConfigDict:
return {
"class_path": "crosslayer_transcoder.model.clt.CrossLayerTranscoder",
"init_args": {
"encoder": {
"class_path": "crosslayer_transcoder.model.clt.Encoder",
"init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2},
},
"decoder": {
"class_path": "crosslayer_transcoder.model.clt.CrosslayerDecoder",
"init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2},
},
"nonlinearity": {
"class_path": "crosslayer_transcoder.model.jumprelu.JumpReLU",
"init_args": {
"theta": 0.03,
"bandwidth": 0.01,
"n_layers": 2,
"d_features": 32,
},
},
"input_standardizer": {
"class_path": "crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer",
"init_args": {"n_layers": 2, "activation_dim": 110},
},
"output_standardizer": {
"class_path": "crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer",
"init_args": {"n_layers": 2, "activation_dim": 110},
},
},
}


@pytest.fixture
def tmp_model_dir():
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir


@pytest.fixture
def tmp_config_path():
with tempfile.NamedTemporaryFile(suffix=".yaml") as tmpfile:
yield tmpfile.name


@pytest.fixture
def tmp_checkpoint_path():
with tempfile.NamedTemporaryFile(suffix=".safetensors") as tmpfile:
yield tmpfile.name
37 changes: 37 additions & 0 deletions tests/serializable-module/test_from_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Tests needed
# 1. from config
# 2. save pretained
# 3. from pretrained
# 4. from config and checkpoint

import pytest

from crosslayer_transcoder.model.clt import (
CrossLayerTranscoder,
Encoder,
)


def test_from_config_success(config_dict):
module = CrossLayerTranscoder.from_config(config_dict)

assert isinstance(module, CrossLayerTranscoder)


def test_from_config_bad_class_path(config_dict):
with pytest.raises(ValueError):
Encoder.from_config(config_dict)


def test_from_config_class_path_dne(config_dict):
config_dict["init_args"]["encoder"]["class_path"] = "path/that/doesnt/exist.py"

with pytest.raises(Exception):
CrossLayerTranscoder.from_config(config_dict)


def test_init_args_missing(config_dict):
del config_dict["init_args"]["encoder"]

with pytest.raises(Exception):
CrossLayerTranscoder.from_config(config_dict)
98 changes: 98 additions & 0 deletions tests/serializable-module/test_from_config_and_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from pathlib import Path

import pytest
import yaml
from safetensors.torch import save_file

from crosslayer_transcoder.model.clt import CrossLayerTranscoder


def test_missing_checkpoint(config_dict, tmp_model_dir):
config_path = Path(tmp_model_dir) / "config.yaml"
with open(config_path, "w") as f:
yaml.dump({"model": config_dict}, f)

with pytest.raises(FileNotFoundError, match="Checkpoint file not found"):
CrossLayerTranscoder.from_config_and_checkpoint(
config_path, Path(tmp_model_dir) / "checkpoint.safetensors"
)


def test_missing_config(config_dict, tmp_model_dir):
model = CrossLayerTranscoder.from_config(config_dict)
checkpoint_path = Path(tmp_model_dir) / "checkpoint.safetensors"
save_file(model.state_dict(), checkpoint_path)

with pytest.raises(FileNotFoundError, match="Config file not found"):
CrossLayerTranscoder.from_config_and_checkpoint(
Path(tmp_model_dir) / "config.yaml", checkpoint_path
)


def test_missing_model_key_in_config(config_dict, tmp_model_dir):
model = CrossLayerTranscoder.from_config(config_dict)
dir_path = Path(tmp_model_dir)
config_path = dir_path / "config.yaml"
checkpoint_path = dir_path / "checkpoint.safetensors"

with open(config_path, "w") as f:
yaml.dump({"wrong_key": config_dict}, f)
save_file(model.state_dict(), checkpoint_path)

with pytest.raises(ValueError, match="Model config not found"):
CrossLayerTranscoder.from_config_and_checkpoint(config_path, checkpoint_path)


def test_is_folded_true(config_dict, tmp_model_dir):
model = CrossLayerTranscoder.from_config(config_dict)
dir_path = Path(tmp_model_dir)
config_path = dir_path / "config.yaml"
checkpoint_path = dir_path / "checkpoint.safetensors"

config_with_folded = config_dict.copy()
config_with_folded["is_folded"] = True
with open(config_path, "w") as f:
yaml.dump({"model": config_with_folded}, f)
save_file(model.state_dict(), checkpoint_path)

loaded = CrossLayerTranscoder.from_config_and_checkpoint(
config_path, checkpoint_path
)

assert loaded._is_folded is True


def test_is_folded_false(config_dict, tmp_model_dir):
model = CrossLayerTranscoder.from_config(config_dict)
dir_path = Path(tmp_model_dir)
config_path = dir_path / "config.yaml"
checkpoint_path = dir_path / "checkpoint.safetensors"

config_with_folded = config_dict.copy()
config_with_folded["is_folded"] = False
with open(config_path, "w") as f:
yaml.dump({"model": config_with_folded}, f)
save_file(model.state_dict(), checkpoint_path)

loaded = CrossLayerTranscoder.from_config_and_checkpoint(
config_path, checkpoint_path
)

assert loaded._is_folded is False


def test_is_folded_defaults_false_when_missing(config_dict, tmp_model_dir):
model = CrossLayerTranscoder.from_config(config_dict)
dir_path = Path(tmp_model_dir)
config_path = dir_path / "config.yaml"
checkpoint_path = dir_path / "checkpoint.safetensors"

with open(config_path, "w") as f:
yaml.dump({"model": config_dict}, f)
save_file(model.state_dict(), checkpoint_path)

loaded = CrossLayerTranscoder.from_config_and_checkpoint(
config_path, checkpoint_path
)

assert loaded._is_folded is False
28 changes: 28 additions & 0 deletions tests/serializable-module/test_from_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pathlib import Path

import torch

from crosslayer_transcoder.model.clt import CrossLayerTranscoder


def test_round_trip(config_dict, tmp_model_dir):
original = CrossLayerTranscoder.from_config(config_dict)
original.save_pretrained(Path(tmp_model_dir))

loaded = CrossLayerTranscoder.from_pretrained(tmp_model_dir)

assert isinstance(loaded, CrossLayerTranscoder)
original_state = original.state_dict()
loaded_state = loaded.state_dict()
assert original_state.keys() == loaded_state.keys()
for key in original_state:
assert torch.allclose(original_state[key], loaded_state[key], equal_nan=True)


def test_accepts_string_path(config_dict, tmp_model_dir):
original = CrossLayerTranscoder.from_config(config_dict)
original.save_pretrained(Path(tmp_model_dir))

loaded = CrossLayerTranscoder.from_pretrained(str(tmp_model_dir))

assert isinstance(loaded, CrossLayerTranscoder)
Loading