From 2d60436c9f8715c4b36fe9fe52a28a92366e710d Mon Sep 17 00:00:00 2001 From: balaji Date: Wed, 21 Jan 2026 02:52:30 +0000 Subject: [PATCH 1/5] refactor executor forward loop Signed-off-by: balaji --- src/tabpfn/regressor.py | 71 ++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 29ade4071..cbbdd053f 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -23,7 +23,7 @@ from collections.abc import Sequence from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal, Union +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union, Iterator from typing_extensions import Self, TypedDict, deprecated, overload import numpy as np @@ -982,31 +982,12 @@ def predict( return logit_to_output(output_type=output_type) - def forward( + def _iter_forward_executor( self, X: list[torch.Tensor] | XType, *, use_inference_mode: bool = False, - ) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]: - """Forward pass for TabPFNRegressor Inference Engine. - Used in fine-tuning and prediction. Called directly - in FineTuning training loop or by predict() function - with the use_inference_mode flag explicitly set to True. - - Iterates over outputs of InferenceEngine. - - Args: - X: list[torch.Tensor] in fine-tuning, XType in normal predictions. - use_inference_mode: Flag for inference mode., default at False since - it is called within predict. During FineTuning forward() is called - directly by user, so default should be False here. - - Returns: - A tuple containing: - - Averaged logits over the ensemble (for fine-tuning). - - Raw outputs from each estimator in the ensemble. - - Borders used for each estimator. - """ + ) -> Iterator[tuple[np.ndaarray, torch.Tensor]]: # Scenario 1: Standard inference path is_standard_inference = use_inference_mode and not isinstance( self.executor_, InferenceEngineBatchedNoPreprocessing @@ -1036,18 +1017,12 @@ def forward( "fine-tuning workflow (requires float32 for backpropagation)." ) + check_is_fitted(self) # Ensure torch.inference_mode is OFF to allow gradients if self.fit_mode in ["fit_preprocessors", "batched"]: # only these two modes support this option self.executor_.use_torch_inference_mode(use_inference=use_inference_mode) - - check_is_fitted(self) - std_borders = self.znorm_space_bardist_.borders.cpu().numpy() - outputs: list[torch.Tensor] = [] - borders: list[np.ndarray] = [] - - # Iterate over estimators for output, config in self.executor_.iter_outputs( X, autocast=self.use_autocast_ ): @@ -1091,19 +1066,49 @@ def forward( if descending_borders: borders_t = borders_t.flip(-1) # type: ignore - borders.append(borders_t) - if logit_cancel_mask is not None: output = output.clone() # noqa: PLW2901 output[..., logit_cancel_mask] = float("-inf") - + yield borders_t, output else: raise ValueError( "Unexpected config format " "and Batch prediction is not supported yet!" ) - outputs.append(output) # type: ignore + def forward( + self, + X: list[torch.Tensor] | XType, + *, + use_inference_mode: bool = False, + ) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]: + """Forward pass for TabPFNRegressor Inference Engine. + Used in fine-tuning and prediction. Called directly + in FineTuning training loop or by predict() function + with the use_inference_mode flag explicitly set to True. + + Iterates over outputs of InferenceEngine. + + Args: + X: list[torch.Tensor] in fine-tuning, XType in normal predictions. + use_inference_mode: Flag for inference mode., default at False since + it is called within predict. During FineTuning forward() is called + directly by user, so default should be False here. + + Returns: + A tuple containing: + - Averaged logits over the ensemble (for fine-tuning). + - Raw outputs from each estimator in the ensemble. + - Borders used for each estimator. + """ + outputs: list[torch.Tensor] = [] + borders: list[np.ndarray] = [] + + for border, output in self._iter_forward_executor( + X, use_inference_mode=use_inference_mode + ): + borders.append(border) + outputs.append(output) averaged_logits = None all_logits = None From 99b9ac5e461ae748c1ae7fbb4f12ee51e9139293 Mon Sep 17 00:00:00 2001 From: balaji Date: Wed, 21 Jan 2026 03:04:55 +0000 Subject: [PATCH 2/5] Optimize regressor predict method for memory efficiency - Use _iter_forward_executor directly instead of forward method - Transform probabilities across borders inside the loop - Average ensemble outputs on-the-fly instead of accumulating all outputs This reduces memory usage by avoiding storage of all intermediate outputs, especially beneficial for large n_estimators. Co-Authored-By: glm4.5 --- src/tabpfn/regressor.py | 45 +++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index cbbdd053f..e0aa941fe 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -913,29 +913,34 @@ def predict( ) # Runs over iteration engine + + n_estimators = 0 + averaged_logits: torch.Tensor | None = None with handle_oom_errors(self.devices_, X, model_type="regressor"): - ( - _, - # list of tensors [N_est, N_samples, N_borders] (after forward) - outputs, - # list of numpy arrays containing borders for each estimator - borders, - ) = self.forward(X, use_inference_mode=True) - - # --- Translate probs, average, get final logits --- - transformed_logits = [ - translate_probs_across_borders( - logits, - frm=torch.as_tensor(borders_t, device=logits.device), - to=self.znorm_space_bardist_.borders.to(logits.device), - ) - for logits, borders_t in zip(outputs, borders) - ] - stacked_logits = torch.stack(transformed_logits, dim=0) + for borders_t, output in self._iter_forward_executor( + X, use_inference_mode=True + ): + # Transform probabilities across borders + transformed = translate_probs_across_borders( + output, + frm=torch.as_tensor(borders_t, device=output.device), + to=self.znorm_space_bardist_.borders.to(output.device), + ) + + if self.average_before_softmax: + transformed = transformed.log() + + if averaged_logits is None: + averaged_logits = transformed + else: + averaged_logits = averaged_logits + transformed + n_estimators += 1 + + # Finalize averaging if self.average_before_softmax: - logits = stacked_logits.log().mean(dim=0).softmax(dim=-1) + logits = (averaged_logits / n_estimators).softmax(dim=-1) # type: ignore else: - logits = stacked_logits.mean(dim=0) + logits = averaged_logits / n_estimators # type: ignore # Post-process the logits logits = logits.log() From 866e299dde28c3c029b707666d625de0bccbadca Mon Sep 17 00:00:00 2001 From: balaji Date: Thu, 22 Jan 2026 09:01:35 +0000 Subject: [PATCH 3/5] add changelog --- changelog/745.changed.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changelog/745.changed.md diff --git a/changelog/745.changed.md b/changelog/745.changed.md new file mode 100644 index 000000000..f48b0f7c0 --- /dev/null +++ b/changelog/745.changed.md @@ -0,0 +1,3 @@ +- Optimize regressor predict method for memory efficiency + - Average ensemble outputs on-the-fly instead of accumulating all outputs + - Reduces memory usage by avoiding storage of all intermediate outputs, especially beneficial for large `n_estimators` From 962b2a384a7c3d448637f1930a2743cdd134a538 Mon Sep 17 00:00:00 2001 From: balaji Date: Wed, 4 Feb 2026 06:42:25 +0000 Subject: [PATCH 4/5] resovle precommit issueo Signed-off-by: balaji --- src/tabpfn/regressor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index e0aa941fe..a880b8701 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -20,10 +20,10 @@ import logging import typing import warnings -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal, Union, Iterator +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union from typing_extensions import Self, TypedDict, deprecated, overload import numpy as np @@ -849,7 +849,7 @@ def predict( @config_context(transform_output="default") # type: ignore @track_model_call(model_method="predict", param_names=["X"]) - def predict( + def predict( # noqa: C901, PLR0912 self, X: XType, *, From c076ee9987da6cad1869033fed06b74f5d55a8bc Mon Sep 17 00:00:00 2001 From: balaji Date: Wed, 4 Feb 2026 08:04:02 +0000 Subject: [PATCH 5/5] fix ndarray. I feel delusional now --- src/tabpfn/regressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index a880b8701..779e1fd34 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -992,7 +992,7 @@ def _iter_forward_executor( X: list[torch.Tensor] | XType, *, use_inference_mode: bool = False, - ) -> Iterator[tuple[np.ndaarray, torch.Tensor]]: + ) -> Iterator[tuple[np.ndarray, torch.Tensor]]: # Scenario 1: Standard inference path is_standard_inference = use_inference_mode and not isinstance( self.executor_, InferenceEngineBatchedNoPreprocessing