From 9eb52b9ebb10ec16bc4558d443f087bf275a0635 Mon Sep 17 00:00:00 2001 From: jiito Date: Tue, 27 Jan 2026 16:32:39 -0800 Subject: [PATCH 1/2] add basic from config tests --- .../model/serializable_module.py | 91 +++++++++++++++++++ tests/serializable-module/test_from_config.py | 74 +++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 crosslayer_transcoder/model/serializable_module.py create mode 100644 tests/serializable-module/test_from_config.py diff --git a/crosslayer_transcoder/model/serializable_module.py b/crosslayer_transcoder/model/serializable_module.py new file mode 100644 index 0000000..3b4f778 --- /dev/null +++ b/crosslayer_transcoder/model/serializable_module.py @@ -0,0 +1,91 @@ +import importlib +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, Self, TypedDict, Union + +import yaml +from safetensors.torch import load_file, save_file +from torch import nn + + +class ConfigDict(TypedDict): + class_path: str + init_args: Dict[str, Any] + + +class SerializableModule(nn.Module, ABC): + """Base class for modules that can serialize to/from config and save/load to disk.""" + + @abstractmethod + def to_config(self) -> ConfigDict: + """Serialize module configuration to a dict.""" + raise NotImplementedError + + @classmethod + def from_config(cls, config: ConfigDict) -> Self: + """Construct module from a config dict. Weights are not loaded.""" + + module_name, class_name = config["class_path"].rsplit(".", 1) + resolved_cls = getattr(importlib.import_module(module_name), class_name) + if cls != resolved_cls: + raise ValueError( + f"Incorrect class_path specified for building {cls}. Classpath specified: {config['class_path']}" + ) + + init_args = config.get("init_args", {}) + + resolved_args = {} + + for key, value in init_args.items(): + if isinstance(value, dict) and "class_path" in value: + target_module_name, target_class_name = value["class_path"].rsplit( + ".", 1 + ) + target_cls = getattr( + importlib.import_module(target_module_name), target_class_name + ) + resolved_args[key] = target_cls.from_config(value) + else: + resolved_args[key] = value + + return cls(**resolved_args) + + def save_pretrained( + self, + directory: Path, + ) -> None: + """Save config and weights to directory.""" + directory.mkdir(parents=True, exist_ok=True) + + with open(directory / "config.yaml", "w") as f: + yaml.dump({"model": self.to_config()}, f) + + save_file(self.state_dict(), directory / "checkpoint.safetensors") + + @classmethod + def from_pretrained(cls, directory: Union[Path, str]) -> Self: + """Load model from directory.""" + directory = Path(directory) + checkpoint = directory / "checkpoint.safetensors" + config = directory / "config.yaml" + return cls.from_config_and_checkpoint(config, checkpoint) + + @classmethod + def from_config_and_checkpoint(cls, config: Path, checkpoint: Path) -> Self: + if not checkpoint.exists(): + raise FileNotFoundError(f"Checkpoint file not found: {checkpoint}") + if not config.exists(): + raise FileNotFoundError(f"Config file not found: {config}") + + with open(config, "r") as f: + full_config = yaml.load(f, Loader=yaml.FullLoader) + + model_config = full_config.get("model") + if model_config is None: + raise ValueError("Model config not found in config.yaml", full_config) + + model = cls.from_config(model_config) + model.load_state_dict(load_file(checkpoint)) + + model._is_folded = model_config.get("is_folded", False) + return model diff --git a/tests/serializable-module/test_from_config.py b/tests/serializable-module/test_from_config.py new file mode 100644 index 0000000..21c17f4 --- /dev/null +++ b/tests/serializable-module/test_from_config.py @@ -0,0 +1,74 @@ +# Tests needed +# 1. from config +# 2. save pretained +# 3. from pretrained +# 4. from config and checkpoint + +import pytest + +from crosslayer_transcoder.model.clt import ( + CrossLayerTranscoder, + Encoder, +) +from crosslayer_transcoder.model.serializable_module import ( + ConfigDict, +) + + +@pytest.fixture +def config() -> ConfigDict: + return { + "class_path": "crosslayer_transcoder.model.clt.CrossLayerTranscoder", + "init_args": { + "encoder": { + "class_path": "crosslayer_transcoder.model.clt.Encoder", + "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, + }, + "decoder": { + "class_path": "crosslayer_transcoder.model.clt.CrosslayerDecoder", + "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, + }, + "nonlinearity": { + "class_path": "crosslayer_transcoder.model.jumprelu.JumpReLU", + "init_args": { + "theta": 0.03, + "bandwidth": 0.01, + "n_layers": 2, + "d_features": 32, + }, + }, + "input_standardizer": { + "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer", + "init_args": {"n_layers": 2, "activation_dim": 110}, + }, + "output_standardizer": { + "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer", + "init_args": {"n_layers": 2, "activation_dim": 110}, + }, + }, + } + + +def test_from_config_success(config): + module = CrossLayerTranscoder.from_config(config) + + assert isinstance(module, CrossLayerTranscoder) + + +def test_from_config_bad_class_path(config): + with pytest.raises(ValueError): + Encoder.from_config(config) + + +def test_from_config_class_path_dne(config): + config["init_args"]["encoder"]["class_path"] = "path/that/doesnt/exist.py" + + with pytest.raises(Exception): + CrossLayerTranscoder.from_config(config) + + +def test_init_args_missing(config): + del config["init_args"]["encoder"] + + with pytest.raises(Exception): + CrossLayerTranscoder.from_config(config) From 783a77b62afc4f031f6b3045711cc92ef029106d Mon Sep 17 00:00:00 2001 From: jiito Date: Wed, 4 Feb 2026 17:16:59 -0800 Subject: [PATCH 2/2] add generated test coverage --- tests/serializable-module/conftest.py | 56 +++++++++++ tests/serializable-module/test_from_config.py | 57 ++--------- .../test_from_config_and_checkpoint.py | 98 +++++++++++++++++++ .../test_from_pretrained.py | 28 ++++++ 4 files changed, 192 insertions(+), 47 deletions(-) create mode 100644 tests/serializable-module/conftest.py create mode 100644 tests/serializable-module/test_from_config_and_checkpoint.py create mode 100644 tests/serializable-module/test_from_pretrained.py diff --git a/tests/serializable-module/conftest.py b/tests/serializable-module/conftest.py new file mode 100644 index 0000000..9778e48 --- /dev/null +++ b/tests/serializable-module/conftest.py @@ -0,0 +1,56 @@ +import tempfile +import pytest + +from crosslayer_transcoder.model.serializable_module import ConfigDict + + +@pytest.fixture +def config_dict() -> ConfigDict: + return { + "class_path": "crosslayer_transcoder.model.clt.CrossLayerTranscoder", + "init_args": { + "encoder": { + "class_path": "crosslayer_transcoder.model.clt.Encoder", + "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, + }, + "decoder": { + "class_path": "crosslayer_transcoder.model.clt.CrosslayerDecoder", + "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, + }, + "nonlinearity": { + "class_path": "crosslayer_transcoder.model.jumprelu.JumpReLU", + "init_args": { + "theta": 0.03, + "bandwidth": 0.01, + "n_layers": 2, + "d_features": 32, + }, + }, + "input_standardizer": { + "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer", + "init_args": {"n_layers": 2, "activation_dim": 110}, + }, + "output_standardizer": { + "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer", + "init_args": {"n_layers": 2, "activation_dim": 110}, + }, + }, + } + + +@pytest.fixture +def tmp_model_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture +def tmp_config_path(): + with tempfile.NamedTemporaryFile(suffix=".yaml") as tmpfile: + yield tmpfile.name + + +@pytest.fixture +def tmp_checkpoint_path(): + with tempfile.NamedTemporaryFile(suffix=".safetensors") as tmpfile: + yield tmpfile.name diff --git a/tests/serializable-module/test_from_config.py b/tests/serializable-module/test_from_config.py index 21c17f4..a244292 100644 --- a/tests/serializable-module/test_from_config.py +++ b/tests/serializable-module/test_from_config.py @@ -10,65 +10,28 @@ CrossLayerTranscoder, Encoder, ) -from crosslayer_transcoder.model.serializable_module import ( - ConfigDict, -) - - -@pytest.fixture -def config() -> ConfigDict: - return { - "class_path": "crosslayer_transcoder.model.clt.CrossLayerTranscoder", - "init_args": { - "encoder": { - "class_path": "crosslayer_transcoder.model.clt.Encoder", - "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, - }, - "decoder": { - "class_path": "crosslayer_transcoder.model.clt.CrosslayerDecoder", - "init_args": {"d_acts": 110, "d_features": 32, "n_layers": 2}, - }, - "nonlinearity": { - "class_path": "crosslayer_transcoder.model.jumprelu.JumpReLU", - "init_args": { - "theta": 0.03, - "bandwidth": 0.01, - "n_layers": 2, - "d_features": 32, - }, - }, - "input_standardizer": { - "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseInputStandardizer", - "init_args": {"n_layers": 2, "activation_dim": 110}, - }, - "output_standardizer": { - "class_path": "crosslayer_transcoder.model.standardize.DimensionwiseOutputStandardizer", - "init_args": {"n_layers": 2, "activation_dim": 110}, - }, - }, - } -def test_from_config_success(config): - module = CrossLayerTranscoder.from_config(config) +def test_from_config_success(config_dict): + module = CrossLayerTranscoder.from_config(config_dict) assert isinstance(module, CrossLayerTranscoder) -def test_from_config_bad_class_path(config): +def test_from_config_bad_class_path(config_dict): with pytest.raises(ValueError): - Encoder.from_config(config) + Encoder.from_config(config_dict) -def test_from_config_class_path_dne(config): - config["init_args"]["encoder"]["class_path"] = "path/that/doesnt/exist.py" +def test_from_config_class_path_dne(config_dict): + config_dict["init_args"]["encoder"]["class_path"] = "path/that/doesnt/exist.py" with pytest.raises(Exception): - CrossLayerTranscoder.from_config(config) + CrossLayerTranscoder.from_config(config_dict) -def test_init_args_missing(config): - del config["init_args"]["encoder"] +def test_init_args_missing(config_dict): + del config_dict["init_args"]["encoder"] with pytest.raises(Exception): - CrossLayerTranscoder.from_config(config) + CrossLayerTranscoder.from_config(config_dict) diff --git a/tests/serializable-module/test_from_config_and_checkpoint.py b/tests/serializable-module/test_from_config_and_checkpoint.py new file mode 100644 index 0000000..99087ca --- /dev/null +++ b/tests/serializable-module/test_from_config_and_checkpoint.py @@ -0,0 +1,98 @@ +from pathlib import Path + +import pytest +import yaml +from safetensors.torch import save_file + +from crosslayer_transcoder.model.clt import CrossLayerTranscoder + + +def test_missing_checkpoint(config_dict, tmp_model_dir): + config_path = Path(tmp_model_dir) / "config.yaml" + with open(config_path, "w") as f: + yaml.dump({"model": config_dict}, f) + + with pytest.raises(FileNotFoundError, match="Checkpoint file not found"): + CrossLayerTranscoder.from_config_and_checkpoint( + config_path, Path(tmp_model_dir) / "checkpoint.safetensors" + ) + + +def test_missing_config(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + checkpoint_path = Path(tmp_model_dir) / "checkpoint.safetensors" + save_file(model.state_dict(), checkpoint_path) + + with pytest.raises(FileNotFoundError, match="Config file not found"): + CrossLayerTranscoder.from_config_and_checkpoint( + Path(tmp_model_dir) / "config.yaml", checkpoint_path + ) + + +def test_missing_model_key_in_config(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + with open(config_path, "w") as f: + yaml.dump({"wrong_key": config_dict}, f) + save_file(model.state_dict(), checkpoint_path) + + with pytest.raises(ValueError, match="Model config not found"): + CrossLayerTranscoder.from_config_and_checkpoint(config_path, checkpoint_path) + + +def test_is_folded_true(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + config_with_folded = config_dict.copy() + config_with_folded["is_folded"] = True + with open(config_path, "w") as f: + yaml.dump({"model": config_with_folded}, f) + save_file(model.state_dict(), checkpoint_path) + + loaded = CrossLayerTranscoder.from_config_and_checkpoint( + config_path, checkpoint_path + ) + + assert loaded._is_folded is True + + +def test_is_folded_false(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + config_with_folded = config_dict.copy() + config_with_folded["is_folded"] = False + with open(config_path, "w") as f: + yaml.dump({"model": config_with_folded}, f) + save_file(model.state_dict(), checkpoint_path) + + loaded = CrossLayerTranscoder.from_config_and_checkpoint( + config_path, checkpoint_path + ) + + assert loaded._is_folded is False + + +def test_is_folded_defaults_false_when_missing(config_dict, tmp_model_dir): + model = CrossLayerTranscoder.from_config(config_dict) + dir_path = Path(tmp_model_dir) + config_path = dir_path / "config.yaml" + checkpoint_path = dir_path / "checkpoint.safetensors" + + with open(config_path, "w") as f: + yaml.dump({"model": config_dict}, f) + save_file(model.state_dict(), checkpoint_path) + + loaded = CrossLayerTranscoder.from_config_and_checkpoint( + config_path, checkpoint_path + ) + + assert loaded._is_folded is False diff --git a/tests/serializable-module/test_from_pretrained.py b/tests/serializable-module/test_from_pretrained.py new file mode 100644 index 0000000..82d23e6 --- /dev/null +++ b/tests/serializable-module/test_from_pretrained.py @@ -0,0 +1,28 @@ +from pathlib import Path + +import torch + +from crosslayer_transcoder.model.clt import CrossLayerTranscoder + + +def test_round_trip(config_dict, tmp_model_dir): + original = CrossLayerTranscoder.from_config(config_dict) + original.save_pretrained(Path(tmp_model_dir)) + + loaded = CrossLayerTranscoder.from_pretrained(tmp_model_dir) + + assert isinstance(loaded, CrossLayerTranscoder) + original_state = original.state_dict() + loaded_state = loaded.state_dict() + assert original_state.keys() == loaded_state.keys() + for key in original_state: + assert torch.allclose(original_state[key], loaded_state[key], equal_nan=True) + + +def test_accepts_string_path(config_dict, tmp_model_dir): + original = CrossLayerTranscoder.from_config(config_dict) + original.save_pretrained(Path(tmp_model_dir)) + + loaded = CrossLayerTranscoder.from_pretrained(str(tmp_model_dir)) + + assert isinstance(loaded, CrossLayerTranscoder)