diff --git a/changelog/802.fixed.md b/changelog/802.fixed.md new file mode 100644 index 000000000..868a435d3 --- /dev/null +++ b/changelog/802.fixed.md @@ -0,0 +1 @@ +Fix inference precision to respect force_inference_dtype in KV cache engine and skip thinking tokens during cache-building. \ No newline at end of file diff --git a/src/tabpfn/architectures/base/transformer.py b/src/tabpfn/architectures/base/transformer.py index 632aba37a..bcfd24889 100644 --- a/src/tabpfn/architectures/base/transformer.py +++ b/src/tabpfn/architectures/base/transformer.py @@ -523,7 +523,10 @@ def forward( # noqa: PLR0912, C901 ) del embedded_y, embedded_x - if self.add_thinking_tokens is not None: + is_kv_cache_prediction = ( + self.cache_trainset_representation and single_eval_pos == 0 + ) + if self.add_thinking_tokens is not None and not is_kv_cache_prediction: embedded_input, single_eval_pos = self.add_thinking_tokens( embedded_input, single_eval_pos, diff --git a/src/tabpfn/inference.py b/src/tabpfn/inference.py index 6b9386653..a5773202f 100644 --- a/src/tabpfn/inference.py +++ b/src/tabpfn/inference.py @@ -13,7 +13,6 @@ from typing_extensions import override import joblib -import numpy as np import torch from tabpfn.architectures.base.memory import ( @@ -30,6 +29,8 @@ from tabpfn.utils import get_autocast_context if TYPE_CHECKING: + import numpy as np + from tabpfn.architectures.interface import Architecture from tabpfn.preprocessing import EnsembleConfig from tabpfn.preprocessing.ensemble import ( @@ -364,7 +365,7 @@ def iter_outputs( y_train=self.y_train, feature_schema=self.feature_schema, parallel_mode="in-order", - override_random_state=np.random.default_rng(self.static_seed), + override_random_state=self.static_seed, ) ) @@ -768,14 +769,20 @@ def __init__( # noqa: PLR0913 ens_model = deepcopy(models[ensemble_member.config._model_index]) ens_model = ens_model.to(device) + if force_inference_dtype is not None: + ens_model = ens_model.type(force_inference_dtype) X = ensemble_member.X_train y = ensemble_member.y_train - if not isinstance(X, torch.Tensor): - X = torch.as_tensor(X, dtype=torch.float32, device=device) + inference_dtype = ( + force_inference_dtype + if force_inference_dtype is not None + else torch.float32 + ) + + X = torch.as_tensor(X, dtype=inference_dtype, device=device) X = X.unsqueeze(1) - if not isinstance(y, torch.Tensor): - y = torch.as_tensor(y, dtype=torch.float32, device=device) + y = torch.as_tensor(y, dtype=inference_dtype, device=device) batched_preprocessor_cat_ix = [ ensemble_member.feature_schema.indices_for(FeatureModality.CATEGORICAL) @@ -825,7 +832,16 @@ def iter_outputs( for ensemble_member, model in zip(self.ensemble_members, self.models): model.to(self.device) X_test = ensemble_member.transform_X_test(X) - X_test = torch.as_tensor(X_test, dtype=torch.float32, device=self.device) + inference_dtype = ( + self.force_inference_dtype + if self.force_inference_dtype is not None + else torch.float32 + ) + X_test = torch.as_tensor( + X_test, + dtype=inference_dtype, + device=self.device, + ) X_test = X_test.unsqueeze(1) batched_cat_ix = [ ensemble_member.feature_schema.indices_for(FeatureModality.CATEGORICAL) diff --git a/tests/reference_predictions/darwin_arm64/classifier_iris_dataset.json b/tests/reference_predictions/darwin_arm64/classifier_iris_dataset.json index 7dc4eaf97..86a771297 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_iris_dataset.json +++ b/tests/reference_predictions/darwin_arm64/classifier_iris_dataset.json @@ -1,17 +1,17 @@ [ [ - 0.9999790191650391, - 1.4922236005077139e-05, - 6.077526904846309e-06 + 0.9999856352806091, + 9.786756891116966e-06, + 4.600969987222925e-06 ], [ - 0.00038013834273442626, - 0.9990183115005493, - 0.0006015249527990818 + 0.0003290163876954466, + 0.9991994500160217, + 0.0004715345858130604 ], [ - 0.00032731943065300584, - 0.0018296991474926472, - 0.9978430271148682 + 0.00027023980510421097, + 0.0014350548153743148, + 0.9982947111129761 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/classifier_iris_dataset_several_devices.json b/tests/reference_predictions/darwin_arm64/classifier_iris_dataset_several_devices.json index 7dc4eaf97..86a771297 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_iris_dataset_several_devices.json +++ b/tests/reference_predictions/darwin_arm64/classifier_iris_dataset_several_devices.json @@ -1,17 +1,17 @@ [ [ - 0.9999790191650391, - 1.4922236005077139e-05, - 6.077526904846309e-06 + 0.9999856352806091, + 9.786756891116966e-06, + 4.600969987222925e-06 ], [ - 0.00038013834273442626, - 0.9990183115005493, - 0.0006015249527990818 + 0.0003290163876954466, + 0.9991994500160217, + 0.0004715345858130604 ], [ - 0.00032731943065300584, - 0.0018296991474926472, - 0.9978430271148682 + 0.00027023980510421097, + 0.0014350548153743148, + 0.9982947111129761 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_5_estimators.json b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_5_estimators.json index 6a0157735..5cb7dd154 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_5_estimators.json +++ b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_5_estimators.json @@ -1,14 +1,14 @@ [ [ 0.7395162582397461, - 0.2604837119579315 + 0.2604838013648987 ], [ - 0.46708282828330994, - 0.5329171419143677 + 0.4670824706554413, + 0.5329175591468811 ], [ - 0.751488208770752, - 0.24851179122924805 + 0.7514878511428833, + 0.2485121190547943 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_differentiable_input.json b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_differentiable_input.json index 14f1a37a5..b28f6ce65 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_differentiable_input.json +++ b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_differentiable_input.json @@ -1,14 +1,14 @@ [ [ - 0.777837, - 0.222163 + 0.7778364419937134, + 0.22216357290744781 ], [ - 0.4193180501461029, - 0.5806819796562195 + 0.41931843757629395, + 0.580681562423706 ], [ 0.6285772323608398, - 0.37142282724380493 + 0.37142279744148254 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_preprocessors.json b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_preprocessors.json index ac41b90fb..3389a7efe 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_preprocessors.json +++ b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_preprocessors.json @@ -1,14 +1,14 @@ [ [ - 0.751238226890564, - 0.24876181781291962 + 0.7512377500534058, + 0.24876223504543304 ], [ - 0.4854810833930969, - 0.5145189166069031 + 0.4854806065559387, + 0.5145193934440613 ], [ - 0.736987978219986, - 0.26301202178001404 + 0.7369879484176636, + 0.2630121111869812 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_with_cache.json b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_with_cache.json index 5678b73d4..cd6a4e92e 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_with_cache.json +++ b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_fit_with_cache.json @@ -1,14 +1,14 @@ [ [ - 0.5095966458320618, - 0.49040335416793823 + 0.7512380480766296, + 0.24876198172569275 ], [ - 0.5075675845146179, - 0.4924324154853821 + 0.4854806661605835, + 0.5145193338394165 ], [ - 0.5023237466812134, - 0.4976762533187866 + 0.7369883060455322, + 0.2630116939544678 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_low_memory.json b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_low_memory.json index ff60dbf4e..3389a7efe 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_low_memory.json +++ b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2.5_low_memory.json @@ -1,14 +1,14 @@ [ [ - 0.7434266805648804, - 0.25657331943511963 + 0.7512377500534058, + 0.24876223504543304 ], [ - 0.46427232027053833, - 0.5357276797294617 + 0.4854806065559387, + 0.5145193934440613 ], [ - 0.7502415180206299, - 0.24975848197937012 + 0.7369879484176636, + 0.2630121111869812 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2_fit_preprocessors.json b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2_fit_preprocessors.json index 1bede1184..07b254ded 100644 --- a/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2_fit_preprocessors.json +++ b/tests/reference_predictions/darwin_arm64/classifier_tiny_dataset_v2_fit_preprocessors.json @@ -1,14 +1,14 @@ [ [ - 0.7953547835350037, - 0.20464521646499634 + 0.7953541874885559, + 0.2046457827091217 ], [ - 0.5345571041107178, - 0.4654428958892822 + 0.5345551371574402, + 0.4654448628425598 ], [ - 0.7656365036964417, - 0.23436351120471954 + 0.765636146068573, + 0.2343638837337494 ] ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_several_devices.json b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_several_devices.json index bc27bdb8c..f4483c3b8 100644 --- a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_several_devices.json +++ b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_several_devices.json @@ -1,5 +1,5 @@ [ - 4.636890411376953, - 4.241406440734863, - 4.2678446769714355 + 4.636868000030518, + 4.241325378417969, + 4.267908096313477 ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_preprocessors.json b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_preprocessors.json index bc27bdb8c..f4483c3b8 100644 --- a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_preprocessors.json +++ b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_preprocessors.json @@ -1,5 +1,5 @@ [ - 4.636890411376953, - 4.241406440734863, - 4.2678446769714355 + 4.636868000030518, + 4.241325378417969, + 4.267908096313477 ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_with_cache.json b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_with_cache.json index 48880dd0f..f4483c3b8 100644 --- a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_with_cache.json +++ b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_fit_with_cache.json @@ -1,5 +1,5 @@ [ - 5.126402854919434, - 5.2795634269714355, - 5.346292972564697 + 4.636868000030518, + 4.241325378417969, + 4.267908096313477 ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_low_memory.json b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_low_memory.json index ee24cab2f..f4483c3b8 100644 --- a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_low_memory.json +++ b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2.5_low_memory.json @@ -1,5 +1,5 @@ [ - 4.582385063171387, - 4.221979141235352, - 4.202066421508789 + 4.636868000030518, + 4.241325378417969, + 4.267908096313477 ] \ No newline at end of file diff --git a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2_fit_preprocessors.json b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2_fit_preprocessors.json index fd86a7013..0b8aef8bc 100644 --- a/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2_fit_preprocessors.json +++ b/tests/reference_predictions/darwin_arm64/regressor_tiny_dataset_v2_fit_preprocessors.json @@ -1,5 +1,5 @@ [ - 4.594064712524414, + 4.594013690948486, 4.312458515167236, - 4.53116512298584 + 4.531164646148682 ] \ No newline at end of file