From 7670ba24acfc49dd015765c89d65aaac9ba93dca Mon Sep 17 00:00:00 2001 From: Yunjie Xu Date: Mon, 4 May 2026 10:40:05 -0700 Subject: [PATCH] Add support for native_serialization_platforms in JaxDataProcessor. PiperOrigin-RevId: 910114766 --- .../data_processors/jax_data_processor.py | 15 +++++++++++++ .../jax_data_processor_test.py | 21 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/export/orbax/export/data_processors/jax_data_processor.py b/export/orbax/export/data_processors/jax_data_processor.py index 369c83622..85eb08864 100644 --- a/export/orbax/export/data_processors/jax_data_processor.py +++ b/export/orbax/export/data_processors/jax_data_processor.py @@ -20,6 +20,7 @@ import jax import jaxtyping +from orbax.experimental.model.core.protos import manifest_pb2 from orbax.export import constants from orbax.export import obm_configs from orbax.export.data_processors import data_processor_base @@ -150,12 +151,26 @@ def _save_checkpoint( self._params, ) + native_serialization_platforms = ( + self._options.native_serialization_platforms + ) + if native_serialization_platforms is None: + platforms = None + elif isinstance(native_serialization_platforms, str): + platforms = [native_serialization_platforms] + else: + platforms = native_serialization_platforms + + if platforms is not None: + platforms = [manifest_pb2.Platform.Value(p.upper()) for p in platforms] + # Convert the JAX function to an Orbax Model function using jax2obm, # making it compatible with the Orbax Export framework. self._obm_function = jax2obm.convert( fun_jax=self._processor_callable, args_spec=args_spec, kwargs_spec={}, + platforms=platforms, native_serialization_disabled_checks=self._options.native_serialization_disabled_checks, model_param_names=jax.tree.leaves(param_names_tree), # TODO: b/485622993 - Add other options if needed. diff --git a/export/orbax/export/data_processors/jax_data_processor_test.py b/export/orbax/export/data_processors/jax_data_processor_test.py index e3fa0fc9f..0a43bda7d 100644 --- a/export/orbax/export/data_processors/jax_data_processor_test.py +++ b/export/orbax/export/data_processors/jax_data_processor_test.py @@ -145,6 +145,27 @@ def add_params(params: Mapping[str, Any], x: jax.Array) -> jax.Array: checkpoint_path = pathlib.Path(temp_dir) / 'processor' / 'add_params' self.assertTrue(checkpoint_path.exists()) + def test_prepare_with_platforms_option(self): + def add(x: jax.Array) -> jax.Array: + return x + 1.0 + + processor = jax_data_processor.JaxDataProcessor( + add, + name='add', + options=obm_configs.Jax2ObmOptions( + native_serialization_platforms=['cpu', 'tpu'] + ), + ) + processor.prepare( + jax.ShapeDtypeStruct((2, 3), jnp.float32), + ) + + self.assertIsNotNone(processor.obm_function) + self.assertEqual( + processor.obm_function.lowering_platforms, # pytype: disable=attribute-error + ('cpu', 'tpu'), + ) + if __name__ == '__main__': googletest.main()