Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/models/foundation/test_chronos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch


def test_chronos_default_dtype_is_float32():
"""Ensure Chronos defaults to float32 dtype."""
from timecopilot.models.foundation.chronos import Chronos

model = Chronos(repo_id="amazon/chronos-t5-tiny")
assert model.dtype == torch.float32


def test_chronos_model_uses_configured_dtype(mocker):
"""Ensure Chronos loads models with the configured dtype."""
mock_pipeline = mocker.patch(
"timecopilot.models.foundation.chronos.BaseChronosPipeline.from_pretrained"
)
mocker.patch("torch.cuda.is_available", return_value=False)

from timecopilot.models.foundation.chronos import Chronos

# Test default (float32)
model = Chronos(repo_id="amazon/chronos-t5-tiny")
with model._get_model():
pass
call_kwargs = mock_pipeline.call_args[1]
assert call_kwargs["torch_dtype"] == torch.float32

# Test custom dtype (bfloat16)
mock_pipeline.reset_mock()
model_bf16 = Chronos(repo_id="amazon/chronos-t5-tiny", dtype=torch.bfloat16)
with model_bf16._get_model():
pass
call_kwargs = mock_pipeline.call_args[1]
assert call_kwargs["torch_dtype"] == torch.bfloat16


def test_chronos_forecast_uses_configured_dtype(mocker):
"""Ensure Chronos.forecast uses the configured dtype for dataset creation."""
import pandas as pd
import pytest

from timecopilot.models.foundation.chronos import Chronos

# Patch dataset creation to capture dtype argument
mock_from_df = mocker.patch(
"timecopilot.models.foundation.chronos.TimeSeriesDataset.from_df"
)

# Avoid real model loading and CUDA branching
mocker.patch(
"timecopilot.models.foundation.chronos.BaseChronosPipeline.from_pretrained"
)
mocker.patch("torch.cuda.is_available", return_value=False)

model_dtype = torch.bfloat16
model = Chronos(repo_id="amazon/chronos-t5-tiny", dtype=model_dtype)

df = pd.DataFrame(
{
"unique_id": ["A"] * 10,
"ds": pd.date_range("2020-01-01", periods=10),
"y": range(10),
}
)

def _from_df_side_effect(*args, **kwargs):
# Assert that Chronos.forecast passes the configured dtype through
assert kwargs.get("dtype") == model_dtype
# Short-circuit the rest of the forecast call
raise RuntimeError("stop after dtype check")

mock_from_df.side_effect = _from_df_side_effect

with pytest.raises(RuntimeError, match="stop after dtype check"):
model.forecast(df=df, h=2)
33 changes: 33 additions & 0 deletions tests/models/foundation/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

from timecopilot.models.foundation.utils import TimeSeriesDataset


def test_timeseries_dataset_class_default_dtype_is_bfloat16():
"""Ensure TimeSeriesDataset defaults to bfloat16 for backward compatibility."""
import pandas as pd

df = pd.DataFrame(
{
"unique_id": ["A"] * 10,
"ds": pd.date_range("2020-01-01", periods=10),
"y": range(10),
}
)
dataset = TimeSeriesDataset.from_df(df, batch_size=10)
assert dataset.data[0].dtype == torch.bfloat16


def test_timeseries_dataset_respects_custom_dtype():
"""Ensure TimeSeriesDataset respects custom dtype parameter."""
import pandas as pd

df = pd.DataFrame(
{
"unique_id": ["A"] * 10,
"ds": pd.date_range("2020-01-01", periods=10),
"y": range(10),
}
)
dataset = TimeSeriesDataset.from_df(df, batch_size=10, dtype=torch.float32)
assert dataset.data[0].dtype == torch.float32
16 changes: 12 additions & 4 deletions timecopilot/models/foundation/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
repo_id: str = "amazon/chronos-t5-large",
batch_size: int = 16,
alias: str = "Chronos",
dtype: torch.dtype = torch.float32,
):
# ruff: noqa: E501
"""
Expand All @@ -45,6 +46,10 @@ def __init__(
higher batch sizes (e.g., 256) are possible.
alias (str, optional): Name to use for the model in output
DataFrames and logs. Defaults to "Chronos".
dtype (torch.dtype, optional): Data type for model weights and
input tensors. Defaults to torch.float32 for numerical
precision. Use torch.bfloat16 for reduced memory usage on
supported hardware.

Notes:
**Available models:**
Expand Down Expand Up @@ -77,21 +82,22 @@ def __init__(
available, otherwise CPU).
- For best performance with large models (e.g., "chronos-t5-large"),
a CUDA-compatible GPU is recommended.
- The model weights are loaded with torch_dtype=torch.bfloat16 for
efficiency on supported hardware.
- Model weights and input tensors use dtype (default: torch.float32)
for numerical precision. Can be overridden via the dtype parameter.

"""
self.repo_id = repo_id
self.batch_size = batch_size
self.alias = alias
self.dtype = dtype

@contextmanager
def _get_model(self) -> BaseChronosPipeline:
device_map = "cuda:0" if torch.cuda.is_available() else "cpu"
model = BaseChronosPipeline.from_pretrained(
self.repo_id,
device_map=device_map,
torch_dtype=torch.bfloat16,
torch_dtype=self.dtype,
)
try:
yield model
Expand Down Expand Up @@ -218,7 +224,9 @@ def forecast(
"""
freq = self._maybe_infer_freq(df, freq)
qc = QuantileConverter(level=level, quantiles=quantiles)
dataset = TimeSeriesDataset.from_df(df, batch_size=self.batch_size)
dataset = TimeSeriesDataset.from_df(
df, batch_size=self.batch_size, dtype=self.dtype
)
fcst_df = dataset.make_future_dataframe(h=h, freq=freq)
with self._get_model() as model:
fcsts_mean_np, fcsts_quantiles_np = self._predict(
Expand Down