-
Notifications
You must be signed in to change notification settings - Fork 594
Fix inference precision. #802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Fix inference precision to respect force_inference_dtype in KV cache engine and skip thinking tokens during cache-building. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,6 @@ | |
| from typing_extensions import override | ||
|
|
||
| import joblib | ||
| import numpy as np | ||
| import torch | ||
|
|
||
| from tabpfn.architectures.base.memory import ( | ||
|
|
@@ -30,6 +29,8 @@ | |
| from tabpfn.utils import get_autocast_context | ||
|
|
||
| if TYPE_CHECKING: | ||
| import numpy as np | ||
|
|
||
| from tabpfn.architectures.interface import Architecture | ||
| from tabpfn.preprocessing import EnsembleConfig | ||
| from tabpfn.preprocessing.ensemble import ( | ||
|
|
@@ -364,7 +365,7 @@ def iter_outputs( | |
| y_train=self.y_train, | ||
| feature_schema=self.feature_schema, | ||
| parallel_mode="in-order", | ||
| override_random_state=np.random.default_rng(self.static_seed), | ||
| override_random_state=self.static_seed, | ||
| ) | ||
|
Comment on lines
365
to
369
|
||
| ) | ||
|
|
||
|
|
@@ -768,14 +769,20 @@ def __init__( # noqa: PLR0913 | |
|
|
||
| ens_model = deepcopy(models[ensemble_member.config._model_index]) | ||
| ens_model = ens_model.to(device) | ||
| if force_inference_dtype is not None: | ||
| ens_model = ens_model.type(force_inference_dtype) | ||
| X = ensemble_member.X_train | ||
| y = ensemble_member.y_train | ||
|
|
||
| if not isinstance(X, torch.Tensor): | ||
| X = torch.as_tensor(X, dtype=torch.float32, device=device) | ||
| inference_dtype = ( | ||
| force_inference_dtype | ||
| if force_inference_dtype is not None | ||
| else torch.float32 | ||
| ) | ||
|
Comment on lines
+777
to
+781
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| X = torch.as_tensor(X, dtype=inference_dtype, device=device) | ||
| X = X.unsqueeze(1) | ||
| if not isinstance(y, torch.Tensor): | ||
| y = torch.as_tensor(y, dtype=torch.float32, device=device) | ||
| y = torch.as_tensor(y, dtype=inference_dtype, device=device) | ||
|
|
||
| batched_preprocessor_cat_ix = [ | ||
| ensemble_member.feature_schema.indices_for(FeatureModality.CATEGORICAL) | ||
|
|
@@ -825,7 +832,16 @@ def iter_outputs( | |
| for ensemble_member, model in zip(self.ensemble_members, self.models): | ||
| model.to(self.device) | ||
| X_test = ensemble_member.transform_X_test(X) | ||
| X_test = torch.as_tensor(X_test, dtype=torch.float32, device=self.device) | ||
| inference_dtype = ( | ||
| self.force_inference_dtype | ||
| if self.force_inference_dtype is not None | ||
| else torch.float32 | ||
| ) | ||
|
Comment on lines
+835
to
+839
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| X_test = torch.as_tensor( | ||
| X_test, | ||
| dtype=inference_dtype, | ||
| device=self.device, | ||
| ) | ||
| X_test = X_test.unsqueeze(1) | ||
| batched_cat_ix = [ | ||
| ensemble_member.feature_schema.indices_for(FeatureModality.CATEGORICAL) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,17 @@ | ||
| [ | ||
| [ | ||
| 0.9999790191650391, | ||
| 1.4922236005077139e-05, | ||
| 6.077526904846309e-06 | ||
| 0.9999856352806091, | ||
| 9.786756891116966e-06, | ||
| 4.600969987222925e-06 | ||
| ], | ||
| [ | ||
| 0.00038013834273442626, | ||
| 0.9990183115005493, | ||
| 0.0006015249527990818 | ||
| 0.0003290163876954466, | ||
| 0.9991994500160217, | ||
| 0.0004715345858130604 | ||
| ], | ||
| [ | ||
| 0.00032731943065300584, | ||
| 0.0018296991474926472, | ||
| 0.9978430271148682 | ||
| 0.00027023980510421097, | ||
| 0.0014350548153743148, | ||
| 0.9982947111129761 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,17 @@ | ||
| [ | ||
| [ | ||
| 0.9999790191650391, | ||
| 1.4922236005077139e-05, | ||
| 6.077526904846309e-06 | ||
| 0.9999856352806091, | ||
| 9.786756891116966e-06, | ||
| 4.600969987222925e-06 | ||
| ], | ||
| [ | ||
| 0.00038013834273442626, | ||
| 0.9990183115005493, | ||
| 0.0006015249527990818 | ||
| 0.0003290163876954466, | ||
| 0.9991994500160217, | ||
| 0.0004715345858130604 | ||
| ], | ||
| [ | ||
| 0.00032731943065300584, | ||
| 0.0018296991474926472, | ||
| 0.9978430271148682 | ||
| 0.00027023980510421097, | ||
| 0.0014350548153743148, | ||
| 0.9982947111129761 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| [ | ||
| [ | ||
| 0.7395162582397461, | ||
| 0.2604837119579315 | ||
| 0.2604838013648987 | ||
| ], | ||
| [ | ||
| 0.46708282828330994, | ||
| 0.5329171419143677 | ||
| 0.4670824706554413, | ||
| 0.5329175591468811 | ||
| ], | ||
| [ | ||
| 0.751488208770752, | ||
| 0.24851179122924805 | ||
| 0.7514878511428833, | ||
| 0.2485121190547943 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| [ | ||
| [ | ||
| 0.777837, | ||
| 0.222163 | ||
| 0.7778364419937134, | ||
| 0.22216357290744781 | ||
| ], | ||
| [ | ||
| 0.4193180501461029, | ||
| 0.5806819796562195 | ||
| 0.41931843757629395, | ||
| 0.580681562423706 | ||
| ], | ||
| [ | ||
| 0.6285772323608398, | ||
| 0.37142282724380493 | ||
| 0.37142279744148254 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| [ | ||
| [ | ||
| 0.751238226890564, | ||
| 0.24876181781291962 | ||
| 0.7512377500534058, | ||
| 0.24876223504543304 | ||
| ], | ||
| [ | ||
| 0.4854810833930969, | ||
| 0.5145189166069031 | ||
| 0.4854806065559387, | ||
| 0.5145193934440613 | ||
| ], | ||
| [ | ||
| 0.736987978219986, | ||
| 0.26301202178001404 | ||
| 0.7369879484176636, | ||
| 0.2630121111869812 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| [ | ||
| [ | ||
| 0.5095966458320618, | ||
| 0.49040335416793823 | ||
| 0.7512380480766296, | ||
| 0.24876198172569275 | ||
| ], | ||
| [ | ||
| 0.5075675845146179, | ||
| 0.4924324154853821 | ||
| 0.4854806661605835, | ||
| 0.5145193338394165 | ||
| ], | ||
| [ | ||
| 0.5023237466812134, | ||
| 0.4976762533187866 | ||
| 0.7369883060455322, | ||
| 0.2630116939544678 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| [ | ||
| [ | ||
| 0.7434266805648804, | ||
| 0.25657331943511963 | ||
| 0.7512377500534058, | ||
| 0.24876223504543304 | ||
| ], | ||
| [ | ||
| 0.46427232027053833, | ||
| 0.5357276797294617 | ||
| 0.4854806065559387, | ||
| 0.5145193934440613 | ||
| ], | ||
| [ | ||
| 0.7502415180206299, | ||
| 0.24975848197937012 | ||
| 0.7369879484176636, | ||
| 0.2630121111869812 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,14 @@ | ||
| [ | ||
| [ | ||
| 0.7953547835350037, | ||
| 0.20464521646499634 | ||
| 0.7953541874885559, | ||
| 0.2046457827091217 | ||
| ], | ||
| [ | ||
| 0.5345571041107178, | ||
| 0.4654428958892822 | ||
| 0.5345551371574402, | ||
| 0.4654448628425598 | ||
| ], | ||
| [ | ||
| 0.7656365036964417, | ||
| 0.23436351120471954 | ||
| 0.765636146068573, | ||
| 0.2343638837337494 | ||
| ] | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| [ | ||
| 4.636890411376953, | ||
| 4.241406440734863, | ||
| 4.2678446769714355 | ||
| 4.636868000030518, | ||
| 4.241325378417969, | ||
| 4.267908096313477 | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| [ | ||
| 4.636890411376953, | ||
| 4.241406440734863, | ||
| 4.2678446769714355 | ||
| 4.636868000030518, | ||
| 4.241325378417969, | ||
| 4.267908096313477 | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| [ | ||
| 5.126402854919434, | ||
| 5.2795634269714355, | ||
| 5.346292972564697 | ||
| 4.636868000030518, | ||
| 4.241325378417969, | ||
| 4.267908096313477 | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| [ | ||
| 4.582385063171387, | ||
| 4.221979141235352, | ||
| 4.202066421508789 | ||
| 4.636868000030518, | ||
| 4.241325378417969, | ||
| 4.267908096313477 | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| [ | ||
| 4.594064712524414, | ||
| 4.594013690948486, | ||
| 4.312458515167236, | ||
| 4.53116512298584 | ||
| 4.531164646148682 | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change alters when thinking tokens are added (they’re skipped for KV-cache prediction when
single_eval_pos==0). There’s currently no test covering this specific behavior/contract (e.g., thatfit_with_cacheprediction path doesn’t append thinking tokens and stays consistent with other fit modes for a fixed seed). Please add/adjust a unit/integration test to lock this in—re-enabling the existing skipped “fit modes return equal results” tests (or adding a targeted regression test for #631) would help prevent regressions.