From 56164f8a1f733fa72cca1d46415ebd7c3a25b486 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Tue, 3 Feb 2026 16:07:29 +0000 Subject: [PATCH] add json_serializable arg --- src/fiddledyn/core/context.py | 8 ++- src/fiddledyn/serialization/serializer.py | 78 +++++++++++++++++++++-- src/fiddledyn/utils.py | 48 ++++++++++++++ tests/test_serialization.py | 43 +++++++++++++ 4 files changed, 169 insertions(+), 8 deletions(-) diff --git a/src/fiddledyn/core/context.py b/src/fiddledyn/core/context.py index 88ad0b4..6994083 100644 --- a/src/fiddledyn/core/context.py +++ b/src/fiddledyn/core/context.py @@ -76,10 +76,12 @@ class SerializationContext(BaseContext): Tracks which objects have been serialized and assigns unique IDs for objects that appear multiple times (DAG structure). - Attributes: + Attributes: ref_counts: Pre-computed reference counts for each object. include_defaults: Whether to include default parameter values. deep_defaults: Whether to recursively expand callable defaults. + json_serializable: Whether to coerce leaf values into JSON-safe + scalar/list forms. memo: Maps object id() to assigned _id_ string. Example: @@ -93,6 +95,7 @@ def __init__( ref_counts: dict[int, int], include_defaults: bool = False, deep_defaults: bool = True, + json_serializable: bool = False, ) -> None: """Initialize a serialization context. @@ -107,10 +110,13 @@ def __init__( include_defaults is True. If False, callable defaults are serialized with just `_target_` and `_call_: false` without their own defaults. + json_serializable: Whether to coerce leaf values into JSON-safe + scalar/list forms. """ self.ref_counts = ref_counts self.include_defaults = include_defaults self.deep_defaults = deep_defaults + self.json_serializable = json_serializable self.memo: dict[int, str] = {} self._counter = 0 diff --git a/src/fiddledyn/serialization/serializer.py b/src/fiddledyn/serialization/serializer.py index 3fc0566..c1bf9d1 100644 --- a/src/fiddledyn/serialization/serializer.py +++ b/src/fiddledyn/serialization/serializer.py @@ -17,6 +17,16 @@ import fiddle as fdl import yaml +try: + import numpy as np +except ImportError: # pragma: no cover - optional dependency + np = None + +try: + import torch +except ImportError: # pragma: no cover - optional dependency + torch = None + from ..core import SerializationContext from ..core.types import HAS_NEMO_RUN, run from .base import BaseSerializer @@ -77,6 +87,33 @@ def _get_target_path(fn_or_cls: Any) -> str: return f"{module}.{qualname}" +def _is_numpy_array(obj: Any) -> bool: + return np is not None and isinstance(obj, np.ndarray) + + +def _is_torch_tensor(obj: Any) -> bool: + return torch is not None and isinstance(obj, torch.Tensor) + + +def _to_json_value(obj: Any) -> Any: + """Convert a value to a JSON-serializable scalar/list form. + + The output is restricted to strings, floats, or lists of such values. + """ + if _is_torch_tensor(obj) or _is_numpy_array(obj): + return obj.tolist() + + if isinstance(obj, (str, bool, int, float)): + return obj + + if obj is None: + return obj + + raise ValueError( + f"Cannot convert object of type {type(obj)} to JSON-serializable value." + ) + + class ConfigSerializer(BaseSerializer): """Serializes Config/Partial objects to dictionaries. @@ -91,6 +128,8 @@ class ConfigSerializer(BaseSerializer): Attributes: include_defaults: Whether to include default parameter values. deep_defaults: Whether to recursively expand callable defaults. + json_serializable: Whether to coerce leaf values into JSON-safe + scalar/list forms. Example: >>> serializer = ConfigSerializer() @@ -109,7 +148,10 @@ class ConfigSerializer(BaseSerializer): """ def __init__( - self, include_defaults: bool = False, deep_defaults: bool = True + self, + include_defaults: bool = False, + deep_defaults: bool = True, + json_serializable: bool = False, ) -> None: """Initialize the serializer. @@ -122,9 +164,12 @@ def __init__( expanded to include their own parameter defaults. If False, callable defaults are serialized with just `_target_` and `_call_: false` without their own defaults. + json_serializable: If True, coerce leaf values into JSON-safe + scalar/list forms. """ self.include_defaults = include_defaults self.deep_defaults = deep_defaults + self.json_serializable = json_serializable def serialize(self, obj: Any) -> Any: """Serialize a configuration object to a dictionary. @@ -143,7 +188,12 @@ def serialize(self, obj: Any) -> Any: ref_counts: dict[int, int] = collections.defaultdict(int) self._tally_references(obj, ref_counts, set()) - ctx = SerializationContext(ref_counts, self.include_defaults, self.deep_defaults) + ctx = SerializationContext( + ref_counts, + self.include_defaults, + self.deep_defaults, + json_serializable=self.json_serializable, + ) return self._serialize_recursive(obj, ctx) def to_yaml(self, obj: Any, file_path: str | None = None) -> str | None: @@ -292,6 +342,8 @@ def _serialize_recursive( return output if not _is_config_like(cfg): + if ctx.json_serializable: + return _to_json_value(cfg) return cfg obj_id = id(cfg) @@ -377,7 +429,10 @@ def _serialize_recursive( def config_to_dict( - cfg: Any, include_defaults: bool = False, deep_defaults: bool = True + cfg: Any, + include_defaults: bool = False, + deep_defaults: bool = True, + json_serializable: bool = False, ) -> Any: """Convert a Config object to a dictionary. @@ -388,10 +443,12 @@ def config_to_dict( Args: cfg: The configuration object to serialize. include_defaults: If True, include default parameter values. - deep_defaults: If True (default) and include_defaults is True, + deep_defaults: If True (default) and include_defaults is True, callable defaults are recursively expanded to include their own parameter defaults. If False, callable defaults are serialized with just `_target_` and `_call_: false`. + json_serializable: If True, coerce leaf values into JSON-safe + scalar/list forms. Returns: A dictionary representation of the configuration. @@ -409,7 +466,9 @@ def config_to_dict( >>> data = config_to_dict(config, include_defaults=True, deep_defaults=False) """ serializer = ConfigSerializer( - include_defaults=include_defaults, deep_defaults=deep_defaults + include_defaults=include_defaults, + deep_defaults=deep_defaults, + json_serializable=json_serializable, ) return serializer.serialize(cfg) @@ -419,6 +478,7 @@ def dump_yaml( file_path: str | None = None, include_defaults: bool = False, deep_defaults: bool = True, + json_serializable: bool = False, ) -> str | None: """Serialize a Config object to YAML format. @@ -430,10 +490,12 @@ def dump_yaml( cfg: The configuration object to serialize. file_path: Optional path to write the YAML output. include_defaults: If True, include default parameter values. - deep_defaults: If True (default) and include_defaults is True, + deep_defaults: If True (default) and include_defaults is True, callable defaults are recursively expanded to include their own parameter defaults. If False, callable defaults are serialized with just `_target_` and `_call_: false`. + json_serializable: If True, coerce leaf values into JSON-safe + scalar/list forms. Returns: The YAML string if file_path is None, otherwise None. @@ -449,6 +511,8 @@ def dump_yaml( >>> yaml_str = dump_yaml(config, include_defaults=True, deep_defaults=False) """ serializer = ConfigSerializer( - include_defaults=include_defaults, deep_defaults=deep_defaults + include_defaults=include_defaults, + deep_defaults=deep_defaults, + json_serializable=json_serializable, ) return serializer.to_yaml(cfg, file_path) diff --git a/src/fiddledyn/utils.py b/src/fiddledyn/utils.py index 0f4f676..b5b9fd0 100644 --- a/src/fiddledyn/utils.py +++ b/src/fiddledyn/utils.py @@ -85,3 +85,51 @@ def get_target_name(fn_or_cls: Any) -> str: if hasattr(fn_or_cls, "__module__") and hasattr(fn_or_cls, "__qualname__"): return f"{fn_or_cls.__module__}.{fn_or_cls.__qualname__}" return str(fn_or_cls) + + +_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None)) + + +def is_jsonable(obj: Any, _visited: set[int] | None = None) -> bool: + """Check if an object is JSON serializable. + + This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object. + It works correctly for basic use cases but do not guarantee an exhaustive check. + + Object is considered to be recursively json serializable if: + - it is an instance of int, float, str, bool, or NoneType + - it is a list or tuple and all its items are json serializable + - it is a dict and all its keys are strings and all its values are json serializable + + Uses a visited set to avoid infinite recursion on circular references. If object has already been visited, it is + considered not json serializable. + """ + # Initialize visited set to track object ids and detect circular references + if _visited is None: + _visited = set() + + # Detect circular reference + obj_id = id(obj) + if obj_id in _visited: + return False + + # Add current object to visited before recursive checks + _visited.add(obj_id) + try: + if isinstance(obj, _JSON_SERIALIZABLE_TYPES): + return True + if isinstance(obj, (list, tuple)): + return all(is_jsonable(item, _visited) for item in obj) + if isinstance(obj, dict): + return all( + isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value, _visited) + for key, value in obj.items() + ) + if hasattr(obj, "__json__"): + return True + return False + except RecursionError: + return False + finally: + # Remove the object id from visited to avoid side‑effects for other branches + _visited.discard(obj_id) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 8849d8a..c30d1a5 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,9 +3,11 @@ from math import sqrt import nemo_run as run +import pytest from fiddledyn import config_to_dict from fiddledyn.serialization import ConfigSerializer +from fiddledyn.utils import is_jsonable from .conftest import ( Decoder, @@ -384,3 +386,44 @@ def test_deeply_nested(self) -> None: assert result["model"]["encoder"]["vocab_size"] == 10000 assert result["optimizer"]["lr"] == 0.001 + + +class TestJsonSerializableCheck: + """Tests for json_serializable mode.""" + + def test_json_serializable_converts_np_array(self) -> None: + np = pytest.importorskip("numpy") + + arr = np.array([1, 2, 3]) + config = run.Config(SimpleClass, x=arr) + result = config_to_dict(config, json_serializable=True) + + assert is_jsonable(result) + assert result["x"] == [1.0, 2.0, 3.0] + + def test_json_serializable_converts_torch_tensor(self) -> None: + torch = pytest.importorskip("torch") + + arr = torch.tensor([1, 2, 3]) + config = run.Config(SimpleClass, x=arr) + result = config_to_dict(config, json_serializable=True) + + assert is_jsonable(result) + assert result["x"] == [1.0, 2.0, 3.0] + + def test_json_serializable_converts_none(self) -> None: + + arr = None + config = run.Config(SimpleClass, x=arr) + result = config_to_dict(config, json_serializable=True) + + assert is_jsonable(result) + assert result["x"] == None + + def test_json_serializable_catch_none_serialization(self) -> None: + + arr = SimpleClass(x=1) + config = run.Config(SimpleClass, x=arr) + # catch that the serialization raises an error + with pytest.raises(ValueError): + _ = config_to_dict(config, json_serializable=True)