From 7180dc3f46e0ee33782fc3e8c97b224ff3ea5280 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Thu, 23 Apr 2026 02:13:21 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 904321063 --- .../_src/serialization/tensorstore_utils.py | 12 ++++++ .../serialization/tensorstore_utils_test.py | 42 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py index 1013584d3..bff9a165c 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils.py @@ -193,6 +193,18 @@ def build_kvstore_tspec( 'driver': 'ocdbt', 'base': base_driver_spec, }) + # For OCDBT on local filesystems (including GCSFuse), we can safely use + # non-atomic writes for data files to avoid expensive renames. However, + # the manifest file still requires atomic writes to avoid corruption. + # We achieve this by splitting the spec into 'base' (for data files) and + # 'manifest'. Direct GCS paths ('gs://') do not need this optimization + # as they don't use the file driver. + if not is_gcs_path: + base_spec = copy.deepcopy(base_driver_spec) + if isinstance(base_spec, dict): + base_spec['file_io_locking'] = {'mode': 'non_atomic'} + kv_spec['base'] = base_spec + kv_spec['manifest'] = base_driver_spec if name is not None: kv_spec['path'] = name diff --git a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py index 71c3a9424..98ce5251c 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/tensorstore_utils_test.py @@ -199,6 +199,48 @@ def test_ocdbt_kvstore( ) self.assertEqual(json_tspec['kvstore']['path'], self.param_name) + def test_ocdbt_kvstore_with_non_atomic_locking(self): + tspec = self.array_write_spec_constructor( + directory=self.directory, + relative_array_filename=self.param_name, + use_zarr3=False, + use_ocdbt=True, + process_id=13, + ) + self.assertTrue(tspec.metadata.use_ocdbt) + json_tspec = tspec.json + kvstore_tspec = json_tspec['kvstore'] + self.assertEqual(kvstore_tspec['driver'], 'ocdbt') + + # Base spec should have non_atomic locking + base_spec = kvstore_tspec['base'] + self.assertEqual(base_spec['driver'], ts_utils.DEFAULT_DRIVER) + self.assertEqual(base_spec['file_io_locking'], {'mode': 'non_atomic'}) + + # Manifest spec should be present and NOT have non_atomic locking + self.assertIn('manifest', kvstore_tspec) + manifest_spec = kvstore_tspec['manifest'] + self.assertEqual(manifest_spec['driver'], ts_utils.DEFAULT_DRIVER) + self.assertNotIn('file_io_locking', manifest_spec) + + def test_ocdbt_kvstore_with_non_atomic_locking_gcs_path(self): + tspec = self.array_write_spec_constructor( + directory='gs://gcs_bucket/object_path', + relative_array_filename=self.param_name, + use_zarr3=False, + use_ocdbt=True, + process_id=0, + ) + self.assertTrue(tspec.metadata.use_ocdbt) + kvstore_tspec = tspec.json['kvstore'] + self.assertEqual(kvstore_tspec['driver'], 'ocdbt') + + # Base should be a string (URL) and NOT have file_io_locking + self.assertIsInstance(kvstore_tspec['base'], str) + + # Manifest should NOT be present in the spec + self.assertNotIn('manifest', kvstore_tspec) + @parameterized.named_parameters( dict( testcase_name='regular_path',