Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/fiddledyn/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ class SerializationContext(BaseContext):
Tracks which objects have been serialized and assigns unique IDs
for objects that appear multiple times (DAG structure).

Attributes:
Attributes:
ref_counts: Pre-computed reference counts for each object.
include_defaults: Whether to include default parameter values.
deep_defaults: Whether to recursively expand callable defaults.
json_serializable: Whether to coerce leaf values into JSON-safe
scalar/list forms.
memo: Maps object id() to assigned _id_ string.

Example:
Expand All @@ -93,6 +95,7 @@ def __init__(
ref_counts: dict[int, int],
include_defaults: bool = False,
deep_defaults: bool = True,
json_serializable: bool = False,
) -> None:
"""Initialize a serialization context.

Expand All @@ -107,10 +110,13 @@ def __init__(
include_defaults is True. If False, callable defaults are
serialized with just `_target_` and `_call_: false` without
their own defaults.
json_serializable: Whether to coerce leaf values into JSON-safe
scalar/list forms.
"""
self.ref_counts = ref_counts
self.include_defaults = include_defaults
self.deep_defaults = deep_defaults
self.json_serializable = json_serializable
self.memo: dict[int, str] = {}
self._counter = 0

Expand Down
78 changes: 71 additions & 7 deletions src/fiddledyn/serialization/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
import fiddle as fdl
import yaml

try:
import numpy as np
except ImportError: # pragma: no cover - optional dependency
np = None

try:
import torch

Check failure on line 26 in src/fiddledyn/serialization/serializer.py

View workflow job for this annotation

GitHub Actions / static-analysis

Import "torch" could not be resolved (reportMissingImports)

Check failure on line 26 in src/fiddledyn/serialization/serializer.py

View workflow job for this annotation

GitHub Actions / static-analysis

Import "torch" could not be resolved (reportMissingImports)
except ImportError: # pragma: no cover - optional dependency
torch = None

from ..core import SerializationContext
from ..core.types import HAS_NEMO_RUN, run
from ..utils import get_target_name
Expand Down Expand Up @@ -78,6 +88,33 @@
return get_target_name(fn_or_cls)


def _is_numpy_array(obj: Any) -> bool:
return np is not None and isinstance(obj, np.ndarray)


def _is_torch_tensor(obj: Any) -> bool:
return torch is not None and isinstance(obj, torch.Tensor)


def _to_json_value(obj: Any) -> Any:
"""Convert a value to a JSON-serializable scalar/list form.

The output is restricted to strings, floats, or lists of such values.
"""
if _is_torch_tensor(obj) or _is_numpy_array(obj):
return obj.tolist()

if isinstance(obj, (str, bool, int, float)):
return obj

if obj is None:
return obj

raise ValueError(
f"Cannot convert object of type {type(obj)} to JSON-serializable value."
)


class ConfigSerializer(BaseSerializer):
"""Serializes Config/Partial objects to dictionaries.

Expand All @@ -92,6 +129,8 @@
Attributes:
include_defaults: Whether to include default parameter values.
deep_defaults: Whether to recursively expand callable defaults.
json_serializable: Whether to coerce leaf values into JSON-safe
scalar/list forms.

Example:
>>> serializer = ConfigSerializer()
Expand All @@ -110,7 +149,10 @@
"""

def __init__(
self, include_defaults: bool = False, deep_defaults: bool = True
self,
include_defaults: bool = False,
deep_defaults: bool = True,
json_serializable: bool = False,
) -> None:
"""Initialize the serializer.

Expand All @@ -123,9 +165,12 @@
expanded to include their own parameter defaults. If False,
callable defaults are serialized with just `_target_` and
`_call_: false` without their own defaults.
json_serializable: If True, coerce leaf values into JSON-safe
scalar/list forms.
"""
self.include_defaults = include_defaults
self.deep_defaults = deep_defaults
self.json_serializable = json_serializable

def serialize(self, obj: Any) -> Any:
"""Serialize a configuration object to a dictionary.
Expand All @@ -144,7 +189,12 @@
ref_counts: dict[int, int] = collections.defaultdict(int)
self._tally_references(obj, ref_counts, set())

ctx = SerializationContext(ref_counts, self.include_defaults, self.deep_defaults)
ctx = SerializationContext(
ref_counts,
self.include_defaults,
self.deep_defaults,
json_serializable=self.json_serializable,
)
return self._serialize_recursive(obj, ctx)

def to_yaml(self, obj: Any, file_path: str | None = None) -> str | None:
Expand Down Expand Up @@ -293,6 +343,8 @@
return output

if not _is_config_like(cfg):
if ctx.json_serializable:
return _to_json_value(cfg)
return cfg

obj_id = id(cfg)
Expand Down Expand Up @@ -378,7 +430,10 @@


def config_to_dict(
cfg: Any, include_defaults: bool = False, deep_defaults: bool = True
cfg: Any,
include_defaults: bool = False,
deep_defaults: bool = True,
json_serializable: bool = False,
) -> Any:
"""Convert a Config object to a dictionary.

Expand All @@ -389,10 +444,12 @@
Args:
cfg: The configuration object to serialize.
include_defaults: If True, include default parameter values.
deep_defaults: If True (default) and include_defaults is True,
deep_defaults: If True (default) and include_defaults is True,
callable defaults are recursively expanded to include their
own parameter defaults. If False, callable defaults are
serialized with just `_target_` and `_call_: false`.
json_serializable: If True, coerce leaf values into JSON-safe
scalar/list forms.

Returns:
A dictionary representation of the configuration.
Expand All @@ -410,7 +467,9 @@
>>> data = config_to_dict(config, include_defaults=True, deep_defaults=False)
"""
serializer = ConfigSerializer(
include_defaults=include_defaults, deep_defaults=deep_defaults
include_defaults=include_defaults,
deep_defaults=deep_defaults,
json_serializable=json_serializable,
)
return serializer.serialize(cfg)

Expand All @@ -420,6 +479,7 @@
file_path: str | None = None,
include_defaults: bool = False,
deep_defaults: bool = True,
json_serializable: bool = False,
) -> str | None:
"""Serialize a Config object to YAML format.

Expand All @@ -431,10 +491,12 @@
cfg: The configuration object to serialize.
file_path: Optional path to write the YAML output.
include_defaults: If True, include default parameter values.
deep_defaults: If True (default) and include_defaults is True,
deep_defaults: If True (default) and include_defaults is True,
callable defaults are recursively expanded to include their
own parameter defaults. If False, callable defaults are
serialized with just `_target_` and `_call_: false`.
json_serializable: If True, coerce leaf values into JSON-safe
scalar/list forms.

Returns:
The YAML string if file_path is None, otherwise None.
Expand All @@ -450,6 +512,8 @@
>>> yaml_str = dump_yaml(config, include_defaults=True, deep_defaults=False)
"""
serializer = ConfigSerializer(
include_defaults=include_defaults, deep_defaults=deep_defaults
include_defaults=include_defaults,
deep_defaults=deep_defaults,
json_serializable=json_serializable,
)
return serializer.to_yaml(cfg, file_path)
48 changes: 48 additions & 0 deletions src/fiddledyn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,51 @@ def get_target_name(fn_or_cls: Any) -> str:
pass

return f"{module}.{name}"


_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None))


def is_jsonable(obj: Any, _visited: set[int] | None = None) -> bool:
"""Check if an object is JSON serializable.

This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object.
It works correctly for basic use cases but do not guarantee an exhaustive check.

Object is considered to be recursively json serializable if:
- it is an instance of int, float, str, bool, or NoneType
- it is a list or tuple and all its items are json serializable
- it is a dict and all its keys are strings and all its values are json serializable

Uses a visited set to avoid infinite recursion on circular references. If object has already been visited, it is
considered not json serializable.
"""
# Initialize visited set to track object ids and detect circular references
if _visited is None:
_visited = set()

# Detect circular reference
obj_id = id(obj)
if obj_id in _visited:
return False

# Add current object to visited before recursive checks
_visited.add(obj_id)
try:
if isinstance(obj, _JSON_SERIALIZABLE_TYPES):
return True
if isinstance(obj, (list, tuple)):
return all(is_jsonable(item, _visited) for item in obj)
if isinstance(obj, dict):
return all(
isinstance(key, _JSON_SERIALIZABLE_TYPES) and is_jsonable(value, _visited)
for key, value in obj.items()
)
if hasattr(obj, "__json__"):
return True
return False
except RecursionError:
return False
finally:
# Remove the object id from visited to avoid side‑effects for other branches
_visited.discard(obj_id)
43 changes: 43 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from math import sqrt

import nemo_run as run
import pytest

from fiddledyn import config_to_dict
from fiddledyn.serialization import ConfigSerializer
from fiddledyn.utils import is_jsonable

from .conftest import (
Decoder,
Expand Down Expand Up @@ -446,3 +448,44 @@ def test_deeply_nested(self) -> None:

assert result["model"]["encoder"]["vocab_size"] == 10000
assert result["optimizer"]["lr"] == 0.001


class TestJsonSerializableCheck:
"""Tests for json_serializable mode."""

def test_json_serializable_converts_np_array(self) -> None:
np = pytest.importorskip("numpy")

arr = np.array([1, 2, 3])
config = run.Config(SimpleClass, x=arr)
result = config_to_dict(config, json_serializable=True)

assert is_jsonable(result)
assert result["x"] == [1.0, 2.0, 3.0]

def test_json_serializable_converts_torch_tensor(self) -> None:
torch = pytest.importorskip("torch")

arr = torch.tensor([1, 2, 3])
config = run.Config(SimpleClass, x=arr)
result = config_to_dict(config, json_serializable=True)

assert is_jsonable(result)
assert result["x"] == [1.0, 2.0, 3.0]

def test_json_serializable_converts_none(self) -> None:

arr = None
config = run.Config(SimpleClass, x=arr)
result = config_to_dict(config, json_serializable=True)

assert is_jsonable(result)
assert result["x"] == None

def test_json_serializable_catch_none_serialization(self) -> None:

arr = SimpleClass(x=1)
config = run.Config(SimpleClass, x=arr)
# catch that the serialization raises an error
with pytest.raises(ValueError):
_ = config_to_dict(config, json_serializable=True)
Loading