diff --git a/tests/models/foundation/test_chronos.py b/tests/models/foundation/test_chronos.py new file mode 100644 index 00000000..aa02afc2 --- /dev/null +++ b/tests/models/foundation/test_chronos.py @@ -0,0 +1,75 @@ +import torch + + +def test_chronos_default_dtype_is_float32(): + """Ensure Chronos defaults to float32 dtype.""" + from timecopilot.models.foundation.chronos import Chronos + + model = Chronos(repo_id="amazon/chronos-t5-tiny") + assert model.dtype == torch.float32 + + +def test_chronos_model_uses_configured_dtype(mocker): + """Ensure Chronos loads models with the configured dtype.""" + mock_pipeline = mocker.patch( + "timecopilot.models.foundation.chronos.BaseChronosPipeline.from_pretrained" + ) + mocker.patch("torch.cuda.is_available", return_value=False) + + from timecopilot.models.foundation.chronos import Chronos + + # Test default (float32) + model = Chronos(repo_id="amazon/chronos-t5-tiny") + with model._get_model(): + pass + call_kwargs = mock_pipeline.call_args[1] + assert call_kwargs["torch_dtype"] == torch.float32 + + # Test custom dtype (bfloat16) + mock_pipeline.reset_mock() + model_bf16 = Chronos(repo_id="amazon/chronos-t5-tiny", dtype=torch.bfloat16) + with model_bf16._get_model(): + pass + call_kwargs = mock_pipeline.call_args[1] + assert call_kwargs["torch_dtype"] == torch.bfloat16 + + +def test_chronos_forecast_uses_configured_dtype(mocker): + """Ensure Chronos.forecast uses the configured dtype for dataset creation.""" + import pandas as pd + import pytest + + from timecopilot.models.foundation.chronos import Chronos + + # Patch dataset creation to capture dtype argument + mock_from_df = mocker.patch( + "timecopilot.models.foundation.chronos.TimeSeriesDataset.from_df" + ) + + # Avoid real model loading and CUDA branching + mocker.patch( + "timecopilot.models.foundation.chronos.BaseChronosPipeline.from_pretrained" + ) + mocker.patch("torch.cuda.is_available", return_value=False) + + model_dtype = torch.bfloat16 + model = Chronos(repo_id="amazon/chronos-t5-tiny", dtype=model_dtype) + + df = pd.DataFrame( + { + "unique_id": ["A"] * 10, + "ds": pd.date_range("2020-01-01", periods=10), + "y": range(10), + } + ) + + def _from_df_side_effect(*args, **kwargs): + # Assert that Chronos.forecast passes the configured dtype through + assert kwargs.get("dtype") == model_dtype + # Short-circuit the rest of the forecast call + raise RuntimeError("stop after dtype check") + + mock_from_df.side_effect = _from_df_side_effect + + with pytest.raises(RuntimeError, match="stop after dtype check"): + model.forecast(df=df, h=2) diff --git a/tests/models/foundation/test_utils.py b/tests/models/foundation/test_utils.py new file mode 100644 index 00000000..c51e6574 --- /dev/null +++ b/tests/models/foundation/test_utils.py @@ -0,0 +1,33 @@ +import torch + +from timecopilot.models.foundation.utils import TimeSeriesDataset + + +def test_timeseries_dataset_class_default_dtype_is_bfloat16(): + """Ensure TimeSeriesDataset defaults to bfloat16 for backward compatibility.""" + import pandas as pd + + df = pd.DataFrame( + { + "unique_id": ["A"] * 10, + "ds": pd.date_range("2020-01-01", periods=10), + "y": range(10), + } + ) + dataset = TimeSeriesDataset.from_df(df, batch_size=10) + assert dataset.data[0].dtype == torch.bfloat16 + + +def test_timeseries_dataset_respects_custom_dtype(): + """Ensure TimeSeriesDataset respects custom dtype parameter.""" + import pandas as pd + + df = pd.DataFrame( + { + "unique_id": ["A"] * 10, + "ds": pd.date_range("2020-01-01", periods=10), + "y": range(10), + } + ) + dataset = TimeSeriesDataset.from_df(df, batch_size=10, dtype=torch.float32) + assert dataset.data[0].dtype == torch.float32 diff --git a/timecopilot/models/foundation/chronos.py b/timecopilot/models/foundation/chronos.py index 95b29062..d414a9e5 100644 --- a/timecopilot/models/foundation/chronos.py +++ b/timecopilot/models/foundation/chronos.py @@ -28,6 +28,7 @@ def __init__( repo_id: str = "amazon/chronos-t5-large", batch_size: int = 16, alias: str = "Chronos", + dtype: torch.dtype = torch.float32, ): # ruff: noqa: E501 """ @@ -45,6 +46,10 @@ def __init__( higher batch sizes (e.g., 256) are possible. alias (str, optional): Name to use for the model in output DataFrames and logs. Defaults to "Chronos". + dtype (torch.dtype, optional): Data type for model weights and + input tensors. Defaults to torch.float32 for numerical + precision. Use torch.bfloat16 for reduced memory usage on + supported hardware. Notes: **Available models:** @@ -77,13 +82,14 @@ def __init__( available, otherwise CPU). - For best performance with large models (e.g., "chronos-t5-large"), a CUDA-compatible GPU is recommended. - - The model weights are loaded with torch_dtype=torch.bfloat16 for - efficiency on supported hardware. + - Model weights and input tensors use dtype (default: torch.float32) + for numerical precision. Can be overridden via the dtype parameter. """ self.repo_id = repo_id self.batch_size = batch_size self.alias = alias + self.dtype = dtype @contextmanager def _get_model(self) -> BaseChronosPipeline: @@ -91,7 +97,7 @@ def _get_model(self) -> BaseChronosPipeline: model = BaseChronosPipeline.from_pretrained( self.repo_id, device_map=device_map, - torch_dtype=torch.bfloat16, + torch_dtype=self.dtype, ) try: yield model @@ -218,7 +224,9 @@ def forecast( """ freq = self._maybe_infer_freq(df, freq) qc = QuantileConverter(level=level, quantiles=quantiles) - dataset = TimeSeriesDataset.from_df(df, batch_size=self.batch_size) + dataset = TimeSeriesDataset.from_df( + df, batch_size=self.batch_size, dtype=self.dtype + ) fcst_df = dataset.make_future_dataframe(h=h, freq=freq) with self._get_model() as model: fcsts_mean_np, fcsts_quantiles_np = self._predict(