From 0e7f21640a0000f305d5028e41b4728215207702 Mon Sep 17 00:00:00 2001 From: Elia LIU Date: Wed, 25 Mar 2026 03:21:35 +0000 Subject: [PATCH 1/4] [Integration] Expose length-aware batching in all ModelHandler subclasses Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic --- .../apache_beam/ml/inference/base_test.py | 261 ++++++++++++++++++ .../ml/inference/gemini_inference.py | 10 + .../ml/inference/huggingface_inference.py | 24 ++ .../ml/inference/onnx_inference.py | 8 + .../ml/inference/pytorch_inference.py | 16 ++ .../ml/inference/sklearn_inference.py | 16 ++ .../ml/inference/tensorflow_inference.py | 16 ++ .../ml/inference/tensorrt_inference.py | 8 + .../ml/inference/vertex_ai_inference.py | 10 + .../ml/inference/vllm_inference.py | 24 +- .../ml/inference/xgboost_inference.py | 8 + 11 files changed, 397 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index f25316f474fe..94edfe626895 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -16,6 +16,8 @@ # """Tests for apache_beam.ml.base.""" +import contextlib +import importlib import math import multiprocessing import os @@ -25,6 +27,7 @@ import tempfile import time import unittest +import unittest.mock from collections.abc import Iterable from collections.abc import Mapping from collections.abc import Sequence @@ -2319,6 +2322,264 @@ def test_batching_kwargs_none_values_omitted(self): self.assertEqual(kwargs['min_batch_size'], 5) +class PaddingReportingStringModelHandler(base.ModelHandler[str, str, + FakeModel]): + """Reports each element with the max length of the batch it ran in.""" + def load_model(self): + return FakeModel() + + def run_inference(self, batch, model, inference_args=None): + max_len = max(len(s) for s in batch) + return [f'{s}:{max_len}' for s in batch] + + +class RunInferenceLengthAwareBatchingTest(unittest.TestCase): + """End-to-end tests for PR2 length-aware batching in RunInference.""" + def test_run_inference_with_length_aware_batch_elements(self): + handler = PaddingReportingStringModelHandler( + min_batch_size=2, + max_batch_size=2, + max_batch_duration_secs=60, + batch_length_fn=len, + batch_bucket_boundaries=[5]) + + examples = ['a', 'cccccc', 'bb', 'ddddddd'] + with TestPipeline('FnApiRunner') as p: + results = ( + p + | beam.Create(examples, reshuffle=False) + | base.RunInference(handler)) + assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7'])) + + +class HandlerBucketingKwargsForwardingTest(unittest.TestCase): + """Verify each concrete ModelHandler forwards batch_length_fn and + batch_bucket_boundaries through to batch_elements_kwargs().""" + _BUCKETING_KWARGS = { + 'batch_length_fn': len, + 'batch_bucket_boundaries': [32], + } + + def _assert_bucketing_kwargs_forwarded(self, handler): + kwargs = handler.batch_elements_kwargs() + self.assertIs(kwargs['length_fn'], len) + self.assertEqual(kwargs['bucket_boundaries'], [32]) + + def _load_handler_class(self, case): + try: + module = importlib.import_module(case['module_name']) + except ImportError: + raise unittest.SkipTest(case['skip_message']) + return getattr(module, case['class_name']) + + @contextlib.contextmanager + def _handler_setup(self, case): + if not case.get('mock_aiplatform'): + yield + return + + with unittest.mock.patch( + 'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip: + mock_aip.init.return_value = None + mock_endpoint = unittest.mock.MagicMock() + mock_endpoint.list_models.return_value = ['fake-model'] + mock_aip.Endpoint.return_value = mock_endpoint + yield + + def _assert_handler_cases(self, cases): + for case in cases: + with self.subTest(handler=case['name']): + handler_cls = self._load_handler_class(case) + init_kwargs = dict(case['init_kwargs']) + init_kwargs.update(self._BUCKETING_KWARGS) + + with self._handler_setup(case): + handler = handler_cls(**init_kwargs) + + self._assert_bucketing_kwargs_forwarded(handler) + + def test_pytorch_handlers(self): + self._assert_handler_cases(( + { + 'name': 'pytorch_tensor', + 'module_name': 'apache_beam.ml.inference.pytorch_inference', + 'class_name': 'PytorchModelHandlerTensor', + 'skip_message': 'PyTorch not available', + 'init_kwargs': {}, + }, + { + 'name': 'pytorch_keyed_tensor', + 'module_name': 'apache_beam.ml.inference.pytorch_inference', + 'class_name': 'PytorchModelHandlerKeyedTensor', + 'skip_message': 'PyTorch not available', + 'init_kwargs': {}, + }, + )) + + def test_huggingface_handlers(self): + self._assert_handler_cases(( + { + 'name': 'huggingface_keyed_tensor', + 'module_name': 'apache_beam.ml.inference.huggingface_inference', + 'class_name': 'HuggingFaceModelHandlerKeyedTensor', + 'skip_message': 'HuggingFace transformers not available', + 'init_kwargs': { + 'model_uri': 'unused', + 'model_class': object, + 'framework': 'pt', + }, + }, + { + 'name': 'huggingface_tensor', + 'module_name': 'apache_beam.ml.inference.huggingface_inference', + 'class_name': 'HuggingFaceModelHandlerTensor', + 'skip_message': 'HuggingFace transformers not available', + 'init_kwargs': { + 'model_uri': 'unused', + 'model_class': object, + }, + }, + { + 'name': 'huggingface_pipeline', + 'module_name': 'apache_beam.ml.inference.huggingface_inference', + 'class_name': 'HuggingFacePipelineModelHandler', + 'skip_message': 'HuggingFace transformers not available', + 'init_kwargs': { + 'task': 'text-classification', + }, + }, + )) + + def test_sklearn_handlers(self): + self._assert_handler_cases(( + { + 'name': 'sklearn_numpy', + 'module_name': 'apache_beam.ml.inference.sklearn_inference', + 'class_name': 'SklearnModelHandlerNumpy', + 'skip_message': 'scikit-learn not available', + 'init_kwargs': { + 'model_uri': 'unused', + }, + }, + { + 'name': 'sklearn_pandas', + 'module_name': 'apache_beam.ml.inference.sklearn_inference', + 'class_name': 'SklearnModelHandlerPandas', + 'skip_message': 'scikit-learn not available', + 'init_kwargs': { + 'model_uri': 'unused', + }, + }, + )) + + def test_tensorflow_handlers(self): + self._assert_handler_cases(( + { + 'name': 'tensorflow_numpy', + 'module_name': 'apache_beam.ml.inference.tensorflow_inference', + 'class_name': 'TFModelHandlerNumpy', + 'skip_message': 'TensorFlow not available', + 'init_kwargs': { + 'model_uri': 'unused', + }, + }, + { + 'name': 'tensorflow_tensor', + 'module_name': 'apache_beam.ml.inference.tensorflow_inference', + 'class_name': 'TFModelHandlerTensor', + 'skip_message': 'TensorFlow not available', + 'init_kwargs': { + 'model_uri': 'unused', + }, + }, + )) + + def test_onnx_handler(self): + self._assert_handler_cases(({ + 'name': 'onnx_numpy', + 'module_name': 'apache_beam.ml.inference.onnx_inference', + 'class_name': 'OnnxModelHandlerNumpy', + 'skip_message': 'ONNX Runtime not available', + 'init_kwargs': { + 'model_uri': 'unused', + }, + }, )) + + def test_xgboost_handler(self): + self._assert_handler_cases(({ + 'name': 'xgboost_numpy', + 'module_name': 'apache_beam.ml.inference.xgboost_inference', + 'class_name': 'XGBoostModelHandlerNumpy', + 'skip_message': 'XGBoost not available', + 'init_kwargs': { + 'model_class': object, + 'model_state': 'unused', + }, + }, )) + + def test_tensorrt_handler(self): + self._assert_handler_cases(({ + 'name': 'tensorrt_numpy', + 'module_name': 'apache_beam.ml.inference.tensorrt_inference', + 'class_name': 'TensorRTEngineHandlerNumPy', + 'skip_message': 'TensorRT not available', + 'init_kwargs': { + 'min_batch_size': 1, + 'max_batch_size': 8, + }, + }, )) + + def test_vllm_handlers(self): + self._assert_handler_cases(( + { + 'name': 'vllm_completions', + 'module_name': 'apache_beam.ml.inference.vllm_inference', + 'class_name': 'VLLMCompletionsModelHandler', + 'skip_message': 'vLLM not available', + 'init_kwargs': { + 'model_name': 'unused', + }, + }, + { + 'name': 'vllm_chat', + 'module_name': 'apache_beam.ml.inference.vllm_inference', + 'class_name': 'VLLMChatModelHandler', + 'skip_message': 'vLLM not available', + 'init_kwargs': { + 'model_name': 'unused', + }, + }, + )) + + def test_vertex_ai_handler(self): + self._assert_handler_cases(({ + 'name': 'vertex_ai', + 'module_name': 'apache_beam.ml.inference.vertex_ai_inference', + 'class_name': 'VertexAIModelHandlerJSON', + 'skip_message': 'Vertex AI SDK not available', + 'init_kwargs': { + 'endpoint_id': 'unused', + 'project': 'unused', + 'location': 'unused', + }, + 'mock_aiplatform': True, + }, )) + + def test_gemini_handler(self): + self._assert_handler_cases(({ + 'name': 'gemini', + 'module_name': 'apache_beam.ml.inference.gemini_inference', + 'class_name': 'GeminiModelHandler', + 'skip_message': 'Google GenAI SDK not available', + 'init_kwargs': { + 'model_name': 'unused', + 'request_fn': lambda *args: None, + 'project': 'unused', + 'location': 'unused', + }, + }, )) + + class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): def load_model(self): return FakeModel() diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index 2ba220d0162a..9d8cf685f79d 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -117,6 +117,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Google Gemini. **NOTE:** This API and its implementation are under development and @@ -158,6 +160,10 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ self._batching_kwargs = {} self._env_vars = kwargs.get('env_vars', {}) @@ -171,6 +177,10 @@ def __init__( self._batching_kwargs["max_batch_weight"] = max_batch_weight if element_size_fn is not None: self._batching_kwargs['element_size_fn'] = element_size_fn + if batch_length_fn is not None: + self._batching_kwargs['length_fn'] = batch_length_fn + if batch_bucket_boundaries is not None: + self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries self.model_name = model_name self.request_fn = request_fn diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 2c1f5e2cc908..a9893ea9290c 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -229,6 +229,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for HuggingFace with @@ -266,6 +268,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -278,6 +284,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -411,6 +419,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for HuggingFace with @@ -448,6 +458,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -460,6 +474,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -576,6 +592,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for Hugging Face Pipelines. @@ -621,6 +639,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -633,6 +655,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index 4423eed2e407..c525abbff45e 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -68,6 +68,8 @@ def __init__( #pylint: disable=dangerous-default-value max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for onnx using numpy arrays as input. @@ -94,6 +96,10 @@ def __init__( #pylint: disable=dangerous-default-value before emitting; used in streaming contexts. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ @@ -103,6 +109,8 @@ def __init__( #pylint: disable=dangerous-default-value max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 63c2a116fcc9..8dc4b5c43778 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -199,6 +199,8 @@ def __init__( load_model_args: Optional[dict[str, Any]] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for PyTorch. @@ -244,6 +246,10 @@ def __init__( function to specify custom config for loading models. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -256,6 +262,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -431,6 +439,8 @@ def __init__( load_model_args: Optional[dict[str, Any]] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for PyTorch. @@ -481,6 +491,10 @@ def __init__( function to specify custom config for loading models. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -493,6 +507,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index e61ef9c194aa..8d02cacaac34 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -95,6 +95,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for scikit-learn using numpy arrays as input. @@ -126,6 +128,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ @@ -135,6 +141,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -224,6 +232,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for scikit-learn that supports pandas dataframes. @@ -258,6 +268,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ @@ -267,6 +281,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index 97b74eb360a7..3563b1113ba4 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -114,6 +114,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -145,6 +147,10 @@ def __init__( max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -157,6 +163,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -242,6 +250,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -278,6 +288,10 @@ def __init__( max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -290,6 +304,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index 00a61b4934aa..333187301b29 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -232,6 +232,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for TensorRT. @@ -262,6 +264,10 @@ def __init__( a batch before emitting; used in streaming contexts. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: Additional arguments like 'engine_path' and 'onnx_path' are currently supported. 'env_vars' can be used to set environment variables before loading the model. @@ -275,6 +281,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 02827f9578f1..7757146ab8dc 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -72,6 +72,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Vertex AI. **NOTE:** This API and its implementation are under development and @@ -115,6 +117,10 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ self._batching_kwargs = {} self._env_vars = kwargs.get('env_vars', {}) @@ -129,6 +135,10 @@ def __init__( self._batching_kwargs["max_batch_weight"] = max_batch_weight if element_size_fn is not None: self._batching_kwargs['element_size_fn'] = element_size_fn + if batch_length_fn is not None: + self._batching_kwargs['length_fn'] = batch_length_fn + if batch_bucket_boundaries is not None: + self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries if private and network is None: raise ValueError( diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 918b49155606..5a982f5cc55b 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -182,7 +182,9 @@ def __init__( max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, - element_size_fn: Optional[Callable[[Any], int]] = None): + element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None): """Implementation of the ModelHandler interface for vLLM using text as input. @@ -210,13 +212,19 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ super().__init__( min_batch_size=min_batch_size, max_batch_size=max_batch_size, max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, - element_size_fn=element_size_fn) + element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries) self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} @@ -280,7 +288,9 @@ def __init__( max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, - element_size_fn: Optional[Callable[[Any], int]] = None): + element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None): """ Implementation of the ModelHandler interface for vLLM using previous messages as input. @@ -313,13 +323,19 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ super().__init__( min_batch_size=min_batch_size, max_batch_size=max_batch_size, max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, - element_size_fn=element_size_fn) + element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries) self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} self._chat_template_path = chat_template_path diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference.py b/sdks/python/apache_beam/ml/inference/xgboost_inference.py index 9d7413685113..f11b892062e8 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference.py @@ -80,6 +80,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for XGBoost. @@ -109,6 +111,10 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -131,6 +137,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, **kwargs) self._model_class = model_class self._model_state = model_state From 1bb2a3f3faf0475d0ae36a3fa1bf478228393380 Mon Sep 17 00:00:00 2001 From: Elia LIU Date: Wed, 25 Mar 2026 03:21:35 +0000 Subject: [PATCH 2/4] [Integration] Expose length-aware batching in all ModelHandler subclasses Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic From 78c936d11d7e80cccceb158ff5d6827ea2a14a72 Mon Sep 17 00:00:00 2001 From: Elia LIU Date: Wed, 25 Mar 2026 03:21:35 +0000 Subject: [PATCH 3/4] [Integration] Expose length-aware batching in all ModelHandler subclasses Completes the smart bucketing feature (#37531) by exposing batch_length_fn and batch_bucket_boundaries parameters across all concrete ModelHandler implementations. This allows users to enable length-aware batching on supported inference backends by passing these parameters directly to the handler constructor. - adds batch_length_fn / batch_bucket_boundaries to 16 handler classes - wires Gemini and Vertex AI batching params into _batching_kwargs - adds end-to-end RunInference coverage for length-aware batching - adds per-handler forwarding regression tests and fixes them to be hermetic From 82c280bebe933d669206986e158631cc210f097c Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Tue, 31 Mar 2026 02:00:03 +1100 Subject: [PATCH 4/4] Remove redundant bucketing forwarding tests --- .../apache_beam/ml/inference/base_test.py | 230 ------------------ 1 file changed, 230 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 94edfe626895..cad50b1e8dea 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -16,8 +16,6 @@ # """Tests for apache_beam.ml.base.""" -import contextlib -import importlib import math import multiprocessing import os @@ -2352,234 +2350,6 @@ def test_run_inference_with_length_aware_batch_elements(self): assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7'])) -class HandlerBucketingKwargsForwardingTest(unittest.TestCase): - """Verify each concrete ModelHandler forwards batch_length_fn and - batch_bucket_boundaries through to batch_elements_kwargs().""" - _BUCKETING_KWARGS = { - 'batch_length_fn': len, - 'batch_bucket_boundaries': [32], - } - - def _assert_bucketing_kwargs_forwarded(self, handler): - kwargs = handler.batch_elements_kwargs() - self.assertIs(kwargs['length_fn'], len) - self.assertEqual(kwargs['bucket_boundaries'], [32]) - - def _load_handler_class(self, case): - try: - module = importlib.import_module(case['module_name']) - except ImportError: - raise unittest.SkipTest(case['skip_message']) - return getattr(module, case['class_name']) - - @contextlib.contextmanager - def _handler_setup(self, case): - if not case.get('mock_aiplatform'): - yield - return - - with unittest.mock.patch( - 'apache_beam.ml.inference.vertex_ai_inference.aiplatform') as mock_aip: - mock_aip.init.return_value = None - mock_endpoint = unittest.mock.MagicMock() - mock_endpoint.list_models.return_value = ['fake-model'] - mock_aip.Endpoint.return_value = mock_endpoint - yield - - def _assert_handler_cases(self, cases): - for case in cases: - with self.subTest(handler=case['name']): - handler_cls = self._load_handler_class(case) - init_kwargs = dict(case['init_kwargs']) - init_kwargs.update(self._BUCKETING_KWARGS) - - with self._handler_setup(case): - handler = handler_cls(**init_kwargs) - - self._assert_bucketing_kwargs_forwarded(handler) - - def test_pytorch_handlers(self): - self._assert_handler_cases(( - { - 'name': 'pytorch_tensor', - 'module_name': 'apache_beam.ml.inference.pytorch_inference', - 'class_name': 'PytorchModelHandlerTensor', - 'skip_message': 'PyTorch not available', - 'init_kwargs': {}, - }, - { - 'name': 'pytorch_keyed_tensor', - 'module_name': 'apache_beam.ml.inference.pytorch_inference', - 'class_name': 'PytorchModelHandlerKeyedTensor', - 'skip_message': 'PyTorch not available', - 'init_kwargs': {}, - }, - )) - - def test_huggingface_handlers(self): - self._assert_handler_cases(( - { - 'name': 'huggingface_keyed_tensor', - 'module_name': 'apache_beam.ml.inference.huggingface_inference', - 'class_name': 'HuggingFaceModelHandlerKeyedTensor', - 'skip_message': 'HuggingFace transformers not available', - 'init_kwargs': { - 'model_uri': 'unused', - 'model_class': object, - 'framework': 'pt', - }, - }, - { - 'name': 'huggingface_tensor', - 'module_name': 'apache_beam.ml.inference.huggingface_inference', - 'class_name': 'HuggingFaceModelHandlerTensor', - 'skip_message': 'HuggingFace transformers not available', - 'init_kwargs': { - 'model_uri': 'unused', - 'model_class': object, - }, - }, - { - 'name': 'huggingface_pipeline', - 'module_name': 'apache_beam.ml.inference.huggingface_inference', - 'class_name': 'HuggingFacePipelineModelHandler', - 'skip_message': 'HuggingFace transformers not available', - 'init_kwargs': { - 'task': 'text-classification', - }, - }, - )) - - def test_sklearn_handlers(self): - self._assert_handler_cases(( - { - 'name': 'sklearn_numpy', - 'module_name': 'apache_beam.ml.inference.sklearn_inference', - 'class_name': 'SklearnModelHandlerNumpy', - 'skip_message': 'scikit-learn not available', - 'init_kwargs': { - 'model_uri': 'unused', - }, - }, - { - 'name': 'sklearn_pandas', - 'module_name': 'apache_beam.ml.inference.sklearn_inference', - 'class_name': 'SklearnModelHandlerPandas', - 'skip_message': 'scikit-learn not available', - 'init_kwargs': { - 'model_uri': 'unused', - }, - }, - )) - - def test_tensorflow_handlers(self): - self._assert_handler_cases(( - { - 'name': 'tensorflow_numpy', - 'module_name': 'apache_beam.ml.inference.tensorflow_inference', - 'class_name': 'TFModelHandlerNumpy', - 'skip_message': 'TensorFlow not available', - 'init_kwargs': { - 'model_uri': 'unused', - }, - }, - { - 'name': 'tensorflow_tensor', - 'module_name': 'apache_beam.ml.inference.tensorflow_inference', - 'class_name': 'TFModelHandlerTensor', - 'skip_message': 'TensorFlow not available', - 'init_kwargs': { - 'model_uri': 'unused', - }, - }, - )) - - def test_onnx_handler(self): - self._assert_handler_cases(({ - 'name': 'onnx_numpy', - 'module_name': 'apache_beam.ml.inference.onnx_inference', - 'class_name': 'OnnxModelHandlerNumpy', - 'skip_message': 'ONNX Runtime not available', - 'init_kwargs': { - 'model_uri': 'unused', - }, - }, )) - - def test_xgboost_handler(self): - self._assert_handler_cases(({ - 'name': 'xgboost_numpy', - 'module_name': 'apache_beam.ml.inference.xgboost_inference', - 'class_name': 'XGBoostModelHandlerNumpy', - 'skip_message': 'XGBoost not available', - 'init_kwargs': { - 'model_class': object, - 'model_state': 'unused', - }, - }, )) - - def test_tensorrt_handler(self): - self._assert_handler_cases(({ - 'name': 'tensorrt_numpy', - 'module_name': 'apache_beam.ml.inference.tensorrt_inference', - 'class_name': 'TensorRTEngineHandlerNumPy', - 'skip_message': 'TensorRT not available', - 'init_kwargs': { - 'min_batch_size': 1, - 'max_batch_size': 8, - }, - }, )) - - def test_vllm_handlers(self): - self._assert_handler_cases(( - { - 'name': 'vllm_completions', - 'module_name': 'apache_beam.ml.inference.vllm_inference', - 'class_name': 'VLLMCompletionsModelHandler', - 'skip_message': 'vLLM not available', - 'init_kwargs': { - 'model_name': 'unused', - }, - }, - { - 'name': 'vllm_chat', - 'module_name': 'apache_beam.ml.inference.vllm_inference', - 'class_name': 'VLLMChatModelHandler', - 'skip_message': 'vLLM not available', - 'init_kwargs': { - 'model_name': 'unused', - }, - }, - )) - - def test_vertex_ai_handler(self): - self._assert_handler_cases(({ - 'name': 'vertex_ai', - 'module_name': 'apache_beam.ml.inference.vertex_ai_inference', - 'class_name': 'VertexAIModelHandlerJSON', - 'skip_message': 'Vertex AI SDK not available', - 'init_kwargs': { - 'endpoint_id': 'unused', - 'project': 'unused', - 'location': 'unused', - }, - 'mock_aiplatform': True, - }, )) - - def test_gemini_handler(self): - self._assert_handler_cases(({ - 'name': 'gemini', - 'module_name': 'apache_beam.ml.inference.gemini_inference', - 'class_name': 'GeminiModelHandler', - 'skip_message': 'Google GenAI SDK not available', - 'init_kwargs': { - 'model_name': 'unused', - 'request_fn': lambda *args: None, - 'project': 'unused', - 'location': 'unused', - }, - }, )) - - class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): def load_model(self): return FakeModel()