diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index f25316f474fe..cad50b1e8dea 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -25,6 +25,7 @@ import tempfile import time import unittest +import unittest.mock from collections.abc import Iterable from collections.abc import Mapping from collections.abc import Sequence @@ -2319,6 +2320,36 @@ def test_batching_kwargs_none_values_omitted(self): self.assertEqual(kwargs['min_batch_size'], 5) +class PaddingReportingStringModelHandler(base.ModelHandler[str, str, + FakeModel]): + """Reports each element with the max length of the batch it ran in.""" + def load_model(self): + return FakeModel() + + def run_inference(self, batch, model, inference_args=None): + max_len = max(len(s) for s in batch) + return [f'{s}:{max_len}' for s in batch] + + +class RunInferenceLengthAwareBatchingTest(unittest.TestCase): + """End-to-end tests for PR2 length-aware batching in RunInference.""" + def test_run_inference_with_length_aware_batch_elements(self): + handler = PaddingReportingStringModelHandler( + min_batch_size=2, + max_batch_size=2, + max_batch_duration_secs=60, + batch_length_fn=len, + batch_bucket_boundaries=[5]) + + examples = ['a', 'cccccc', 'bb', 'ddddddd'] + with TestPipeline('FnApiRunner') as p: + results = ( + p + | beam.Create(examples, reshuffle=False) + | base.RunInference(handler)) + assert_that(results, equal_to(['a:2', 'bb:2', 'cccccc:7', 'ddddddd:7'])) + + class SimpleFakeModelHandler(base.ModelHandler[int, int, FakeModel]): def load_model(self): return FakeModel() diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index 2ba220d0162a..9d8cf685f79d 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -117,6 +117,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Google Gemini. **NOTE:** This API and its implementation are under development and @@ -158,6 +160,10 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ self._batching_kwargs = {} self._env_vars = kwargs.get('env_vars', {}) @@ -171,6 +177,10 @@ def __init__( self._batching_kwargs["max_batch_weight"] = max_batch_weight if element_size_fn is not None: self._batching_kwargs['element_size_fn'] = element_size_fn + if batch_length_fn is not None: + self._batching_kwargs['length_fn'] = batch_length_fn + if batch_bucket_boundaries is not None: + self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries self.model_name = model_name self.request_fn = request_fn diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 2c1f5e2cc908..a9893ea9290c 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -229,6 +229,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for HuggingFace with @@ -266,6 +268,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -278,6 +284,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -411,6 +419,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for HuggingFace with @@ -448,6 +458,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -460,6 +474,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -576,6 +592,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for Hugging Face Pipelines. @@ -621,6 +639,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -633,6 +655,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index 4423eed2e407..c525abbff45e 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -68,6 +68,8 @@ def __init__( #pylint: disable=dangerous-default-value max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for onnx using numpy arrays as input. @@ -94,6 +96,10 @@ def __init__( #pylint: disable=dangerous-default-value before emitting; used in streaming contexts. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ @@ -103,6 +109,8 @@ def __init__( #pylint: disable=dangerous-default-value max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 63c2a116fcc9..8dc4b5c43778 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -199,6 +199,8 @@ def __init__( load_model_args: Optional[dict[str, Any]] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for PyTorch. @@ -244,6 +246,10 @@ def __init__( function to specify custom config for loading models. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -256,6 +262,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -431,6 +439,8 @@ def __init__( load_model_args: Optional[dict[str, Any]] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for PyTorch. @@ -481,6 +491,10 @@ def __init__( function to specify custom config for loading models. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -493,6 +507,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index e61ef9c194aa..8d02cacaac34 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -95,6 +95,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """ Implementation of the ModelHandler interface for scikit-learn using numpy arrays as input. @@ -126,6 +128,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ @@ -135,6 +141,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -224,6 +232,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for scikit-learn that supports pandas dataframes. @@ -258,6 +268,10 @@ def __init__( GPU capacity and want to maximize resource utilization. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ @@ -267,6 +281,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index 97b74eb360a7..3563b1113ba4 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -114,6 +114,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -145,6 +147,10 @@ def __init__( max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -157,6 +163,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) @@ -242,6 +250,8 @@ def __init__( model_copies: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -278,6 +288,10 @@ def __init__( max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -290,6 +304,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index 00a61b4934aa..333187301b29 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -232,6 +232,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for TensorRT. @@ -262,6 +264,10 @@ def __init__( a batch before emitting; used in streaming contexts. max_batch_weight: the maximum total weight of a batch. element_size_fn: a function that returns the size (weight) of an element. + batch_length_fn: a callable that returns the length of an element for + length-aware batching. + batch_bucket_boundaries: a sorted list of positive boundary values for + length-aware batching buckets. kwargs: Additional arguments like 'engine_path' and 'onnx_path' are currently supported. 'env_vars' can be used to set environment variables before loading the model. @@ -275,6 +281,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, large_model=large_model, model_copies=model_copies, **kwargs) diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 02827f9578f1..7757146ab8dc 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -72,6 +72,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for Vertex AI. **NOTE:** This API and its implementation are under development and @@ -115,6 +117,10 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ self._batching_kwargs = {} self._env_vars = kwargs.get('env_vars', {}) @@ -129,6 +135,10 @@ def __init__( self._batching_kwargs["max_batch_weight"] = max_batch_weight if element_size_fn is not None: self._batching_kwargs['element_size_fn'] = element_size_fn + if batch_length_fn is not None: + self._batching_kwargs['length_fn'] = batch_length_fn + if batch_bucket_boundaries is not None: + self._batching_kwargs['bucket_boundaries'] = batch_bucket_boundaries if private and network is None: raise ValueError( diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py index 918b49155606..5a982f5cc55b 100644 --- a/sdks/python/apache_beam/ml/inference/vllm_inference.py +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -182,7 +182,9 @@ def __init__( max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, - element_size_fn: Optional[Callable[[Any], int]] = None): + element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None): """Implementation of the ModelHandler interface for vLLM using text as input. @@ -210,13 +212,19 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ super().__init__( min_batch_size=min_batch_size, max_batch_size=max_batch_size, max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, - element_size_fn=element_size_fn) + element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries) self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} @@ -280,7 +288,9 @@ def __init__( max_batch_size: Optional[int] = None, max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, - element_size_fn: Optional[Callable[[Any], int]] = None): + element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None): """ Implementation of the ModelHandler interface for vLLM using previous messages as input. @@ -313,13 +323,19 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. """ super().__init__( min_batch_size=min_batch_size, max_batch_size=max_batch_size, max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, - element_size_fn=element_size_fn) + element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries) self._model_name = model_name self._vllm_server_kwargs: dict[str, str] = vllm_server_kwargs or {} self._chat_template_path = chat_template_path diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference.py b/sdks/python/apache_beam/ml/inference/xgboost_inference.py index 9d7413685113..f11b892062e8 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference.py @@ -80,6 +80,8 @@ def __init__( max_batch_duration_secs: Optional[int] = None, max_batch_weight: Optional[int] = None, element_size_fn: Optional[Callable[[Any], int]] = None, + batch_length_fn: Optional[Callable[[Any], int]] = None, + batch_bucket_boundaries: Optional[list[int]] = None, **kwargs): """Implementation of the ModelHandler interface for XGBoost. @@ -109,6 +111,10 @@ def __init__( max_batch_weight: optional. the maximum total weight of a batch. element_size_fn: optional. a function that returns the size (weight) of an element. + batch_length_fn: optional. a callable that returns the length of an + element for length-aware batching. + batch_bucket_boundaries: optional. a sorted list of positive boundary + values for length-aware batching buckets. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -131,6 +137,8 @@ def __init__( max_batch_duration_secs=max_batch_duration_secs, max_batch_weight=max_batch_weight, element_size_fn=element_size_fn, + batch_length_fn=batch_length_fn, + batch_bucket_boundaries=batch_bucket_boundaries, **kwargs) self._model_class = model_class self._model_state = model_state