fix: change Chronos default dtype from bfloat16 to float32#309
fix: change Chronos default dtype from bfloat16 to float32#309
Conversation
Using bfloat16 can cause precision issues that significantly impact forecast accuracy. As reported in time-bench benchmarks, switching to float32 improved Chronos model rankings from 5th to 1st place in terms of MASE. Closes #307
There was a problem hiding this comment.
Pull request overview
Updates Chronos model loading defaults to prioritize numerical precision by switching the default torch_dtype from bfloat16 to float32, addressing reported accuracy regressions tied to reduced-precision weights.
Changes:
- Update Chronos
_get_model()to load model weights withtorch_dtype=torch.float32. - Update the Chronos class docstring to reflect the new default dtype and rationale.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
This PR changes a default that affects runtime behavior (memory/precision). There doesn’t appear to be a unit test asserting the torch_dtype passed to BaseChronosPipeline.from_pretrained; adding a small mocked test would help prevent regressions (e.g., accidentally reverting to bfloat16).
Address review comments: - Change TimeSeriesDataset.from_df default dtype from bfloat16 to float32 to ensure inputs match model precision - Add unit tests to prevent dtype regression
|
Addressed review comments in ce3eb38:
|
AzulGarza
left a comment
There was a problem hiding this comment.
@rebot-eng thanks! i think it's best if we add the attribute self.dtype as in flowstate defaulted to float32.
- Add dtype parameter to __init__ (default: torch.float32) - Use self.dtype consistently for model loading and dataset creation - Follows same pattern as FlowState for consistency - Update tests to verify dtype configuration
|
Updated in 9f323c1 — added class Chronos(Forecaster):
def __init__(
self,
repo_id: str = "amazon/chronos-t5-large",
batch_size: int = 16,
alias: str = "Chronos",
dtype: torch.dtype = torch.float32, # new
):
...
self.dtype = dtypeNow uses # Default float32 (recommended)
model = Chronos()
# Explicit bfloat16 for memory-constrained environments
model = Chronos(dtype=torch.bfloat16) |
Revert utils.py default to bfloat16 for backward compatibility. Chronos explicitly passes dtype=self.dtype (float32) to avoid breaking other code that relies on the original default.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
This test is located in test_chronos.py but tests TimeSeriesDataset, which is a shared utility class used by multiple models (Chronos, FlowState, TimesFM, etc.). Consider moving this test to a dedicated test_utils.py file in the tests/models/foundation directory to better reflect its scope and make it more discoverable.
- Move TimeSeriesDataset tests to test_utils.py (shared utility) - Add test_chronos_forecast_uses_configured_dtype to verify dtype is passed correctly to TimeSeriesDataset.from_df
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Summary
Changes the default
torch_dtypefor Chronos models frombfloat16tofloat32to improve numerical precision.Problem
As reported in #307 by @abdulfatir, using
bfloat16can cause precision issues that significantly impact forecast accuracy. In time-bench benchmarks, switching tofloat32improved Chronos model rankings from 5th to 1st place in terms of MASE.Changes
torch_dtypefromtorch.bfloat16totorch.float32in_get_model()Trade-offs
float32uses 2x memory compared tobfloat16The accuracy gains outweigh the performance costs for most use cases.
Closes #307
cc @AzulGarza