diff --git a/src/fiddledyn/core/context.py b/src/fiddledyn/core/context.py index 88ad0b4..6994083 100644 --- a/src/fiddledyn/core/context.py +++ b/src/fiddledyn/core/context.py @@ -76,10 +76,12 @@ class SerializationContext(BaseContext): Tracks which objects have been serialized and assigns unique IDs for objects that appear multiple times (DAG structure). - Attributes: + Attributes: ref_counts: Pre-computed reference counts for each object. include_defaults: Whether to include default parameter values. deep_defaults: Whether to recursively expand callable defaults. + json_serializable: Whether to coerce leaf values into JSON-safe + scalar/list forms. memo: Maps object id() to assigned _id_ string. Example: @@ -93,6 +95,7 @@ def __init__( ref_counts: dict[int, int], include_defaults: bool = False, deep_defaults: bool = True, + json_serializable: bool = False, ) -> None: """Initialize a serialization context. @@ -107,10 +110,13 @@ def __init__( include_defaults is True. If False, callable defaults are serialized with just `_target_` and `_call_: false` without their own defaults. + json_serializable: Whether to coerce leaf values into JSON-safe + scalar/list forms. """ self.ref_counts = ref_counts self.include_defaults = include_defaults self.deep_defaults = deep_defaults + self.json_serializable = json_serializable self.memo: dict[int, str] = {} self._counter = 0 diff --git a/src/fiddledyn/serialization/serializer.py b/src/fiddledyn/serialization/serializer.py index 8fb2f64..57675c7 100644 --- a/src/fiddledyn/serialization/serializer.py +++ b/src/fiddledyn/serialization/serializer.py @@ -17,6 +17,16 @@ import fiddle as fdl import yaml +try: + import numpy as np +except ImportError: # pragma: no cover - optional dependency + np = None + +try: + import torch +except ImportError: # pragma: no cover - optional dependency + torch = None + from ..core import SerializationContext from ..core.types import HAS_NEMO_RUN, run from ..utils import get_target_name @@ -78,6 +88,33 @@ def _get_target_path(fn_or_cls: Any) -> str: return get_target_name(fn_or_cls) +def _is_numpy_array(obj: Any) -> bool: + return np is not None and isinstance(obj, np.ndarray) + + +def _is_torch_tensor(obj: Any) -> bool: + return torch is not None and isinstance(obj, torch.Tensor) + + +def _to_json_value(obj: Any) -> Any: + """Convert a value to a JSON-serializable scalar/list form. + + The output is restricted to strings, floats, or lists of such values. + """ + if _is_torch_tensor(obj) or _is_numpy_array(obj): + return obj.tolist() + + if isinstance(obj, (str, bool, int, float)): + return obj + + if obj is None: + return obj + + raise ValueError( + f"Cannot convert object of type {type(obj)} to JSON-serializable value." + ) + + class ConfigSerializer(BaseSerializer): """Serializes Config/Partial objects to dictionaries. @@ -92,6 +129,8 @@ class ConfigSerializer(BaseSerializer): Attributes: include_defaults: Whether to include default parameter values. deep_defaults: Whether to recursively expand callable defaults. + json_serializable: Whether to coerce leaf values into JSON-safe + scalar/list forms. Example: >>> serializer = ConfigSerializer() @@ -110,7 +149,10 @@ class ConfigSerializer(BaseSerializer): """ def __init__( - self, include_defaults: bool = False, deep_defaults: bool = True + self, + include_defaults: bool = False, + deep_defaults: bool = True, + json_serializable: bool = False, ) -> None: """Initialize the serializer. @@ -123,9 +165,12 @@ def __init__( expanded to include their own parameter defaults. If False, callable defaults are serialized with just `_target_` and `_call_: false` without their own defaults. + json_serializable: If True, coerce leaf values into JSON-safe + scalar/list forms. """ self.include_defaults = include_defaults self.deep_defaults = deep_defaults + self.json_serializable = json_serializable def serialize(self, obj: Any) -> Any: """Serialize a configuration object to a dictionary. @@ -144,7 +189,12 @@ def serialize(self, obj: Any) -> Any: ref_counts: dict[int, int] = collections.defaultdict(int) self._tally_references(obj, ref_counts, set()) - ctx = SerializationContext(ref_counts, self.include_defaults, self.deep_defaults) + ctx = SerializationContext( + ref_counts, + self.include_defaults, + self.deep_defaults, + json_serializable=self.json_serializable, + ) return self._serialize_recursive(obj, ctx) def to_yaml(self, obj: Any, file_path: str | None = None) -> str | None: @@ -293,6 +343,8 @@ def _serialize_recursive( return output if not _is_config_like(cfg): + if ctx.json_serializable: + return _to_json_value(cfg) return cfg obj_id = id(cfg) @@ -378,7 +430,10 @@ def _serialize_recursive( def config_to_dict( - cfg: Any, include_defaults: bool = False, deep_defaults: bool = True + cfg: Any, + include_defaults: bool = False, + deep_defaults: bool = True, + json_serializable: bool = False, ) -> Any: """Convert a Config object to a dictionary. @@ -389,10 +444,12 @@ def config_to_dict( Args: cfg: The configuration object to serialize. include_defaults: If True, include default parameter values. - deep_defaults: If True (default) and include_defaults is True, + deep_defaults: If True (default) and include_defaults is True, callable defaults are recursively expanded to include their own parameter defaults. If False, callable defaults are serialized with just `_target_` and `_call_: false`. + json_serializable: If True, coerce leaf values into JSON-safe + scalar/list forms. Returns: A dictionary representation of the configuration. @@ -410,7 +467,9 @@ def config_to_dict( >>> data = config_to_dict(config, include_defaults=True, deep_defaults=False) """ serializer = ConfigSerializer( - include_defaults=include_defaults, deep_defaults=deep_defaults + include_defaults=include_defaults, + deep_defaults=deep_defaults, + json_serializable=json_serializable, ) return serializer.serialize(cfg) @@ -420,6 +479,7 @@ def dump_yaml( file_path: str | None = None, include_defaults: bool = False, deep_defaults: bool = True, + json_serializable: bool = False, ) -> str | None: """Serialize a Config object to YAML format. @@ -431,10 +491,12 @@ def dump_yaml( cfg: The configuration object to serialize. file_path: Optional path to write the YAML output. include_defaults: If True, include default parameter values. - deep_defaults: If True (default) and include_defaults is True, + deep_defaults: If True (default) and include_defaults is True, callable defaults are recursively expanded to include their own parameter defaults. If False, callable defaults are serialized with just `_target_` and `_call_: false`. + json_serializable: If True, coerce leaf values into JSON-safe + scalar/list forms. Returns: The YAML string if file_path is None, otherwise None. @@ -450,6 +512,8 @@ def dump_yaml( >>> yaml_str = dump_yaml(config, include_defaults=True, deep_defaults=False) """ serializer = ConfigSerializer( - include_defaults=include_defaults, deep_defaults=deep_defaults + include_defaults=include_defaults, + deep_defaults=deep_defaults, + json_serializable=json_serializable, ) return serializer.to_yaml(cfg, file_path) diff --git a/src/fiddledyn/utils.py b/src/fiddledyn/utils.py index c16bf90..182d596 100644 --- a/src/fiddledyn/utils.py +++ b/src/fiddledyn/utils.py @@ -114,3 +114,51 @@ def get_target_name(fn_or_cls: Any) -> str: pass return f"{module}.{name}" + + +_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None)) + + +def is_jsonable(obj: Any, _visited: set[int] | None = None) -> bool: + """Check if an object is JSON serializable. + + This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object. + It works correctly for basic use cases but do not guarantee an exhaustive check. + + Object is considered to be recursively json serializable if: + - it is an instance of int, float, str, bool, or NoneType + - it is a list or tuple and all its items are json serializable + - it is a dict and all its keys are strings and all its values are json serializable + + Uses a visited set to avoid infinite recursion on circular references. If object has already been visited, it is + considered not json serializable. + """ + # Initialize visited set to track object ids and detect circular references + if _visited is None: + _visited = set() + + # Detect circular reference + obj_id = id(obj) + if obj_id in _visited: + return False + + # Add current object to visited before recursive checks + _visited.add(obj_id) + try: + if isinstance(obj, _JSON_SERIALIZABLE_TYPES): + return True + if isinstance(obj, (list, tuple)): + return all(is_jsonable(item, _visited) for item in obj) + if isinstance(obj, dict): + return all( + isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value, _visited) + for key, value in obj.items() + ) + if hasattr(obj, "__json__"): + return True + return False + except RecursionError: + return False + finally: + # Remove the object id from visited to avoid side‑effects for other branches + _visited.discard(obj_id) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 0115f1a..bff0e71 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,9 +3,11 @@ from math import sqrt import nemo_run as run +import pytest from fiddledyn import config_to_dict from fiddledyn.serialization import ConfigSerializer +from fiddledyn.utils import is_jsonable from .conftest import ( Decoder, @@ -446,3 +448,44 @@ def test_deeply_nested(self) -> None: assert result["model"]["encoder"]["vocab_size"] == 10000 assert result["optimizer"]["lr"] == 0.001 + + +class TestJsonSerializableCheck: + """Tests for json_serializable mode.""" + + def test_json_serializable_converts_np_array(self) -> None: + np = pytest.importorskip("numpy") + + arr = np.array([1, 2, 3]) + config = run.Config(SimpleClass, x=arr) + result = config_to_dict(config, json_serializable=True) + + assert is_jsonable(result) + assert result["x"] == [1.0, 2.0, 3.0] + + def test_json_serializable_converts_torch_tensor(self) -> None: + torch = pytest.importorskip("torch") + + arr = torch.tensor([1, 2, 3]) + config = run.Config(SimpleClass, x=arr) + result = config_to_dict(config, json_serializable=True) + + assert is_jsonable(result) + assert result["x"] == [1.0, 2.0, 3.0] + + def test_json_serializable_converts_none(self) -> None: + + arr = None + config = run.Config(SimpleClass, x=arr) + result = config_to_dict(config, json_serializable=True) + + assert is_jsonable(result) + assert result["x"] == None + + def test_json_serializable_catch_none_serialization(self) -> None: + + arr = SimpleClass(x=1) + config = run.Config(SimpleClass, x=arr) + # catch that the serialization raises an error + with pytest.raises(ValueError): + _ = config_to_dict(config, json_serializable=True)