Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/793.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a batch_size_predict paramter to our .predict() functions that performs automatic batching of the test set.
34 changes: 34 additions & 0 deletions src/tabpfn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,37 @@ def get_embeddings(
embeddings.append(embed.squeeze().cpu().numpy())

return np.array(embeddings)


def predict_in_batches(
predict_fn: typing.Callable,
X: XType,
batch_size: int,
concat_fn: typing.Callable | None = None,
) -> typing.Any:
Comment on lines +492 to +497
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The predict_in_batches function is a good addition for handling large datasets. However, the type hint for X is XType, which is a generic Any. It would be more precise to use np.ndarray or torch.Tensor as X is indexed directly, which implies it's an array-like object. This improves type safety and readability.

Suggested change
def predict_in_batches(
predict_fn: typing.Callable,
X: XType,
batch_size: int,
concat_fn: typing.Callable | None = None,
) -> typing.Any:
def predict_in_batches(
predict_fn: typing.Callable,
X: np.ndarray, # More specific type hint
batch_size: int,
concat_fn: typing.Callable | None = None,
) -> typing.Any:

"""Split X into batches, apply predict_fn to each, and concatenate results.

Args:
predict_fn: A callable that takes a data slice and returns predictions.
X: The full input data.
batch_size: The number of samples per batch.
concat_fn: Optional custom function to concatenate results.
If None, uses ``np.concatenate(..., axis=0)``.

Returns:
The concatenated predictions.

Raises:
ValueError: If batch_size is not a positive integer.
"""
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer")

n_samples = X.shape[0]
results = [
predict_fn(X[start : min(start + batch_size, n_samples)])
for start in range(0, n_samples, batch_size)
]
if concat_fn is not None:
return concat_fn(results)
return np.concatenate(results, axis=0)
66 changes: 57 additions & 9 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_embeddings,
initialize_model_variables_helper,
initialize_telemetry,
predict_in_batches,
)
from tabpfn.constants import (
PROBABILITY_EPSILON_ROUND_ZERO,
Expand Down Expand Up @@ -1028,6 +1029,7 @@ def _raw_predict(
*,
return_logits: bool,
return_raw_logits: bool = False,
batch_size_predict: int | None = None,
) -> torch.Tensor:
"""Internal method to run prediction.

Expand All @@ -1042,11 +1044,26 @@ def _raw_predict(
post-processing steps.
return_raw_logits: If True, returns the raw logits without
averaging estimators or temperature scaling.
batch_size_predict: If not None, split the test data into
chunks of this size and predict each chunk independently.

Returns:
The raw torch.Tensor output, either logits or probabilities,
depending on `return_logits` and `return_raw_logits`.
"""
if batch_size_predict is not None:
concat_dim = 1 if return_raw_logits else 0
return predict_in_batches(
lambda chunk: self._raw_predict(
chunk,
return_logits=return_logits,
return_raw_logits=return_raw_logits,
),
X,
batch_size_predict,
concat_fn=lambda results: torch.cat(results, dim=concat_dim),
)

check_is_fitted(self)

if not self.differentiable_input:
Expand All @@ -1071,16 +1088,18 @@ def _raw_predict(
)

@track_model_call(model_method="predict", param_names=["X"])
def predict(self, X: XType) -> np.ndarray:
def predict(self, X: XType, *, batch_size_predict: int | None = None) -> np.ndarray:
"""Predict the class labels for the provided input samples.

Args:
X: The input data for prediction.
batch_size_predict: If not None, split the test data into
chunks of this size and predict each chunk independently.

Returns:
The predicted class labels as a NumPy array.
"""
probas = self._predict_proba(X=X)
probas = self._predict_proba(X=X, batch_size_predict=batch_size_predict)
y_pred = np.argmax(probas, axis=1)
if hasattr(self, "label_encoder_") and self.label_encoder_ is not None:
return self.label_encoder_.inverse_transform(y_pred)
Expand All @@ -1089,24 +1108,32 @@ def predict(self, X: XType) -> np.ndarray:

@config_context(transform_output="default")
@track_model_call(model_method="predict", param_names=["X"])
def predict_logits(self, X: XType) -> np.ndarray:
def predict_logits(
self, X: XType, *, batch_size_predict: int | None = None
) -> np.ndarray:
"""Predict the raw logits for the provided input samples.

Logits represent the unnormalized log-probabilities of the classes
before the softmax activation function is applied.

Args:
X: The input data for prediction.
batch_size_predict: If not None, split the test data into
chunks of this size and predict each chunk independently.

Returns:
The predicted logits as a NumPy array. Shape (n_samples, n_classes).
"""
logits_tensor = self._raw_predict(X, return_logits=True)
logits_tensor = self._raw_predict(
X, return_logits=True, batch_size_predict=batch_size_predict
)
return logits_tensor.float().detach().cpu().numpy()

@config_context(transform_output="default")
@track_model_call(model_method="predict", param_names=["X"])
def predict_raw_logits(self, X: XType) -> np.ndarray:
def predict_raw_logits(
self, X: XType, *, batch_size_predict: int | None = None
) -> np.ndarray:
"""Predict the raw logits for the provided input samples.

Logits represent the unnormalized log-probabilities of the classes
Expand All @@ -1116,6 +1143,8 @@ def predict_raw_logits(self, X: XType) -> np.ndarray:

Args:
X: The input data for prediction.
batch_size_predict: If not None, split the test data into
chunks of this size and predict each chunk independently.

Returns:
An array of predicted logits for each estimator,
Expand All @@ -1125,37 +1154,56 @@ def predict_raw_logits(self, X: XType) -> np.ndarray:
X,
return_logits=False,
return_raw_logits=True,
batch_size_predict=batch_size_predict,
)
return logits_tensor.float().detach().cpu().numpy()

@track_model_call(model_method="predict", param_names=["X"])
def predict_proba(self, X: XType) -> np.ndarray:
def predict_proba(
self, X: XType, *, batch_size_predict: int | None = None
) -> np.ndarray:
"""Predict the probabilities of the classes for the provided input samples.

This is a wrapper around the `_predict_proba` method.

Args:
X: The input data for prediction.
batch_size_predict: If not None, split the test data into
chunks of this size and predict each chunk independently.

Returns:
The predicted probabilities of the classes as a NumPy array.
Shape (n_samples, n_classes).
"""
return self._predict_proba(X)
return self._predict_proba(X, batch_size_predict=batch_size_predict)

@config_context(transform_output="default") # type: ignore
def _predict_proba(self, X: XType) -> np.ndarray:
def _predict_proba(
self,
X: XType,
batch_size_predict: int | None = None,
) -> np.ndarray:
"""Predict the probabilities of the classes for the provided input samples.

Args:
X: The input data for prediction.
batch_size_predict: If not None, split the test data into
chunks of this size and predict each chunk independently.

Returns:
The predicted probabilities of the classes as a NumPy array.
Shape (n_samples, n_classes).
"""
probas = (
self._raw_predict(X, return_logits=False).float().detach().cpu().numpy()
self._raw_predict(
X,
return_logits=False,
batch_size_predict=batch_size_predict,
)
.float()
.detach()
.cpu()
.numpy()
)
probas = self._maybe_reweight_probas(probas=probas)
if self.inference_config_.USE_SKLEARN_16_DECIMAL_PRECISION:
Expand Down
11 changes: 4 additions & 7 deletions src/tabpfn/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,10 @@ def __init__(

message = (
f"{self.device_name} out of memory{size_info}.\n\n"
f"Solution: Split your test data into smaller batches:\n\n"
f" batch_size = 1000 # depends on hardware\n"
f" predictions = []\n"
f" for i in range(0, len(X_test), batch_size):\n"
f" batch = model.{predict_method}(X_test[i:i + batch_size])\n"
f" predictions.append(batch)\n"
f" predictions = np.vstack(predictions)"
f"Solution: Use batch_size_predict to split test data"
f" into smaller batches:\n\n"
f" model.{predict_method}("
f"X_test, batch_size_predict=100)"
Comment on lines +72 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example code in the error message uses batch_size_predict=100. While this is a reasonable default, it might be beneficial to mention that the optimal batch_size_predict depends on hardware and the specific dataset, similar to how the previous message suggested batch_size = 1000 # depends on hardware. This would provide more comprehensive guidance to the user.

            f

)
if original_error is not None:
message += f"\n\nOriginal error: {original_error}"
Expand Down
79 changes: 79 additions & 0 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
get_embeddings,
initialize_model_variables_helper,
initialize_telemetry,
predict_in_batches,
)
from tabpfn.constants import REGRESSION_CONSTANT_TARGET_BORDER_EPSILON, ModelVersion
from tabpfn.errors import TabPFNValidationError, handle_oom_errors
Expand Down Expand Up @@ -830,6 +831,7 @@ def predict(
*,
output_type: Literal["mean", "median", "mode"] = "mean",
quantiles: list[float] | None = None,
batch_size_predict: int | None = None,
) -> np.ndarray: ...

@overload
Expand All @@ -839,6 +841,7 @@ def predict(
*,
output_type: Literal["quantiles"],
quantiles: list[float] | None = None,
batch_size_predict: int | None = None,
) -> list[np.ndarray]: ...

@overload
Expand All @@ -848,6 +851,7 @@ def predict(
*,
output_type: Literal["main"],
quantiles: list[float] | None = None,
batch_size_predict: int | None = None,
) -> MainOutputDict: ...

@overload
Expand All @@ -857,6 +861,7 @@ def predict(
*,
output_type: Literal["full"],
quantiles: list[float] | None = None,
batch_size_predict: int | None = None,
) -> FullOutputDict: ...

@config_context(transform_output="default") # type: ignore
Expand All @@ -868,6 +873,7 @@ def predict(
# TODO: support "ei", "pi"
output_type: OutputType = "mean",
quantiles: list[float] | None = None,
batch_size_predict: int | None = None,
) -> RegressionResultType:
"""Runs the forward() method and then transform the logits
from the binning space in order to predict target variable.
Expand All @@ -894,6 +900,9 @@ def predict(
quantiles are returned. The predictions per quantile match
the input order.

batch_size_predict: If not None, split the test data into
chunks of this size and predict each chunk independently.

Returns:
The prediction, which can be a numpy array, a list of arrays (for
quantiles), or a dictionary with detailed outputs.
Expand Down Expand Up @@ -925,6 +934,39 @@ def predict(
X, ord_encoder=getattr(self, "ordinal_encoder_", None)
)

if batch_size_predict is not None:
return predict_in_batches(
lambda chunk: self._predict_core(
chunk, output_type=output_type, quantiles=quantiles
),
X,
batch_size_predict,
concat_fn=lambda results: _concatenate_regression_results(
results, output_type
),
)

return self._predict_core(X, output_type=output_type, quantiles=quantiles)

def _predict_core(
self,
X: XType,
output_type: OutputType,
quantiles: list[float],
) -> RegressionResultType:
"""Core prediction logic on already-preprocessed data.

Runs the forward pass, translates logits, and formats the output.
This method assumes X has already been validated and preprocessed.

Args:
X: The preprocessed input data.
output_type: The type of output to return.
quantiles: The quantiles to compute.

Returns:
The prediction result.
"""
# Runs over iteration engine
with handle_oom_errors(self.devices_, X, model_type="regressor"):
(
Expand Down Expand Up @@ -1241,3 +1283,40 @@ def _logits_to_output(
raise ValueError(f"Invalid output type: {output_type}")

return output.cpu().detach().numpy()


def _concatenate_regression_results(
results: list[RegressionResultType],
output_type: str,
) -> RegressionResultType:
"""Concatenate batched regression prediction results."""
if output_type in _OUTPUT_TYPES_BASIC:
return np.concatenate(results, axis=0)

if output_type == "quantiles":
return [
np.concatenate([r[q] for r in results], axis=0)
for q in range(len(results[0]))
Comment on lines +1298 to +1299
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the _concatenate_regression_results function, when output_type == "quantiles", the code accesses r[q] where r is an element from results (which is RegressionResultType) and q is an integer. This implies r is expected to be a list or array of quantiles. However, RegressionResultType can also be np.ndarray, MainOutputDict, or FullOutputDict. This could lead to a TypeError if r is not indexable by an integer q (e.g., if it's a np.ndarray representing a single output type like 'mean'). It should explicitly check the type or ensure results always contains lists of quantiles when output_type is 'quantiles'.

        return [
            np.concatenate([typing.cast(list[np.ndarray], r)[q] for r in results], axis=0)
            for q in range(len(typing.cast(list[np.ndarray], results[0])))
        ]

]

if output_type in ("main", "full"):
main = MainOutputDict(
mean=np.concatenate([r["mean"] for r in results], axis=0),
median=np.concatenate([r["median"] for r in results], axis=0),
mode=np.concatenate([r["mode"] for r in results], axis=0),
quantiles=[
np.concatenate([r["quantiles"][q] for r in results], axis=0)
for q in range(len(results[0]["quantiles"]))
],
Comment on lines +1304 to +1310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the _concatenate_regression_results function, when output_type is 'main' or 'full', the code accesses r["mean"], r["median"], r["mode"], and r["quantiles"] from r which is an element of results. results is a list[RegressionResultType]. If RegressionResultType is np.ndarray (e.g., if the original predict call returned only 'mean'), then accessing r["mean"] would raise a TypeError. The function should ensure that results contains MainOutputDict or FullOutputDict when processing 'main' or 'full' output types, or handle the np.ndarray case gracefully.

        main = MainOutputDict(
            mean=np.concatenate([typing.cast(MainOutputDict, r)["mean"] for r in results], axis=0),
            median=np.concatenate([typing.cast(MainOutputDict, r)["median"] for r in results], axis=0),
            mode=np.concatenate([typing.cast(MainOutputDict, r)["mode"] for r in results], axis=0),
            quantiles=[
                np.concatenate([typing.cast(MainOutputDict, r)["quantiles"][q] for r in results], axis=0)
                for q in range(len(typing.cast(MainOutputDict, results[0])["quantiles"])) # Cast results[0] to MainOutputDict
            ],
        )

)
if output_type == "main":
return main
# criterion is a model-level attribute (raw_space_bardist_), identical
# across all batches, so we take it from the first result.
return FullOutputDict(
**main,
criterion=results[0]["criterion"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, r["logits"] assumes r is a FullOutputDict. If r is an np.ndarray or MainOutputDict, this will cause an error. Type casting r to FullOutputDict would make this access safe.

            criterion=typing.cast(FullOutputDict, results[0])["criterion"],
            logits=torch.cat([typing.cast(FullOutputDict, r)["logits"] for r in results], dim=0),

logits=torch.cat([r["logits"] for r in results], dim=0),
)

raise ValueError(f"Invalid output type: {output_type}")
Loading
Loading