diff --git a/crosslayer_transcoder/model/serializable_module.py b/crosslayer_transcoder/model/serializable_module.py index 80ee9a4..3b4f778 100644 --- a/crosslayer_transcoder/model/serializable_module.py +++ b/crosslayer_transcoder/model/serializable_module.py @@ -24,14 +24,26 @@ def to_config(self) -> ConfigDict: @classmethod def from_config(cls, config: ConfigDict) -> Self: """Construct module from a config dict. Weights are not loaded.""" + + module_name, class_name = config["class_path"].rsplit(".", 1) + resolved_cls = getattr(importlib.import_module(module_name), class_name) + if cls != resolved_cls: + raise ValueError( + f"Incorrect class_path specified for building {cls}. Classpath specified: {config['class_path']}" + ) + init_args = config.get("init_args", {}) resolved_args = {} for key, value in init_args.items(): if isinstance(value, dict) and "class_path" in value: - target_module_name, target_class_name = value["class_path"].rsplit(".", 1) - target_cls = getattr(importlib.import_module(target_module_name), target_class_name) + target_module_name, target_class_name = value["class_path"].rsplit( + ".", 1 + ) + target_cls = getattr( + importlib.import_module(target_module_name), target_class_name + ) resolved_args[key] = target_cls.from_config(value) else: resolved_args[key] = value @@ -54,16 +66,26 @@ def save_pretrained( def from_pretrained(cls, directory: Union[Path, str]) -> Self: """Load model from directory.""" directory = Path(directory) - with open(directory / "config.yaml") as f: - full_config = yaml.safe_load(f) + checkpoint = directory / "checkpoint.safetensors" + config = directory / "config.yaml" + return cls.from_config_and_checkpoint(config, checkpoint) + + @classmethod + def from_config_and_checkpoint(cls, config: Path, checkpoint: Path) -> Self: + if not checkpoint.exists(): + raise FileNotFoundError(f"Checkpoint file not found: {checkpoint}") + if not config.exists(): + raise FileNotFoundError(f"Config file not found: {config}") + + with open(config, "r") as f: + full_config = yaml.load(f, Loader=yaml.FullLoader) model_config = full_config.get("model") if model_config is None: raise ValueError("Model config not found in config.yaml", full_config) model = cls.from_config(model_config) - model.load_state_dict(load_file(directory / "checkpoint.safetensors")) + model.load_state_dict(load_file(checkpoint)) model._is_folded = model_config.get("is_folded", False) - return model diff --git a/tests/serializable-module/conftest.py b/tests/serializable-module/conftest.py new file mode 100644 index 0000000..9778e48 --- /dev/null +++ b/tests/serializable-module/conftest.py @@ -0,0 +1,56 @@ +import tempfile +import pytest + +from crosslayer_transcoder.model.serializable_module import ConfigDict + + +@pytest.fixture +def config_dict() -> ConfigDict: + return { + "class_path": "crosslayer_transcoder.model.clt.CrossLayerTranscoder", + "init_args": { + "encoder": { + "class_path": "crosslayer_transcoder.model.clt.Encoder", + "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, + }, + "decoder": { + "class_path": "crosslayer_transcoder.model.clt.CrosslayerDecoder", + "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, + }, + "nonlinearity": { + "class_path": "crosslayer_transcoder.model.jumprelu.JumpReLU", + "init_args": { + "theta": 0.03, + "bandwidth": 0.01, + "n_layers": 2, + "d_features": 32, + }, + }, + "input_standardizer": { + "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer", + "init_args": {"n_layers": 2, "activation_dim": 110}, + }, + "output_standardizer": { + "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer", + "init_args": {"n_layers": 2, "activation_dim": 110}, + }, + }, + } + + +@pytest.fixture +def tmp_model_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture +def tmp_config_path(): + with tempfile.NamedTemporaryFile(suffix=".yaml") as tmpfile: + yield tmpfile.name + + +@pytest.fixture +def tmp_checkpoint_path(): + with tempfile.NamedTemporaryFile(suffix=".safetensors") as tmpfile: + yield tmpfile.name diff --git a/tests/serializable-module/test_from_config.py b/tests/serializable-module/test_from_config.py new file mode 100644 index 0000000..a244292 --- /dev/null +++ b/tests/serializable-module/test_from_config.py @@ -0,0 +1,37 @@ +# Tests needed +# 1. from config +# 2. save pretained +# 3. from pretrained +# 4. from config and checkpoint + +import pytest + +from crosslayer_transcoder.model.clt import ( + CrossLayerTranscoder, + Encoder, +) + + +def test_from_config_success(config_dict): + module = CrossLayerTranscoder.from_config(config_dict) + + assert isinstance(module, CrossLayerTranscoder) + + +def test_from_config_bad_class_path(config_dict): + with pytest.raises(ValueError): + Encoder.from_config(config_dict) + + +def test_from_config_class_path_dne(config_dict): + config_dict["init_args"]["encoder"]["class_path"] = "path/that/doesnt/exist.py" + + with pytest.raises(Exception): + CrossLayerTranscoder.from_config(config_dict) + + +def test_init_args_missing(config_dict): + del config_dict["init_args"]["encoder"] + + with pytest.raises(Exception): + CrossLayerTranscoder.from_config(config_dict) diff --git a/tests/serializable-module/test_from_config_and_checkpoint.py b/tests/serializable-module/test_from_config_and_checkpoint.py new file mode 100644 index 0000000..99087ca --- /dev/null +++ b/tests/serializable-module/test_from_config_and_checkpoint.py @@ -0,0 +1,98 @@ +from pathlib import Path + +import pytest +import yaml +from safetensors.torch import save_file + +from crosslayer_transcoder.model.clt import CrossLayerTranscoder + + +def test_missing_checkpoint(config_dict, tmp_model_dir): + config_path = Path(tmp_model_dir) / "config.yaml" + with open(config_path, "w") as f: + yaml.dump({"model": config_dict}, f) + + with pytest.raises(FileNotFoundError, match="Checkpoint file not found"): + CrossLayerTranscoder.from_config_and_checkpoint( + config_path, Path(tmp_model_dir) / "checkpoint.safetensors" + ) + + +def test_missing_config(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + checkpoint_path = Path(tmp_model_dir) / "checkpoint.safetensors" + save_file(model.state_dict(), checkpoint_path) + + with pytest.raises(FileNotFoundError, match="Config file not found"): + CrossLayerTranscoder.from_config_and_checkpoint( + Path(tmp_model_dir) / "config.yaml", checkpoint_path + ) + + +def test_missing_model_key_in_config(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + with open(config_path, "w") as f: + yaml.dump({"wrong_key": config_dict}, f) + save_file(model.state_dict(), checkpoint_path) + + with pytest.raises(ValueError, match="Model config not found"): + CrossLayerTranscoder.from_config_and_checkpoint(config_path, checkpoint_path) + + +def test_is_folded_true(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + config_with_folded = config_dict.copy() + config_with_folded["is_folded"] = True + with open(config_path, "w") as f: + yaml.dump({"model": config_with_folded}, f) + save_file(model.state_dict(), checkpoint_path) + + loaded = CrossLayerTranscoder.from_config_and_checkpoint( + config_path, checkpoint_path + ) + + assert loaded._is_folded is True + + +def test_is_folded_false(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + config_with_folded = config_dict.copy() + config_with_folded["is_folded"] = False + with open(config_path, "w") as f: + yaml.dump({"model": config_with_folded}, f) + save_file(model.state_dict(), checkpoint_path) + + loaded = CrossLayerTranscoder.from_config_and_checkpoint( + config_path, checkpoint_path + ) + + assert loaded._is_folded is False + + +def test_is_folded_defaults_false_when_missing(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + with open(config_path, "w") as f: + yaml.dump({"model": config_dict}, f) + save_file(model.state_dict(), checkpoint_path) + + loaded = CrossLayerTranscoder.from_config_and_checkpoint( + config_path, checkpoint_path + ) + + assert loaded._is_folded is False diff --git a/tests/serializable-module/test_from_pretrained.py b/tests/serializable-module/test_from_pretrained.py new file mode 100644 index 0000000..82d23e6 --- /dev/null +++ b/tests/serializable-module/test_from_pretrained.py @@ -0,0 +1,28 @@ +from pathlib import Path + +import torch + +from crosslayer_transcoder.model.clt import CrossLayerTranscoder + + +def test_round_trip(config_dict, tmp_model_dir): + original = CrossLayerTranscoder.from_config(config_dict) + original.save_pretrained(Path(tmp_model_dir)) + + loaded = CrossLayerTranscoder.from_pretrained(tmp_model_dir) + + assert isinstance(loaded, CrossLayerTranscoder) + original_state = original.state_dict() + loaded_state = loaded.state_dict() + assert original_state.keys() == loaded_state.keys() + for key in original_state: + assert torch.allclose(original_state[key], loaded_state[key], equal_nan=True) + + +def test_accepts_string_path(config_dict, tmp_model_dir): + original = CrossLayerTranscoder.from_config(config_dict) + original.save_pretrained(Path(tmp_model_dir)) + + loaded = CrossLayerTranscoder.from_pretrained(str(tmp_model_dir)) + + assert isinstance(loaded, CrossLayerTranscoder)