diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 014a88527..41784544e 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -376,6 +376,7 @@ def __init__( ), enable_pinned_host_transfer: Optional[bool] = None, is_prioritized_key_fn: Optional[types.IsPrioritizedKeyFn] = None, + use_non_atomic_file_io_locking: bool = True, ): """Creates BasePyTreeCheckpointHandler. @@ -420,6 +421,8 @@ def __init__( not prioritized. Note that any "prioritized" keys are assumed to be lightweight, and `save_device_host_concurrent_gb` will be ignored for them. + use_non_atomic_file_io_locking: If True, enables non-atomic file I/O + locking mode for TensorStore OCDBT data files. """ self._save_concurrent_bytes = save_concurrent_bytes self._restore_concurrent_bytes = restore_concurrent_bytes @@ -449,6 +452,7 @@ def __init__( self._use_ocdbt = use_ocdbt self._use_zarr3 = use_zarr3 self._use_compression = use_compression + self._use_non_atomic_file_io_locking = use_non_atomic_file_io_locking self._primary_host = multiprocessing_options.primary_host self._type_handler_registry = type_handler_registry self._enable_post_merge_validation = enable_post_merge_validation @@ -506,6 +510,7 @@ def _get_param_infos( byte_limiter: Optional[limits.ByteLimiter] = None, device_host_byte_limiter: Optional[limits.ByteLimiter] = None, raise_array_data_missing_error: bool = True, + use_non_atomic_file_io_locking: bool = True, ) -> PyTree: """Returns parameter information for elements in `item`. @@ -523,6 +528,8 @@ def _get_param_infos( byte_limiter: ByteLimiter object. device_host_byte_limiter: ByteLimiter object for device-to-host transfer. raise_array_data_missing_error: See documentation in ParamInfo. + use_non_atomic_file_io_locking: If True, enables non-atomic file I/O + locking mode for TensorStore OCDBT data files. Returns: A PyTree matching `item` of ParamInfo. @@ -557,6 +564,7 @@ def _param_info(keypath, name, value): ), raise_array_data_missing_error=raise_array_data_missing_error, is_prioritized_key_fn=self._is_prioritized_key_fn, + use_non_atomic_file_io_locking=use_non_atomic_file_io_locking, ) return jax.tree.map_with_path( @@ -718,6 +726,7 @@ async def async_save( device_host_byte_limiter=device_host_byte_limiter, use_compression=self._use_compression, use_zarr3=self._use_zarr3, + use_non_atomic_file_io_locking=self._use_non_atomic_file_io_locking, ) # TODO(b/425293362): Add validation for PathAwaitingCreation. if isinstance(directory, epath.Path): @@ -1147,6 +1156,7 @@ class TrainState: use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, raise_array_data_missing_error=raise_array_data_missing_error, + use_non_atomic_file_io_locking=self._use_non_atomic_file_io_locking, ) # Begin restore. tree_memory_size, restored_item = asyncio_utils.run_sync( diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index 0c5de7e4f..802b2af45 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -522,6 +522,7 @@ def __init__( is_prioritized_key_fn: Optional[ serialization_types.IsPrioritizedKeyFn ] = None, + use_non_atomic_file_io_locking: bool = True, ): """Creates PyTreeCheckpointHandler. @@ -576,6 +577,10 @@ def __init__( not prioritized. Note that any "prioritized" keys are assumed to be lightweight, and `save_device_host_concurrent_gb` will be ignored for them. + use_non_atomic_file_io_locking: If True, enables non-atomic file I/O + locking mode for TensorStore OCDBT data files. This can improve + performance on filesystems like GCSFuse by avoiding expensive renames. + Defaults to True. """ self._aggregate_handler = MsgpackHandler( @@ -612,6 +617,7 @@ def __init__( array_metadata_validator=array_metadata_validator, enable_pinned_host_transfer=enable_pinned_host_transfer, is_prioritized_key_fn=is_prioritized_key_fn, + use_non_atomic_file_io_locking=use_non_atomic_file_io_locking, ) self._pytree_metadata_options = pytree_metadata_options diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_metadata_test.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_metadata_test.py index 8b59d1013..c298ea145 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_metadata_test.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_metadata_test.py @@ -37,7 +37,11 @@ class PyTreeMetadataTest(parameterized.TestCase): @parameterized.product(use_ocdbt=(True, False), use_zarr3=(True, False)) def test_metadata_properties(self, use_ocdbt: bool, use_zarr3: bool): directory = epath.Path(self.create_tempdir().full_path) - handler = PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3) + handler = PyTreeCheckpointHandler( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_non_atomic_file_io_locking=False, + ) item = {'a': np.array([1, 2, 3]), 'b': {'c': 'test'}} custom_metadata = {'key1': 'value1', 'key2': 123} handler.save( @@ -60,7 +64,11 @@ def test_as_custom_metadata(self, use_ocdbt: bool, use_zarr3: bool): directory = epath.Path(self.create_tempdir().full_path) item = {'a': np.array([1, 2, 3]), 'b': {'c': 'test'}} custom_metadata = {'key1': 'value1', 'key2': 123} - handler = PyTreeCheckpointHandler(use_ocdbt=use_ocdbt, use_zarr3=use_zarr3) + handler = PyTreeCheckpointHandler( + use_ocdbt=use_ocdbt, + use_zarr3=use_zarr3, + use_non_atomic_file_io_locking=False, + ) handler.save( directory, args=PyTreeSaveArgs(item, custom_metadata=custom_metadata), diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py index 3ac69ebc6..53368b381 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py @@ -143,6 +143,7 @@ def build_kvstore_tspec( use_ocdbt: bool = True, process_id: int | str | None = None, replica_separate_folder: bool = False, + use_non_atomic_file_io_locking: bool = True, ) -> JsonSpec: """Constructs a spec for a Tensorstore KvStore. @@ -155,6 +156,8 @@ def build_kvstore_tspec( `{directory}/ocdbt.process_{process_id}` path is used as the base path. If a string, must conform to [A-Za-z0-9]+ pattern. replica_separate_folder: Whether a replica separated folder is used. + use_non_atomic_file_io_locking: If True, enables non-atomic file I/O + locking mode for TensorStore OCDBT data files. Returns: A Tensorstore KvStore spec in dictionary form. @@ -194,6 +197,26 @@ def build_kvstore_tspec( 'driver': 'ocdbt', 'base': base_driver_spec, }) + # For OCDBT on local filesystems (including GCSFuse), we can safely use + # non-atomic writes for data files to avoid expensive renames. However, + # the manifest file still requires atomic writes to avoid corruption. + # We achieve this by splitting the spec into 'base' (for data files) and + # 'manifest'. Direct GCS paths ('gs://') do not need this optimization + # as they don't use the file driver. + is_tree_verity_path = str(directory).startswith('/tree_verity/') + is_remote = is_remote_storage(base_driver_spec) + if ( + use_non_atomic_file_io_locking + and not is_gcs_path + and not is_tree_verity_path + and not is_remote + ): + base_spec = copy.deepcopy(base_driver_spec) + if isinstance(base_spec, dict): + base_spec['driver'] = 'file' # Force 'file' driver + base_spec['file_io_locking'] = {'mode': 'non_atomic'} + kv_spec['base'] = base_spec + kv_spec['manifest'] = base_driver_spec if name is not None: kv_spec['path'] = name @@ -503,6 +526,7 @@ def __init__( metadata_key: str | None = None, replica_separate_folder: bool = False, ext_metadata: ExtMetadata | None = None, + use_non_atomic_file_io_locking: bool = True, ): """Builds a TensorStore spec for writing an array.""" # Construct the underlying KvStore spec. @@ -512,6 +536,7 @@ def __init__( use_ocdbt=use_ocdbt, process_id=process_id, replica_separate_folder=replica_separate_folder, + use_non_atomic_file_io_locking=use_non_atomic_file_io_locking, ) # Construct the top-level array spec. tspec = { @@ -684,6 +709,7 @@ def _get_json_tspec( name=info.name, use_ocdbt=use_ocdbt, process_id=process_index, + use_non_atomic_file_io_locking=info.use_non_atomic_file_io_locking, ) tspec = { @@ -803,6 +829,7 @@ def build_array_write_spec( replica_separate_folder: bool = False, metadata_key: str | None = None, ext_metadata: dict[str, Any] | None = None, + use_non_atomic_file_io_locking: bool = True, ) -> ArrayWriteSpec: """Gets ArrayWriteSpec for writing.""" if info.name is None or info.parent_dir is None: @@ -828,6 +855,7 @@ def build_array_write_spec( ocdbt_target_data_file_size=info.ocdbt_target_data_file_size, metadata_key=metadata_key, ext_metadata=ext_metadata, + use_non_atomic_file_io_locking=use_non_atomic_file_io_locking, ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py index 516dcef3e..5ee28e510 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py @@ -245,6 +245,7 @@ def test_ocdbt_kvstore( use_zarr3=False, use_ocdbt=True, process_id=13, + use_non_atomic_file_io_locking=False, ) self.assertTrue(tspec.metadata.use_ocdbt) json_tspec = tspec.json @@ -257,6 +258,49 @@ def test_ocdbt_kvstore( ) self.assertEqual(json_tspec['kvstore']['path'], self.param_name) + def test_ocdbt_kvstore_with_non_atomic_locking(self): + tspec = self.array_write_spec_constructor( + directory=self.directory, + relative_array_filename=self.param_name, + use_zarr3=False, + use_ocdbt=True, + process_id=13, + ) + self.assertTrue(tspec.metadata.use_ocdbt) + json_tspec = tspec.json + kvstore_tspec = json_tspec['kvstore'] + self.assertEqual(kvstore_tspec['driver'], 'ocdbt') + + # Base spec should have non_atomic locking + base_spec = kvstore_tspec['base'] + self.assertEqual(base_spec['driver'], 'file') + self.assertEqual(base_spec['file_io_locking'], {'mode': 'non_atomic'}) + + # Manifest spec should be present and NOT have non_atomic locking + self.assertIn('manifest', kvstore_tspec) + manifest_spec = kvstore_tspec['manifest'] + self.assertEqual(manifest_spec['driver'], ts_utils.DEFAULT_DRIVER) + self.assertNotIn('file_io_locking', manifest_spec) + + def test_ocdbt_kvstore_with_non_atomic_locking_gcs_path(self): + tspec = self.array_write_spec_constructor( + directory='gs://gcs_bucket/object_path', + relative_array_filename=self.param_name, + use_zarr3=False, + use_ocdbt=True, + process_id=0, + ) + self.assertTrue(tspec.metadata.use_ocdbt) + kvstore_tspec = tspec.json['kvstore'] + self.assertEqual(kvstore_tspec['driver'], 'ocdbt') + + # Base should be a string (URL) and NOT have file_io_locking + self.assertIsInstance(kvstore_tspec['base'], str) + + # Manifest should NOT be present in the spec + self.assertNotIn('manifest', kvstore_tspec) + + @parameterized.named_parameters( dict( testcase_name='regular_path', diff --git a/checkpoint/orbax/checkpoint/_src/serialization/types.py b/checkpoint/orbax/checkpoint/_src/serialization/types.py index bd6146006..b0194c667 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/types.py @@ -99,6 +99,8 @@ class ParamInfo: write_shape: Shape of the array shard. Used in the subchunking context. is_prioritized_key_fn: See ``IsPrioritizedKeyFn`` definition. keypath: Tuple of keys identifying the parameter's position in the PyTree. + use_non_atomic_file_io_locking: If True, enables non-atomic file I/O + locking mode for TensorStore OCDBT data files. """ def __init__( @@ -121,6 +123,7 @@ def __init__( raise_array_data_missing_error: bool = True, write_shape: arrays_types.Shape | None = None, is_prioritized_key_fn: Optional[IsPrioritizedKeyFn] = None, + use_non_atomic_file_io_locking: bool = True, ): self.name = name self._parent_dir = parent_dir @@ -139,6 +142,7 @@ def __init__( self.raise_array_data_missing_error = raise_array_data_missing_error self.write_shape = write_shape self.is_prioritized_key_fn = is_prioritized_key_fn + self.use_non_atomic_file_io_locking = use_non_atomic_file_io_locking @property def parent_dir(self) -> epath.Path: diff --git a/checkpoint/orbax/checkpoint/options.py b/checkpoint/orbax/checkpoint/options.py index a3f636319..da8703732 100644 --- a/checkpoint/orbax/checkpoint/options.py +++ b/checkpoint/orbax/checkpoint/options.py @@ -72,6 +72,7 @@ class FileOptions: """ path_permission_mode: int | None = None + use_non_atomic_file_io_locking: bool = True @dataclasses.dataclass