Conversation
removed extra squeeze in _predict_batch so flowstate returns 2D arrays for h=1 with single unique_id.
There was a problem hiding this comment.
Pull request overview
Fixes a shape bug in the FlowState foundation model’s batch prediction path so h=1 forecasts don’t collapse dimensions (especially for single-series inputs), preventing downstream failures.
Changes:
- Removed an extra
.squeeze()when extracting the median (0.5) quantile forecast so outputs preserve a consistent 2D(batch, h)shape.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| fcst_mean = fcst[..., supported_quantiles.index(0.5)].squeeze() | ||
| fcst_mean = fcst[..., supported_quantiles.index(0.5)] | ||
| fcst_mean_np = fcst_mean.detach().numpy() | ||
| fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None |
There was a problem hiding this comment.
This change fixes a shape edge case for h=1 and/or batch size 1, but there doesn’t appear to be a regression test covering the previously-crashing scenario (single unique_id, h=1). Adding a small pytest that exercises FlowState.forecast with n_series=1 and h=1 would prevent this from reappearing.
| fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None | |
| # Ensure we never propagate scalar (0D) arrays, which can occur in edge cases | |
| if fcst_mean_np.shape == (): | |
| fcst_mean_np = np.expand_dims(fcst_mean_np, axis=0) | |
| if quantiles is not None: | |
| fcst_quantiles_np = fcst.detach().numpy() | |
| if fcst_quantiles_np.shape == (): | |
| fcst_quantiles_np = np.expand_dims(fcst_quantiles_np, axis=0) | |
| else: | |
| fcst_quantiles_np = None |
There was a problem hiding this comment.
fcst_mean (and fcst) will be on self.device (CUDA when available), so calling .detach().numpy() will raise on GPU. Convert to CPU first (e.g., .detach().cpu().numpy()) for both fcst_mean_np and fcst_quantiles_np to avoid runtime crashes when CUDA is available.
| fcst_mean_np = fcst_mean.detach().numpy() | |
| fcst_quantiles_np = fcst.detach().numpy() if quantiles is not None else None | |
| fcst_mean_np = fcst_mean.detach().cpu().numpy() | |
| fcst_quantiles_np = ( | |
| fcst.detach().cpu().numpy() if quantiles is not None else None | |
| ) |
There was a problem hiding this comment.
An alternative solution is setting the force argument in the numpy() method to true.
e.g.
fcst_mean_np = fcst_mean.detach().numpy(force=True)There was a problem hiding this comment.
thanks @spolisar! @Kushagra7777 could we add this fix to the pr?
There was a problem hiding this comment.
pushed the changes.
AzulGarza
left a comment
There was a problem hiding this comment.
thanks @Kushagra7777!
could you add a small test to ensure that the fix works as expected?
sure @AzulGarza |
ruff check . ruff format .
There was a problem hiding this comment.
thanks @spolisar! @Kushagra7777 could we add this fix to the pr?
adjust conversion to handle device-related issues safely Co-Authored-By: spolisar <spolisar@users.noreply.github.com> Co-Authored-By: azul <azul.garza.r@gmail.com>
removed extra squeeze in _predict_batch so flowstate returns 2D arrays for h=1 with single unique_id.