diff --git a/tests/models/foundation/test_flowstate.py b/tests/models/foundation/test_flowstate.py new file mode 100644 index 00000000..6e0cfe4b --- /dev/null +++ b/tests/models/foundation/test_flowstate.py @@ -0,0 +1,22 @@ +import numpy as np +import pandas as pd + +from timecopilot import TimeCopilotForecaster +from timecopilot.models.foundation.flowstate import FlowState + + +def test_flowstate_h1_single_uid(): + # create simple weekly data for one unique_id + ds = pd.date_range("2024-01-01", periods=20, freq="W") + df = pd.DataFrame({"unique_id": "u1", "ds": ds, "y": np.arange(20)}) + + tcf = TimeCopilotForecaster(models=[FlowState()]) + + # this used to crash before the fix + fcst = tcf.forecast(df=df, h=1, freq="W") + + # basic checks + assert isinstance(fcst, pd.DataFrame) + assert len(fcst) == 1 + assert "unique_id" in fcst.columns + assert "ds" in fcst.columns diff --git a/timecopilot/models/foundation/flowstate.py b/timecopilot/models/foundation/flowstate.py index de9d95b8..493232c3 100644 --- a/timecopilot/models/foundation/flowstate.py +++ b/timecopilot/models/foundation/flowstate.py @@ -167,8 +167,8 @@ def _predict_batch( batch_first=False, ).prediction_outputs fcst = fcst.squeeze(-1).transpose(-1, -2) # now shape is (batch, h, quantiles) - fcst_mean = fcst[..., supported_quantiles.index(0.5)].squeeze() - fcst_mean_np = fcst_mean.detach().numpy() + fcst_mean = fcst[..., supported_quantiles.index(0.5)] + fcst_mean_np = fcst_mean.detach().numpy(force=True) fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None return fcst_mean_np, fcst_quantiles_np