diff --git a/changelog/793.added.md b/changelog/793.added.md new file mode 100644 index 000000000..7ccf0c498 --- /dev/null +++ b/changelog/793.added.md @@ -0,0 +1 @@ +Add a batch_size_predict paramter to our .predict() functions that performs automatic batching of the test set. \ No newline at end of file diff --git a/src/tabpfn/base.py b/src/tabpfn/base.py index 04c912aff..e6e64537f 100644 --- a/src/tabpfn/base.py +++ b/src/tabpfn/base.py @@ -487,3 +487,37 @@ def get_embeddings( embeddings.append(embed.squeeze().cpu().numpy()) return np.array(embeddings) + + +def predict_in_batches( + predict_fn: typing.Callable, + X: XType, + batch_size: int, + concat_fn: typing.Callable | None = None, +) -> typing.Any: + """Split X into batches, apply predict_fn to each, and concatenate results. + + Args: + predict_fn: A callable that takes a data slice and returns predictions. + X: The full input data. + batch_size: The number of samples per batch. + concat_fn: Optional custom function to concatenate results. + If None, uses ``np.concatenate(..., axis=0)``. + + Returns: + The concatenated predictions. + + Raises: + ValueError: If batch_size is not a positive integer. + """ + if batch_size <= 0: + raise ValueError("batch_size must be a positive integer") + + n_samples = X.shape[0] + results = [ + predict_fn(X[start : min(start + batch_size, n_samples)]) + for start in range(0, n_samples, batch_size) + ] + if concat_fn is not None: + return concat_fn(results) + return np.concatenate(results, axis=0) diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 01397194a..c8310c429 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -40,6 +40,7 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + predict_in_batches, ) from tabpfn.constants import ( PROBABILITY_EPSILON_ROUND_ZERO, @@ -1028,6 +1029,7 @@ def _raw_predict( *, return_logits: bool, return_raw_logits: bool = False, + batch_size_predict: int | None = None, ) -> torch.Tensor: """Internal method to run prediction. @@ -1042,11 +1044,26 @@ def _raw_predict( post-processing steps. return_raw_logits: If True, returns the raw logits without averaging estimators or temperature scaling. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The raw torch.Tensor output, either logits or probabilities, depending on `return_logits` and `return_raw_logits`. """ + if batch_size_predict is not None: + concat_dim = 1 if return_raw_logits else 0 + return predict_in_batches( + lambda chunk: self._raw_predict( + chunk, + return_logits=return_logits, + return_raw_logits=return_raw_logits, + ), + X, + batch_size_predict, + concat_fn=lambda results: torch.cat(results, dim=concat_dim), + ) + check_is_fitted(self) if not self.differentiable_input: @@ -1071,16 +1088,18 @@ def _raw_predict( ) @track_model_call(model_method="predict", param_names=["X"]) - def predict(self, X: XType) -> np.ndarray: + def predict(self, X: XType, *, batch_size_predict: int | None = None) -> np.ndarray: """Predict the class labels for the provided input samples. Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted class labels as a NumPy array. """ - probas = self._predict_proba(X=X) + probas = self._predict_proba(X=X, batch_size_predict=batch_size_predict) y_pred = np.argmax(probas, axis=1) if hasattr(self, "label_encoder_") and self.label_encoder_ is not None: return self.label_encoder_.inverse_transform(y_pred) @@ -1089,7 +1108,9 @@ def predict(self, X: XType) -> np.ndarray: @config_context(transform_output="default") @track_model_call(model_method="predict", param_names=["X"]) - def predict_logits(self, X: XType) -> np.ndarray: + def predict_logits( + self, X: XType, *, batch_size_predict: int | None = None + ) -> np.ndarray: """Predict the raw logits for the provided input samples. Logits represent the unnormalized log-probabilities of the classes @@ -1097,16 +1118,22 @@ def predict_logits(self, X: XType) -> np.ndarray: Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted logits as a NumPy array. Shape (n_samples, n_classes). """ - logits_tensor = self._raw_predict(X, return_logits=True) + logits_tensor = self._raw_predict( + X, return_logits=True, batch_size_predict=batch_size_predict + ) return logits_tensor.float().detach().cpu().numpy() @config_context(transform_output="default") @track_model_call(model_method="predict", param_names=["X"]) - def predict_raw_logits(self, X: XType) -> np.ndarray: + def predict_raw_logits( + self, X: XType, *, batch_size_predict: int | None = None + ) -> np.ndarray: """Predict the raw logits for the provided input samples. Logits represent the unnormalized log-probabilities of the classes @@ -1116,6 +1143,8 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: An array of predicted logits for each estimator, @@ -1125,37 +1154,56 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: X, return_logits=False, return_raw_logits=True, + batch_size_predict=batch_size_predict, ) return logits_tensor.float().detach().cpu().numpy() @track_model_call(model_method="predict", param_names=["X"]) - def predict_proba(self, X: XType) -> np.ndarray: + def predict_proba( + self, X: XType, *, batch_size_predict: int | None = None + ) -> np.ndarray: """Predict the probabilities of the classes for the provided input samples. This is a wrapper around the `_predict_proba` method. Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted probabilities of the classes as a NumPy array. Shape (n_samples, n_classes). """ - return self._predict_proba(X) + return self._predict_proba(X, batch_size_predict=batch_size_predict) @config_context(transform_output="default") # type: ignore - def _predict_proba(self, X: XType) -> np.ndarray: + def _predict_proba( + self, + X: XType, + batch_size_predict: int | None = None, + ) -> np.ndarray: """Predict the probabilities of the classes for the provided input samples. Args: X: The input data for prediction. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. Returns: The predicted probabilities of the classes as a NumPy array. Shape (n_samples, n_classes). """ probas = ( - self._raw_predict(X, return_logits=False).float().detach().cpu().numpy() + self._raw_predict( + X, + return_logits=False, + batch_size_predict=batch_size_predict, + ) + .float() + .detach() + .cpu() + .numpy() ) probas = self._maybe_reweight_probas(probas=probas) if self.inference_config_.USE_SKLEARN_16_DECIMAL_PRECISION: diff --git a/src/tabpfn/errors.py b/src/tabpfn/errors.py index 87e99f04c..525b1c125 100644 --- a/src/tabpfn/errors.py +++ b/src/tabpfn/errors.py @@ -69,13 +69,10 @@ def __init__( message = ( f"{self.device_name} out of memory{size_info}.\n\n" - f"Solution: Split your test data into smaller batches:\n\n" - f" batch_size = 1000 # depends on hardware\n" - f" predictions = []\n" - f" for i in range(0, len(X_test), batch_size):\n" - f" batch = model.{predict_method}(X_test[i:i + batch_size])\n" - f" predictions.append(batch)\n" - f" predictions = np.vstack(predictions)" + f"Solution: Use batch_size_predict to split test data" + f" into smaller batches:\n\n" + f" model.{predict_method}(" + f"X_test, batch_size_predict=100)" ) if original_error is not None: message += f"\n\nOriginal error: {original_error}" diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index c1d873c76..c1772abbf 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -46,6 +46,7 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + predict_in_batches, ) from tabpfn.constants import REGRESSION_CONSTANT_TARGET_BORDER_EPSILON, ModelVersion from tabpfn.errors import TabPFNValidationError, handle_oom_errors @@ -830,6 +831,7 @@ def predict( *, output_type: Literal["mean", "median", "mode"] = "mean", quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> np.ndarray: ... @overload @@ -839,6 +841,7 @@ def predict( *, output_type: Literal["quantiles"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> list[np.ndarray]: ... @overload @@ -848,6 +851,7 @@ def predict( *, output_type: Literal["main"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> MainOutputDict: ... @overload @@ -857,6 +861,7 @@ def predict( *, output_type: Literal["full"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> FullOutputDict: ... @config_context(transform_output="default") # type: ignore @@ -868,6 +873,7 @@ def predict( # TODO: support "ei", "pi" output_type: OutputType = "mean", quantiles: list[float] | None = None, + batch_size_predict: int | None = None, ) -> RegressionResultType: """Runs the forward() method and then transform the logits from the binning space in order to predict target variable. @@ -894,6 +900,9 @@ def predict( quantiles are returned. The predictions per quantile match the input order. + batch_size_predict: If not None, split the test data into + chunks of this size and predict each chunk independently. + Returns: The prediction, which can be a numpy array, a list of arrays (for quantiles), or a dictionary with detailed outputs. @@ -925,6 +934,39 @@ def predict( X, ord_encoder=getattr(self, "ordinal_encoder_", None) ) + if batch_size_predict is not None: + return predict_in_batches( + lambda chunk: self._predict_core( + chunk, output_type=output_type, quantiles=quantiles + ), + X, + batch_size_predict, + concat_fn=lambda results: _concatenate_regression_results( + results, output_type + ), + ) + + return self._predict_core(X, output_type=output_type, quantiles=quantiles) + + def _predict_core( + self, + X: XType, + output_type: OutputType, + quantiles: list[float], + ) -> RegressionResultType: + """Core prediction logic on already-preprocessed data. + + Runs the forward pass, translates logits, and formats the output. + This method assumes X has already been validated and preprocessed. + + Args: + X: The preprocessed input data. + output_type: The type of output to return. + quantiles: The quantiles to compute. + + Returns: + The prediction result. + """ # Runs over iteration engine with handle_oom_errors(self.devices_, X, model_type="regressor"): ( @@ -1241,3 +1283,40 @@ def _logits_to_output( raise ValueError(f"Invalid output type: {output_type}") return output.cpu().detach().numpy() + + +def _concatenate_regression_results( + results: list[RegressionResultType], + output_type: str, +) -> RegressionResultType: + """Concatenate batched regression prediction results.""" + if output_type in _OUTPUT_TYPES_BASIC: + return np.concatenate(results, axis=0) + + if output_type == "quantiles": + return [ + np.concatenate([r[q] for r in results], axis=0) + for q in range(len(results[0])) + ] + + if output_type in ("main", "full"): + main = MainOutputDict( + mean=np.concatenate([r["mean"] for r in results], axis=0), + median=np.concatenate([r["median"] for r in results], axis=0), + mode=np.concatenate([r["mode"] for r in results], axis=0), + quantiles=[ + np.concatenate([r["quantiles"][q] for r in results], axis=0) + for q in range(len(results[0]["quantiles"])) + ], + ) + if output_type == "main": + return main + # criterion is a model-level attribute (raw_space_bardist_), identical + # across all batches, so we take it from the first result. + return FullOutputDict( + **main, + criterion=results[0]["criterion"], + logits=torch.cat([r["logits"] for r in results], dim=0), + ) + + raise ValueError(f"Invalid output type: {output_type}") diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index caa000f24..2d006840d 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -1203,6 +1203,37 @@ def _create_dummy_classifier_model_specs( ) +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched prediction.""" + X, y = X_y + + model = TabPFNClassifier(n_estimators=2, random_state=42) + model.fit(X, y) + + # Unbatched predictions + pred_all = model.predict(X) + proba_all = model.predict_proba(X) + logits_all = model.predict_logits(X) + raw_logits_all = model.predict_raw_logits(X) + + # Batched predictions + pred_batched = model.predict(X, batch_size_predict=batch_size_predict) + proba_batched = model.predict_proba(X, batch_size_predict=batch_size_predict) + logits_batched = model.predict_logits(X, batch_size_predict=batch_size_predict) + raw_logits_batched = model.predict_raw_logits( + X, batch_size_predict=batch_size_predict + ) + + np.testing.assert_array_equal(pred_all, pred_batched) + np.testing.assert_allclose(proba_all, proba_batched, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(logits_all, logits_batched, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(raw_logits_all, raw_logits_batched, atol=1e-5, rtol=1e-5) + + def test__create_default_for_version__v2__uses_correct_defaults() -> None: estimator = TabPFNClassifier.create_default_for_version(ModelVersion.V2) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 545cd71da..48fc9660c 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -807,6 +807,101 @@ def test__TabPFNRegressor__few_features__works(n_features: int) -> None: ) +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +@pytest.mark.parametrize("output_type", ["mean", "median", "mode"]) +def test__predict__batch_size_predict__matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, + output_type: str, +) -> None: + """Test that batch_size_predict matches unbatched prediction.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + pred_all = model.predict(X, output_type=output_type) + pred_batched = model.predict( + X, output_type=output_type, batch_size_predict=batch_size_predict + ) + np.testing.assert_allclose(pred_all, pred_batched, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__quantiles_matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched quantiles.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + quantiles_list = [0.1, 0.5, 0.9] + quant_all = model.predict(X, output_type="quantiles", quantiles=quantiles_list) + quant_batched = model.predict( + X, + output_type="quantiles", + quantiles=quantiles_list, + batch_size_predict=batch_size_predict, + ) + for q_all, q_batched in zip(quant_all, quant_batched): + np.testing.assert_allclose(q_all, q_batched, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__main_matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched main output.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + main_all = model.predict(X, output_type="main") + main_batched = model.predict( + X, output_type="main", batch_size_predict=batch_size_predict + ) + for key in ["mean", "median", "mode"]: + np.testing.assert_allclose( + main_all[key], main_batched[key], atol=1e-5, rtol=1e-5 + ) + for q_all, q_batched in zip(main_all["quantiles"], main_batched["quantiles"]): + np.testing.assert_allclose(q_all, q_batched, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("batch_size_predict", [1, 3, 5]) +def test__predict__batch_size_predict__full_matches_unbatched( + X_y: tuple[np.ndarray, np.ndarray], + batch_size_predict: int, +) -> None: + """Test that batch_size_predict matches unbatched full output.""" + X, y = X_y + + model = TabPFNRegressor(n_estimators=2, random_state=42) + model.fit(X, y) + + full_all = model.predict(X, output_type="full") + full_batched = model.predict( + X, output_type="full", batch_size_predict=batch_size_predict + ) + for key in ["mean", "median", "mode"]: + np.testing.assert_allclose( + full_all[key], full_batched[key], atol=1e-5, rtol=1e-5 + ) + for q_all, q_batched in zip(full_all["quantiles"], full_batched["quantiles"]): + np.testing.assert_allclose(q_all, q_batched, atol=1e-5, rtol=1e-5) + # logits should match + torch.testing.assert_close( + full_all["logits"], full_batched["logits"], atol=1e-5, rtol=1e-5 + ) + # criterion is a model-level attribute, should be the same object + assert full_all["criterion"] is full_batched["criterion"] + + def test__create_default_for_version__v2__uses_correct_defaults() -> None: estimator = TabPFNRegressor.create_default_for_version(ModelVersion.V2)