Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/802.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix inference precision to respect force_inference_dtype in KV cache engine and skip thinking tokens during cache-building.
5 changes: 4 additions & 1 deletion src/tabpfn/architectures/base/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,10 @@ def forward( # noqa: PLR0912, C901
)
del embedded_y, embedded_x

if self.add_thinking_tokens is not None:
is_kv_cache_prediction = (
self.cache_trainset_representation and single_eval_pos == 0
)
if self.add_thinking_tokens is not None and not is_kv_cache_prediction:
embedded_input, single_eval_pos = self.add_thinking_tokens(
Comment on lines +526 to 530
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change alters when thinking tokens are added (they’re skipped for KV-cache prediction when single_eval_pos==0). There’s currently no test covering this specific behavior/contract (e.g., that fit_with_cache prediction path doesn’t append thinking tokens and stays consistent with other fit modes for a fixed seed). Please add/adjust a unit/integration test to lock this in—re-enabling the existing skipped “fit modes return equal results” tests (or adding a targeted regression test for #631) would help prevent regressions.

Copilot uses AI. Check for mistakes.
embedded_input,
single_eval_pos,
Expand Down
30 changes: 23 additions & 7 deletions src/tabpfn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing_extensions import override

import joblib
import numpy as np
import torch

from tabpfn.architectures.base.memory import (
Expand All @@ -30,6 +29,8 @@
from tabpfn.utils import get_autocast_context

if TYPE_CHECKING:
import numpy as np

from tabpfn.architectures.interface import Architecture
from tabpfn.preprocessing import EnsembleConfig
from tabpfn.preprocessing.ensemble import (
Expand Down Expand Up @@ -364,7 +365,7 @@ def iter_outputs(
y_train=self.y_train,
feature_schema=self.feature_schema,
parallel_mode="in-order",
override_random_state=np.random.default_rng(self.static_seed),
override_random_state=self.static_seed,
)
Comment on lines 365 to 369
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override_random_state is now passed as an int (self.static_seed). In TabPFNEnsemblePreprocessor.fit_transform_ensemble_members_iterator the random_state is selected via override_random_state or self.random_state, which will ignore an override of 0 (since 0 is falsy) and fall back to self.random_state, reintroducing non-deterministic preprocessing across predict calls. Prefer either passing a truthy override (e.g., a np.random.Generator like before) or (better) changing the downstream selection to override_random_state if override_random_state is not None else self.random_state so that seed 0 is respected.

Copilot uses AI. Check for mistakes.
)

Expand Down Expand Up @@ -768,14 +769,20 @@ def __init__( # noqa: PLR0913

ens_model = deepcopy(models[ensemble_member.config._model_index])
ens_model = ens_model.to(device)
if force_inference_dtype is not None:
ens_model = ens_model.type(force_inference_dtype)
X = ensemble_member.X_train
y = ensemble_member.y_train

if not isinstance(X, torch.Tensor):
X = torch.as_tensor(X, dtype=torch.float32, device=device)
inference_dtype = (
force_inference_dtype
if force_inference_dtype is not None
else torch.float32
)
Comment on lines +777 to +781
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for determining inference_dtype can be made more concise. Since torch.dtype objects are not falsy, you can use the or operator to simplify this assignment.

            inference_dtype = force_inference_dtype or torch.float32


X = torch.as_tensor(X, dtype=inference_dtype, device=device)
X = X.unsqueeze(1)
if not isinstance(y, torch.Tensor):
y = torch.as_tensor(y, dtype=torch.float32, device=device)
y = torch.as_tensor(y, dtype=inference_dtype, device=device)

batched_preprocessor_cat_ix = [
ensemble_member.feature_schema.indices_for(FeatureModality.CATEGORICAL)
Expand Down Expand Up @@ -825,7 +832,16 @@ def iter_outputs(
for ensemble_member, model in zip(self.ensemble_members, self.models):
model.to(self.device)
X_test = ensemble_member.transform_X_test(X)
X_test = torch.as_tensor(X_test, dtype=torch.float32, device=self.device)
inference_dtype = (
self.force_inference_dtype
if self.force_inference_dtype is not None
else torch.float32
)
Comment on lines +835 to +839
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for determining inference_dtype can be simplified for better readability and conciseness. Using the or operator is a more idiomatic way to provide a default value in this case.

            inference_dtype = self.force_inference_dtype or torch.float32

X_test = torch.as_tensor(
X_test,
dtype=inference_dtype,
device=self.device,
)
X_test = X_test.unsqueeze(1)
batched_cat_ix = [
ensemble_member.feature_schema.indices_for(FeatureModality.CATEGORICAL)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
[
[
0.9999790191650391,
1.4922236005077139e-05,
6.077526904846309e-06
0.9999856352806091,
9.786756891116966e-06,
4.600969987222925e-06
],
[
0.00038013834273442626,
0.9990183115005493,
0.0006015249527990818
0.0003290163876954466,
0.9991994500160217,
0.0004715345858130604
],
[
0.00032731943065300584,
0.0018296991474926472,
0.9978430271148682
0.00027023980510421097,
0.0014350548153743148,
0.9982947111129761
]
]
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
[
[
0.9999790191650391,
1.4922236005077139e-05,
6.077526904846309e-06
0.9999856352806091,
9.786756891116966e-06,
4.600969987222925e-06
],
[
0.00038013834273442626,
0.9990183115005493,
0.0006015249527990818
0.0003290163876954466,
0.9991994500160217,
0.0004715345858130604
],
[
0.00032731943065300584,
0.0018296991474926472,
0.9978430271148682
0.00027023980510421097,
0.0014350548153743148,
0.9982947111129761
]
]
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[
[
0.7395162582397461,
0.2604837119579315
0.2604838013648987
],
[
0.46708282828330994,
0.5329171419143677
0.4670824706554413,
0.5329175591468811
],
[
0.751488208770752,
0.24851179122924805
0.7514878511428833,
0.2485121190547943
]
]
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[
[
0.777837,
0.222163
0.7778364419937134,
0.22216357290744781
],
[
0.4193180501461029,
0.5806819796562195
0.41931843757629395,
0.580681562423706
],
[
0.6285772323608398,
0.37142282724380493
0.37142279744148254
]
]
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[
[
0.751238226890564,
0.24876181781291962
0.7512377500534058,
0.24876223504543304
],
[
0.4854810833930969,
0.5145189166069031
0.4854806065559387,
0.5145193934440613
],
[
0.736987978219986,
0.26301202178001404
0.7369879484176636,
0.2630121111869812
]
]
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[
[
0.5095966458320618,
0.49040335416793823
0.7512380480766296,
0.24876198172569275
],
[
0.5075675845146179,
0.4924324154853821
0.4854806661605835,
0.5145193338394165
],
[
0.5023237466812134,
0.4976762533187866
0.7369883060455322,
0.2630116939544678
]
]
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[
[
0.7434266805648804,
0.25657331943511963
0.7512377500534058,
0.24876223504543304
],
[
0.46427232027053833,
0.5357276797294617
0.4854806065559387,
0.5145193934440613
],
[
0.7502415180206299,
0.24975848197937012
0.7369879484176636,
0.2630121111869812
]
]
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[
[
0.7953547835350037,
0.20464521646499634
0.7953541874885559,
0.2046457827091217
],
[
0.5345571041107178,
0.4654428958892822
0.5345551371574402,
0.4654448628425598
],
[
0.7656365036964417,
0.23436351120471954
0.765636146068573,
0.2343638837337494
]
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[
4.636890411376953,
4.241406440734863,
4.2678446769714355
4.636868000030518,
4.241325378417969,
4.267908096313477
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[
4.636890411376953,
4.241406440734863,
4.2678446769714355
4.636868000030518,
4.241325378417969,
4.267908096313477
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[
5.126402854919434,
5.2795634269714355,
5.346292972564697
4.636868000030518,
4.241325378417969,
4.267908096313477
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[
4.582385063171387,
4.221979141235352,
4.202066421508789
4.636868000030518,
4.241325378417969,
4.267908096313477
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[
4.594064712524414,
4.594013690948486,
4.312458515167236,
4.53116512298584
4.531164646148682
]
Loading