-
Notifications
You must be signed in to change notification settings - Fork 594
[WIP] Batched Predictions #793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ec04b5d
5d3e726
49126f9
b64ece9
f9b585a
3e48b16
4b97ade
9901c10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Add a batch_size_predict paramter to our .predict() functions that performs automatic batching of the test set. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,13 +69,10 @@ def __init__( | |
|
|
||
| message = ( | ||
| f"{self.device_name} out of memory{size_info}.\n\n" | ||
| f"Solution: Split your test data into smaller batches:\n\n" | ||
| f" batch_size = 1000 # depends on hardware\n" | ||
| f" predictions = []\n" | ||
| f" for i in range(0, len(X_test), batch_size):\n" | ||
| f" batch = model.{predict_method}(X_test[i:i + batch_size])\n" | ||
| f" predictions.append(batch)\n" | ||
| f" predictions = np.vstack(predictions)" | ||
| f"Solution: Use batch_size_predict to split test data" | ||
| f" into smaller batches:\n\n" | ||
| f" model.{predict_method}(" | ||
| f"X_test, batch_size_predict=100)" | ||
|
Comment on lines
+72
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The example code in the error message uses f |
||
| ) | ||
| if original_error is not None: | ||
| message += f"\n\nOriginal error: {original_error}" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ | |
| get_embeddings, | ||
| initialize_model_variables_helper, | ||
| initialize_telemetry, | ||
| predict_in_batches, | ||
| ) | ||
| from tabpfn.constants import REGRESSION_CONSTANT_TARGET_BORDER_EPSILON, ModelVersion | ||
| from tabpfn.errors import TabPFNValidationError, handle_oom_errors | ||
|
|
@@ -830,6 +831,7 @@ def predict( | |
| *, | ||
| output_type: Literal["mean", "median", "mode"] = "mean", | ||
| quantiles: list[float] | None = None, | ||
| batch_size_predict: int | None = None, | ||
| ) -> np.ndarray: ... | ||
|
|
||
| @overload | ||
|
|
@@ -839,6 +841,7 @@ def predict( | |
| *, | ||
| output_type: Literal["quantiles"], | ||
| quantiles: list[float] | None = None, | ||
| batch_size_predict: int | None = None, | ||
| ) -> list[np.ndarray]: ... | ||
|
|
||
| @overload | ||
|
|
@@ -848,6 +851,7 @@ def predict( | |
| *, | ||
| output_type: Literal["main"], | ||
| quantiles: list[float] | None = None, | ||
| batch_size_predict: int | None = None, | ||
| ) -> MainOutputDict: ... | ||
|
|
||
| @overload | ||
|
|
@@ -857,6 +861,7 @@ def predict( | |
| *, | ||
| output_type: Literal["full"], | ||
| quantiles: list[float] | None = None, | ||
| batch_size_predict: int | None = None, | ||
| ) -> FullOutputDict: ... | ||
|
|
||
| @config_context(transform_output="default") # type: ignore | ||
|
|
@@ -868,6 +873,7 @@ def predict( | |
| # TODO: support "ei", "pi" | ||
| output_type: OutputType = "mean", | ||
| quantiles: list[float] | None = None, | ||
| batch_size_predict: int | None = None, | ||
| ) -> RegressionResultType: | ||
| """Runs the forward() method and then transform the logits | ||
| from the binning space in order to predict target variable. | ||
|
|
@@ -894,6 +900,9 @@ def predict( | |
| quantiles are returned. The predictions per quantile match | ||
| the input order. | ||
|
|
||
| batch_size_predict: If not None, split the test data into | ||
| chunks of this size and predict each chunk independently. | ||
|
|
||
| Returns: | ||
| The prediction, which can be a numpy array, a list of arrays (for | ||
| quantiles), or a dictionary with detailed outputs. | ||
|
|
@@ -925,6 +934,39 @@ def predict( | |
| X, ord_encoder=getattr(self, "ordinal_encoder_", None) | ||
| ) | ||
|
|
||
| if batch_size_predict is not None: | ||
| return predict_in_batches( | ||
| lambda chunk: self._predict_core( | ||
| chunk, output_type=output_type, quantiles=quantiles | ||
| ), | ||
| X, | ||
| batch_size_predict, | ||
| concat_fn=lambda results: _concatenate_regression_results( | ||
| results, output_type | ||
| ), | ||
| ) | ||
|
|
||
| return self._predict_core(X, output_type=output_type, quantiles=quantiles) | ||
|
|
||
| def _predict_core( | ||
| self, | ||
| X: XType, | ||
| output_type: OutputType, | ||
| quantiles: list[float], | ||
| ) -> RegressionResultType: | ||
| """Core prediction logic on already-preprocessed data. | ||
|
|
||
| Runs the forward pass, translates logits, and formats the output. | ||
| This method assumes X has already been validated and preprocessed. | ||
|
|
||
| Args: | ||
| X: The preprocessed input data. | ||
| output_type: The type of output to return. | ||
| quantiles: The quantiles to compute. | ||
|
|
||
| Returns: | ||
| The prediction result. | ||
| """ | ||
| # Runs over iteration engine | ||
| with handle_oom_errors(self.devices_, X, model_type="regressor"): | ||
| ( | ||
|
|
@@ -1241,3 +1283,40 @@ def _logits_to_output( | |
| raise ValueError(f"Invalid output type: {output_type}") | ||
|
|
||
| return output.cpu().detach().numpy() | ||
|
|
||
|
|
||
| def _concatenate_regression_results( | ||
| results: list[RegressionResultType], | ||
| output_type: str, | ||
| ) -> RegressionResultType: | ||
| """Concatenate batched regression prediction results.""" | ||
| if output_type in _OUTPUT_TYPES_BASIC: | ||
| return np.concatenate(results, axis=0) | ||
|
|
||
| if output_type == "quantiles": | ||
| return [ | ||
| np.concatenate([r[q] for r in results], axis=0) | ||
| for q in range(len(results[0])) | ||
|
Comment on lines
+1298
to
+1299
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the return [
np.concatenate([typing.cast(list[np.ndarray], r)[q] for r in results], axis=0)
for q in range(len(typing.cast(list[np.ndarray], results[0])))
] |
||
| ] | ||
|
|
||
| if output_type in ("main", "full"): | ||
| main = MainOutputDict( | ||
| mean=np.concatenate([r["mean"] for r in results], axis=0), | ||
| median=np.concatenate([r["median"] for r in results], axis=0), | ||
| mode=np.concatenate([r["mode"] for r in results], axis=0), | ||
| quantiles=[ | ||
| np.concatenate([r["quantiles"][q] for r in results], axis=0) | ||
| for q in range(len(results[0]["quantiles"])) | ||
| ], | ||
|
Comment on lines
+1304
to
+1310
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the main = MainOutputDict(
mean=np.concatenate([typing.cast(MainOutputDict, r)["mean"] for r in results], axis=0),
median=np.concatenate([typing.cast(MainOutputDict, r)["median"] for r in results], axis=0),
mode=np.concatenate([typing.cast(MainOutputDict, r)["mode"] for r in results], axis=0),
quantiles=[
np.concatenate([typing.cast(MainOutputDict, r)["quantiles"][q] for r in results], axis=0)
for q in range(len(typing.cast(MainOutputDict, results[0])["quantiles"])) # Cast results[0] to MainOutputDict
],
) |
||
| ) | ||
| if output_type == "main": | ||
| return main | ||
| # criterion is a model-level attribute (raw_space_bardist_), identical | ||
| # across all batches, so we take it from the first result. | ||
| return FullOutputDict( | ||
| **main, | ||
| criterion=results[0]["criterion"], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, criterion=typing.cast(FullOutputDict, results[0])["criterion"],
logits=torch.cat([typing.cast(FullOutputDict, r)["logits"] for r in results], dim=0), |
||
| logits=torch.cat([r["logits"] for r in results], dim=0), | ||
| ) | ||
|
|
||
| raise ValueError(f"Invalid output type: {output_type}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
predict_in_batchesfunction is a good addition for handling large datasets. However, the type hint forXisXType, which is a genericAny. It would be more precise to usenp.ndarrayortorch.TensorasXis indexed directly, which implies it's an array-like object. This improves type safety and readability.