Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,8 @@ def __init__(self, namespace: str, prefix: str = ''):
self._load_model_latency_milli_secs = beam.metrics.Metrics.distribution(
namespace, prefix + 'load_model_latency_milli_secs')

# Metrics cache
# Model load can happen during setup(), before bundle-scoped metric updates
# are tracked. Cache those values and flush them once finish_bundle() runs.
self._load_model_latency_milli_secs_cache = None
self._model_byte_size_cache = None

Expand Down Expand Up @@ -2133,8 +2134,8 @@ def process(
return self._run_inference(batch, inference_args)

def finish_bundle(self):
# TODO(https://github.com/apache/beam/issues/21435): Figure out why there
# is a cache.
# setup() may load the model before bundle-scoped metrics are active, so
# flush the cached model load metrics once the bundle lifecycle is running.
Comment on lines +2137 to +2138
if self._metrics_collector:
self._metrics_collector.update_metrics_with_cache()

Expand Down
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,25 @@ def test_timing_metrics(self):
self.assertEqual(load_model_latency.result.count, 1)
self.assertEqual(load_model_latency.result.mean, 500)

def test_setup_caches_model_load_metrics_until_finish_bundle(self):
fake_clock = FakeClock()
dofn = base._RunInferenceDoFn(
FakeModelHandler(clock=fake_clock), fake_clock, None)
metrics_collector = unittest.mock.Mock(spec=base._MetricsCollector)
dofn.get_metrics_collector = unittest.mock.Mock(
return_value=metrics_collector)

with unittest.mock.patch.object(
base, '_get_current_process_memory_in_bytes', side_effect=[100, 125]):
dofn.setup()
Comment on lines +1055 to +1057

metrics_collector.cache_load_model_metrics.assert_called_once_with(500, 25)
metrics_collector.update_metrics_with_cache.assert_not_called()

dofn.finish_bundle()

metrics_collector.update_metrics_with_cache.assert_called_once_with()

def test_forwards_batch_args(self):
examples = list(range(100))
with TestPipeline('FnApiRunner') as pipeline:
Expand Down
Loading