From ec04b5d8e604af0c9098fca91702d558432765e5 Mon Sep 17 00:00:00 2001 From: klemens-floege Date: Tue, 17 Feb 2026 18:22:20 +0100 Subject: [PATCH 1/7] error and base funciton --- src/tabpfn/base.py | 28 ++++++++++++++++++++++++++++ src/tabpfn/errors.py | 11 ++++------- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/tabpfn/base.py b/src/tabpfn/base.py index c29f67260..aa728e9a6 100644 --- a/src/tabpfn/base.py +++ b/src/tabpfn/base.py @@ -489,3 +489,31 @@ def get_embeddings( embeddings.append(embed.squeeze().cpu().numpy()) return np.array(embeddings) + + +def predict_in_batches( + predict_fn: typing.Callable, + X: XType, + batch_size: int, + concat_fn: typing.Callable | None = None, +) -> typing.Any: + """Split X into batches, apply predict_fn to each, and concatenate results. + + Args: + predict_fn: A callable that takes a data slice and returns predictions. + X: The full input data. + batch_size: The number of samples per batch. + concat_fn: Optional custom function to concatenate results. + If None, uses ``np.concatenate(..., axis=0)``. + + Returns: + The concatenated predictions. + """ + n_samples = X.shape[0] + results = [ + predict_fn(X[start : min(start + batch_size, n_samples)]) + for start in range(0, n_samples, batch_size) + ] + if concat_fn is not None: + return concat_fn(results) + return np.concatenate(results, axis=0) diff --git a/src/tabpfn/errors.py b/src/tabpfn/errors.py index 87e99f04c..525b1c125 100644 --- a/src/tabpfn/errors.py +++ b/src/tabpfn/errors.py @@ -69,13 +69,10 @@ def __init__( message = ( f"{self.device_name} out of memory{size_info}.\n\n" - f"Solution: Split your test data into smaller batches:\n\n" - f" batch_size = 1000 # depends on hardware\n" - f" predictions = []\n" - f" for i in range(0, len(X_test), batch_size):\n" - f" batch = model.{predict_method}(X_test[i:i + batch_size])\n" - f" predictions.append(batch)\n" - f" predictions = np.vstack(predictions)" + f"Solution: Use batch_size_predict to split test data" + f" into smaller batches:\n\n" + f" model.{predict_method}(" + f"X_test, batch_size_predict=100)" ) if original_error is not None: message += f"\n\nOriginal error: {original_error}" From 5d3e726b5e8154e83e48273c8b8d4cab4df38ff4 Mon Sep 17 00:00:00 2001 From: klemens-floege Date: Tue, 17 Feb 2026 18:29:18 +0100 Subject: [PATCH 2/7] add to regressor + classifier --- src/tabpfn/classifier.py | 40 ++++++++++++++-- src/tabpfn/regressor.py | 52 ++++++++++++++++++++ tests/test_classifier_interface.py | 31 ++++++++++++ tests/test_regressor_interface.py | 76 ++++++++++++++++++++++++++++++ 4 files changed, 195 insertions(+), 4 deletions(-) diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 667690e6c..94847a269 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -40,6 +40,7 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + predict_in_batches, ) from tabpfn.constants import ( PROBABILITY_EPSILON_ROUND_ZERO, @@ -1061,15 +1062,20 @@ def _raw_predict( ) @track_model_call(model_method="predict", param_names=["X"]) - def predict(self, X: XType) -> np.ndarray: + def predict(self, X: XType, *, batch_size_predict: int | None = None) -> np.ndarray: """Predict the class labels for the provided input samples. Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted class labels as a NumPy array. """ + if batch_size_predict is not None: + return predict_in_batches(self.predict, X, batch_size_predict) + probas = self._predict_proba(X=X) y_pred = np.argmax(probas, axis=1) if hasattr(self, "label_encoder_") and self.label_encoder_ is not None: @@ -1079,7 +1085,9 @@ def predict(self, X: XType) -> np.ndarray: @config_context(transform_output="default") @track_model_call(model_method="predict", param_names=["X"]) - def predict_logits(self, X: XType) -> np.ndarray: + def predict_logits( + self, X: XType, *, batch_size_predict: int | None = None + ) -> np.ndarray: """Predict the raw logits for the provided input samples. Logits represent the unnormalized log-probabilities of the classes @@ -1087,16 +1095,23 @@ def predict_logits(self, X: XType) -> np.ndarray: Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted logits as a NumPy array. Shape (n_samples, n_classes). """ + if batch_size_predict is not None: + return predict_in_batches(self.predict_logits, X, batch_size_predict) + logits_tensor = self._raw_predict(X, return_logits=True) return logits_tensor.float().detach().cpu().numpy() @config_context(transform_output="default") @track_model_call(model_method="predict", param_names=["X"]) - def predict_raw_logits(self, X: XType) -> np.ndarray: + def predict_raw_logits( + self, X: XType, *, batch_size_predict: int | None = None + ) -> np.ndarray: """Predict the raw logits for the provided input samples. Logits represent the unnormalized log-probabilities of the classes @@ -1106,11 +1121,21 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: An array of predicted logits for each estimator, Shape (n_estimators, n_samples, n_classes). """ + if batch_size_predict is not None: + return predict_in_batches( + self.predict_raw_logits, + X, + batch_size_predict, + concat_fn=lambda results: np.concatenate(results, axis=1), + ) + logits_tensor = self._raw_predict( X, return_logits=False, @@ -1119,18 +1144,25 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: return logits_tensor.float().detach().cpu().numpy() @track_model_call(model_method="predict", param_names=["X"]) - def predict_proba(self, X: XType) -> np.ndarray: + def predict_proba( + self, X: XType, *, batch_size_predict: int | None = None + ) -> np.ndarray: """Predict the probabilities of the classes for the provided input samples. This is a wrapper around the `_predict_proba` method. Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted probabilities of the classes as a NumPy array. Shape (n_samples, n_classes). """ + if batch_size_predict is not None: + return predict_in_batches(self.predict_proba, X, batch_size_predict) + return self._predict_proba(X) @config_context(transform_output="default") # type: ignore diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 3e611452c..9a64a3226 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -46,6 +46,7 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + predict_in_batches, ) from tabpfn.constants import REGRESSION_CONSTANT_TARGET_BORDER_EPSILON, ModelVersion from tabpfn.errors import TabPFNValidationError, handle_oom_errors @@ -860,6 +861,7 @@ def predict( # TODO: support "ei", "pi" output_type: OutputType = "mean", quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> RegressionResultType: """Runs the forward() method and then transform the logits from the binning space in order to predict target variable. @@ -886,10 +888,25 @@ def predict( quantiles are returned. The predictions per quantile match the input order. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. + Returns: The prediction, which can be a numpy array, a list of arrays (for quantiles), or a dictionary with detailed outputs. """ + if batch_size_predict is not None: + return predict_in_batches( + lambda chunk: self.predict( + chunk, output_type=output_type, quantiles=quantiles + ), + X, + batch_size_predict, + concat_fn=lambda results: _concatenate_regression_results( + results, output_type + ), + ) + check_is_fitted(self) # TODO: Move these at some point to InferenceEngine @@ -1233,3 +1250,38 @@ def _logits_to_output( raise ValueError(f"Invalid output type: {output_type}") return output.cpu().detach().numpy() + + +def _concatenate_regression_results( + results: list[RegressionResultType], + output_type: str, +) -> RegressionResultType: + """Concatenate batched regression prediction results.""" + if output_type in _OUTPUT_TYPES_BASIC: + return np.concatenate(results, axis=0) + + if output_type == "quantiles": + return [ + np.concatenate([r[q] for r in results], axis=0) + for q in range(len(results[0])) + ] + + if output_type in ("main", "full"): + main = MainOutputDict( + mean=np.concatenate([r["mean"] for r in results], axis=0), + median=np.concatenate([r["median"] for r in results], axis=0), + mode=np.concatenate([r["mode"] for r in results], axis=0), + quantiles=[ + np.concatenate([r["quantiles"][q] for r in results], axis=0) + for q in range(len(results[0]["quantiles"])) + ], + ) + if output_type == "main": + return main + return FullOutputDict( + **main, + criterion=results[0]["criterion"], + logits=torch.cat([r["logits"] for r in results], dim=0), + ) + + raise ValueError(f"Invalid output type: {output_type}") diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index caa000f24..2d006840d 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -1203,6 +1203,37 @@ def _create_dummy_classifier_model_specs( ) +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched prediction.""" + X, y = X_y + + model = TabPFNClassifier(n_estimators=2, random_state=42) + model.fit(X, y) + + # Unbatched predictions + pred_all = model.predict(X) + proba_all = model.predict_proba(X) + logits_all = model.predict_logits(X) + raw_logits_all = model.predict_raw_logits(X) + + # Batched predictions + pred_batched = model.predict(X, batch_size_predict=batch_size_predict) + proba_batched = model.predict_proba(X, batch_size_predict=batch_size_predict) + logits_batched = model.predict_logits(X, batch_size_predict=batch_size_predict) + raw_logits_batched = model.predict_raw_logits( + X, batch_size_predict=batch_size_predict + ) + + np.testing.assert_array_equal(pred_all, pred_batched) + np.testing.assert_allclose(proba_all, proba_batched, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(logits_all, logits_batched, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(raw_logits_all, raw_logits_batched, atol=1e-5, rtol=1e-5) + + def test__create_default_for_version__v2__uses_correct_defaults() -> None: estimator = TabPFNClassifier.create_default_for_version(ModelVersion.V2) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 545cd71da..aecce26e2 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -807,6 +807,82 @@ def test__TabPFNRegressor__few_features__works(n_features: int) -> None: ) +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +@pytest.mark.parametrize("output_type", ["mean", "median", "mode"]) +def test__predict__batch_size_predict__matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, + output_type: str, +) -> None: + """Test that batch_size_predict matches unbatched prediction.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + pred_all = model.predict(X, output_type=output_type) + pred_batched = model.predict( + X, output_type=output_type, batch_size_predict=batch_size_predict + ) + np.testing.assert_allclose( + pred_all, pred_batched, atol=1e-3, rtol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__quantiles_matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched quantiles.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + quantiles_list = [0.1, 0.5, 0.9] + quant_all = model.predict( + X, output_type="quantiles", quantiles=quantiles_list + ) + quant_batched = model.predict( + X, + output_type="quantiles", + quantiles=quantiles_list, + batch_size_predict=batch_size_predict, + ) + for q_all, q_batched in zip(quant_all, quant_batched): + np.testing.assert_allclose( + q_all, q_batched, atol=1e-3, rtol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__main_matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched main output.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + main_all = model.predict(X, output_type="main") + main_batched = model.predict( + X, output_type="main", batch_size_predict=batch_size_predict + ) + for key in ["mean", "median", "mode"]: + np.testing.assert_allclose( + main_all[key], main_batched[key], atol=1e-3, rtol=1e-3 + ) + for q_all, q_batched in zip( + main_all["quantiles"], main_batched["quantiles"] + ): + np.testing.assert_allclose( + q_all, q_batched, atol=1e-3, rtol=1e-3 + ) + + def test__create_default_for_version__v2__uses_correct_defaults() -> None: estimator = TabPFNRegressor.create_default_for_version(ModelVersion.V2) From 49126f91be331356f78382a9c359d8392163ad42 Mon Sep 17 00:00:00 2001 From: klemens-floege Date: Tue, 24 Feb 2026 12:13:00 +0100 Subject: [PATCH 3/7] ew code --- src/tabpfn/classifier.py | 60 +++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 94847a269..9de57fb6e 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -1019,6 +1019,7 @@ def _raw_predict( *, return_logits: bool, return_raw_logits: bool = False, + batch_size_predict: int | None = None, ) -> torch.Tensor: """Internal method to run prediction. @@ -1033,11 +1034,26 @@ def _raw_predict( post-processing steps. return_raw_logits: If True, returns the raw logits without averaging estimators or temperature scaling. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The raw torch.Tensor output, either logits or probabilities, depending on `return_logits` and `return_raw_logits`. """ + if batch_size_predict is not None: + concat_dim = 1 if return_raw_logits else 0 + return predict_in_batches( + lambda chunk: self._raw_predict( + chunk, + return_logits=return_logits, + return_raw_logits=return_raw_logits, + ), + X, + batch_size_predict, + concat_fn=lambda results: torch.cat(results, dim=concat_dim), + ) + check_is_fitted(self) if not self.differentiable_input: @@ -1073,10 +1089,7 @@ def predict(self, X: XType, *, batch_size_predict: int | None = None) -> np.ndar Returns: The predicted class labels as a NumPy array. """ - if batch_size_predict is not None: - return predict_in_batches(self.predict, X, batch_size_predict) - - probas = self._predict_proba(X=X) + probas = self._predict_proba(X=X, batch_size_predict=batch_size_predict) y_pred = np.argmax(probas, axis=1) if hasattr(self, "label_encoder_") and self.label_encoder_ is not None: return self.label_encoder_.inverse_transform(y_pred) @@ -1101,10 +1114,9 @@ def predict_logits( Returns: The predicted logits as a NumPy array. Shape (n_samples, n_classes). """ - if batch_size_predict is not None: - return predict_in_batches(self.predict_logits, X, batch_size_predict) - - logits_tensor = self._raw_predict(X, return_logits=True) + logits_tensor = self._raw_predict( + X, return_logits=True, batch_size_predict=batch_size_predict + ) return logits_tensor.float().detach().cpu().numpy() @config_context(transform_output="default") @@ -1128,18 +1140,11 @@ def predict_raw_logits( An array of predicted logits for each estimator, Shape (n_estimators, n_samples, n_classes). """ - if batch_size_predict is not None: - return predict_in_batches( - self.predict_raw_logits, - X, - batch_size_predict, - concat_fn=lambda results: np.concatenate(results, axis=1), - ) - logits_tensor = self._raw_predict( X, return_logits=False, return_raw_logits=True, + batch_size_predict=batch_size_predict, ) return logits_tensor.float().detach().cpu().numpy() @@ -1160,24 +1165,35 @@ def predict_proba( The predicted probabilities of the classes as a NumPy array. Shape (n_samples, n_classes). """ - if batch_size_predict is not None: - return predict_in_batches(self.predict_proba, X, batch_size_predict) - - return self._predict_proba(X) + return self._predict_proba(X, batch_size_predict=batch_size_predict) @config_context(transform_output="default") # type: ignore - def _predict_proba(self, X: XType) -> np.ndarray: + def _predict_proba( + self, + X: XType, + batch_size_predict: int | None = None, + ) -> np.ndarray: """Predict the probabilities of the classes for the provided input samples. Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted probabilities of the classes as a NumPy array. Shape (n_samples, n_classes). """ probas = ( - self._raw_predict(X, return_logits=False).float().detach().cpu().numpy() + self._raw_predict( + X, + return_logits=False, + batch_size_predict=batch_size_predict, + ) + .float() + .detach() + .cpu() + .numpy() ) probas = self._maybe_reweight_probas(probas=probas) if self.inference_config_.USE_SKLEARN_16_DECIMAL_PRECISION: From f9b585a7aee6c0449162ad516593c03b7c28da6f Mon Sep 17 00:00:00 2001 From: klemens-floege Date: Fri, 27 Feb 2026 14:44:36 +0100 Subject: [PATCH 4/7] add code --- src/tabpfn/base.py | 6 ++++ src/tabpfn/regressor.py | 51 +++++++++++++++++++++++-------- tests/test_regressor_interface.py | 33 ++++++++++++++++++++ 3 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/tabpfn/base.py b/src/tabpfn/base.py index 4f9dfc6de..e6e64537f 100644 --- a/src/tabpfn/base.py +++ b/src/tabpfn/base.py @@ -506,7 +506,13 @@ def predict_in_batches( Returns: The concatenated predictions. + + Raises: + ValueError: If batch_size is not a positive integer. """ + if batch_size <= 0: + raise ValueError("batch_size must be a positive integer") + n_samples = X.shape[0] results = [ predict_fn(X[start : min(start + batch_size, n_samples)]) diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 3367ad59c..c1772abbf 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -831,6 +831,7 @@ def predict( *, output_type: Literal["mean", "median", "mode"] = "mean", quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> np.ndarray: ... @overload @@ -840,6 +841,7 @@ def predict( *, output_type: Literal["quantiles"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> list[np.ndarray]: ... @overload @@ -849,6 +851,7 @@ def predict( *, output_type: Literal["main"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> MainOutputDict: ... @overload @@ -858,6 +861,7 @@ def predict( *, output_type: Literal["full"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> FullOutputDict: ... @config_context(transform_output="default") # type: ignore @@ -903,18 +907,6 @@ def predict( The prediction, which can be a numpy array, a list of arrays (for quantiles), or a dictionary with detailed outputs. """ - if batch_size_predict is not None: - return predict_in_batches( - lambda chunk: self.predict( - chunk, output_type=output_type, quantiles=quantiles - ), - X, - batch_size_predict, - concat_fn=lambda results: _concatenate_regression_results( - results, output_type - ), - ) - check_is_fitted(self) # TODO: Move these at some point to InferenceEngine @@ -942,6 +934,39 @@ def predict( X, ord_encoder=getattr(self, "ordinal_encoder_", None) ) + if batch_size_predict is not None: + return predict_in_batches( + lambda chunk: self._predict_core( + chunk, output_type=output_type, quantiles=quantiles + ), + X, + batch_size_predict, + concat_fn=lambda results: _concatenate_regression_results( + results, output_type + ), + ) + + return self._predict_core(X, output_type=output_type, quantiles=quantiles) + + def _predict_core( + self, + X: XType, + output_type: OutputType, + quantiles: list[float], + ) -> RegressionResultType: + """Core prediction logic on already-preprocessed data. + + Runs the forward pass, translates logits, and formats the output. + This method assumes X has already been validated and preprocessed. + + Args: + X: The preprocessed input data. + output_type: The type of output to return. + quantiles: The quantiles to compute. + + Returns: + The prediction result. + """ # Runs over iteration engine with handle_oom_errors(self.devices_, X, model_type="regressor"): ( @@ -1286,6 +1311,8 @@ def _concatenate_regression_results( ) if output_type == "main": return main + # criterion is a model-level attribute (raw_space_bardist_), identical + # across all batches, so we take it from the first result. return FullOutputDict( **main, criterion=results[0]["criterion"], diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index aecce26e2..e982e8206 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -883,6 +883,39 @@ def test__predict__batch_size_predict__main_matches_unbatched( ) +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__full_matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched full output.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + full_all = model.predict(X, output_type="full") + full_batched = model.predict( + X, output_type="full", batch_size_predict=batch_size_predict + ) + for key in ["mean", "median", "mode"]: + np.testing.assert_allclose( + full_all[key], full_batched[key], atol=1e-3, rtol=1e-3 + ) + for q_all, q_batched in zip( + full_all["quantiles"], full_batched["quantiles"] + ): + np.testing.assert_allclose( + q_all, q_batched, atol=1e-3, rtol=1e-3 + ) + # logits should match + torch.testing.assert_close( + full_all["logits"], full_batched["logits"], atol=1e-3, rtol=1e-3 + ) + # criterion is a model-level attribute, should be the same object + assert full_all["criterion"] is full_batched["criterion"] + + def test__create_default_for_version__v2__uses_correct_defaults() -> None: estimator = TabPFNRegressor.create_default_for_version(ModelVersion.V2) From 3e48b1663f50ab8d7359d0b125bbf143a68846a7 Mon Sep 17 00:00:00 2001 From: klemens-floege Date: Fri, 27 Feb 2026 14:52:34 +0100 Subject: [PATCH 5/7] linting --- tests/test_regressor_interface.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index e982e8206..52b8ebb92 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -824,9 +824,7 @@ def test__predict__batch_size_predict__matches_unbatched( pred_batched = model.predict( X, output_type=output_type, batch_size_predict=batch_size_predict ) - np.testing.assert_allclose( - pred_all, pred_batched, atol=1e-3, rtol=1e-3 - ) + np.testing.assert_allclose(pred_all, pred_batched, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) @@ -841,9 +839,7 @@ def test__predict__batch_size_predict__quantiles_matches_unbatched( model.fit(X, y) quantiles_list = [0.1, 0.5, 0.9] - quant_all = model.predict( - X, output_type="quantiles", quantiles=quantiles_list - ) + quant_all = model.predict(X, output_type="quantiles", quantiles=quantiles_list) quant_batched = model.predict( X, output_type="quantiles", @@ -851,9 +847,7 @@ def test__predict__batch_size_predict__quantiles_matches_unbatched( batch_size_predict=batch_size_predict, ) for q_all, q_batched in zip(quant_all, quant_batched): - np.testing.assert_allclose( - q_all, q_batched, atol=1e-3, rtol=1e-3 - ) + np.testing.assert_allclose(q_all, q_batched, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) @@ -875,12 +869,8 @@ def test__predict__batch_size_predict__main_matches_unbatched( np.testing.assert_allclose( main_all[key], main_batched[key], atol=1e-3, rtol=1e-3 ) - for q_all, q_batched in zip( - main_all["quantiles"], main_batched["quantiles"] - ): - np.testing.assert_allclose( - q_all, q_batched, atol=1e-3, rtol=1e-3 - ) + for q_all, q_batched in zip(main_all["quantiles"], main_batched["quantiles"]): + np.testing.assert_allclose(q_all, q_batched, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) @@ -902,12 +892,8 @@ def test__predict__batch_size_predict__full_matches_unbatched( np.testing.assert_allclose( full_all[key], full_batched[key], atol=1e-3, rtol=1e-3 ) - for q_all, q_batched in zip( - full_all["quantiles"], full_batched["quantiles"] - ): - np.testing.assert_allclose( - q_all, q_batched, atol=1e-3, rtol=1e-3 - ) + for q_all, q_batched in zip(full_all["quantiles"], full_batched["quantiles"]): + np.testing.assert_allclose(q_all, q_batched, atol=1e-3, rtol=1e-3) # logits should match torch.testing.assert_close( full_all["logits"], full_batched["logits"], atol=1e-3, rtol=1e-3 From 4b97ade9823989a3fa7e8cdc1a394ded48f499a6 Mon Sep 17 00:00:00 2001 From: klemens-floege Date: Fri, 27 Feb 2026 14:53:54 +0100 Subject: [PATCH 6/7] add changelog --- changelog/793.added.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/793.added.md diff --git a/changelog/793.added.md b/changelog/793.added.md new file mode 100644 index 000000000..7ccf0c498 --- /dev/null +++ b/changelog/793.added.md @@ -0,0 +1 @@ +Add a batch_size_predict paramter to our .predict() functions that performs automatic batching of the test set. \ No newline at end of file From 9901c10658296c5cf42be5e80aecff36fb069073 Mon Sep 17 00:00:00 2001 From: klemens-floege Date: Fri, 27 Feb 2026 14:54:01 +0100 Subject: [PATCH 7/7] more precision --- tests/test_regressor_interface.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 52b8ebb92..48fc9660c 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -824,7 +824,7 @@ def test__predict__batch_size_predict__matches_unbatched( pred_batched = model.predict( X, output_type=output_type, batch_size_predict=batch_size_predict ) - np.testing.assert_allclose(pred_all, pred_batched, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(pred_all, pred_batched, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) @@ -847,7 +847,7 @@ def test__predict__batch_size_predict__quantiles_matches_unbatched( batch_size_predict=batch_size_predict, ) for q_all, q_batched in zip(quant_all, quant_batched): - np.testing.assert_allclose(q_all, q_batched, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(q_all, q_batched, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) @@ -867,10 +867,10 @@ def test__predict__batch_size_predict__main_matches_unbatched( ) for key in ["mean", "median", "mode"]: np.testing.assert_allclose( - main_all[key], main_batched[key], atol=1e-3, rtol=1e-3 + main_all[key], main_batched[key], atol=1e-5, rtol=1e-5 ) for q_all, q_batched in zip(main_all["quantiles"], main_batched["quantiles"]): - np.testing.assert_allclose(q_all, q_batched, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(q_all, q_batched, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) @@ -890,13 +890,13 @@ def test__predict__batch_size_predict__full_matches_unbatched( ) for key in ["mean", "median", "mode"]: np.testing.assert_allclose( - full_all[key], full_batched[key], atol=1e-3, rtol=1e-3 + full_all[key], full_batched[key], atol=1e-5, rtol=1e-5 ) for q_all, q_batched in zip(full_all["quantiles"], full_batched["quantiles"]): - np.testing.assert_allclose(q_all, q_batched, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(q_all, q_batched, atol=1e-5, rtol=1e-5) # logits should match torch.testing.assert_close( - full_all["logits"], full_batched["logits"], atol=1e-3, rtol=1e-3 + full_all["logits"], full_batched["logits"], atol=1e-5, rtol=1e-5 ) # criterion is a model-level attribute, should be the same object assert full_all["criterion"] is full_batched["criterion"]