diff --git a/README.md b/README.md index 43b0418..8891ea7 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,18 @@ The repo associated with training the models run here is https://github.com/open The model checkpoints are hosted at: https://huggingface.co/openclimatefix/cloudcasting_uk +## Environment Variables + +The following environment variables are used in the app: + +- `SATELLITE_ZARR_PATH`: The path to the satellite data in Zarr format. +- `OUTPUT_PREDICTION_DIRECTORY`: The directory where results are saved. + +### Optional Environment Variables + +- `SATELLITE_SCALE_FACTOR`: The scale factor for the satellite data. Defaults to 1023. +- `SATELLITE_15_ZARR_PATH`: The path to the 15 minute satellite data in Zarr format. If +this is not set then the `SATELLITE_ZARR_PATH` is used by `.zarr` is repalced with `_15.zarr` ## Installation diff --git a/src/cloudcasting_app/data.py b/src/cloudcasting_app/data.py index 7f6a4c6..3a6f092 100644 --- a/src/cloudcasting_app/data.py +++ b/src/cloudcasting_app/data.py @@ -1,6 +1,7 @@ import logging import shutil import os +import yaml import fsspec import numpy as np @@ -76,6 +77,12 @@ def prepare_satellite_data(t0: pd.Timestamp): # Load data the data for more preprocessing ds = xr.open_zarr(sat_path) + # make sure area attrs are yaml string + if "area" in ds.data.attrs and isinstance(ds.data.attrs["area"], dict): + logger.warning("Converting area attribute to YAML string, " + "we should do this in the satellite consumer.") + ds.data.attrs["area"] = yaml.dump(ds.data.attrs["area"]) + # Crop the input area to expected ds = crop_input_area(ds) @@ -83,7 +90,12 @@ def prepare_satellite_data(t0: pd.Timestamp): ds = ds.sel(variable=channel_order) # Scale the satellite data from 0-1 - ds = ds / 1023 + scale_factor = int(os.environ.get("SATELLITE_SCALE_FACTOR", 1023)) + logger.info( + f"Scaling satellite data by {scale_factor} to be between 0 and 1" + ) + ds = ds / scale_factor + # Resave ds = ds.compute() @@ -107,8 +119,10 @@ def download_all_sat_data() -> bool: # Set variable to track whether the satellite download is successful sat_available = False + # get paths + sat_5_dl_path, sat_15_dl_path = get_satellite_source_paths() + # download 5 minute satellite data - sat_5_dl_path = os.environ["SATELLITE_ZARR_PATH"] fs, _ = fsspec.core.url_to_fs(sat_5_dl_path) if fs.exists(sat_5_dl_path): sat_available = True @@ -121,7 +135,6 @@ def download_all_sat_data() -> bool: logger.info("No 5-minute data available") # Also download 15-minute satellite if it exists - sat_15_dl_path = sat_5_dl_path.replace(".zarr", "_15.zarr") if fs.exists(sat_15_dl_path): sat_available = True logger.info("Downloading 15-minute satellite data") @@ -226,3 +239,14 @@ def get_input_data(ds: xr.Dataset, t0: pd.Timestamp): X = np.nan_to_num(X, nan=-1) return torch.Tensor(X) + + +def get_satellite_source_paths() -> (str | None, str | None): + """ Get the paths to the satellite data from environment variables""" + sat_source_path_5 = os.getenv("SATELLITE_ZARR_PATH", None) + sat_source_path_15 = os.getenv("SATELLITE_15_ZARR_PATH", None) + if sat_source_path_15 is None and sat_source_path_5 is not None: + sat_source_path_15 = sat_source_path_5.replace(".zarr", "_15.zarr") + logger.info( + f"Satellite source paths: 5-minute: {sat_source_path_5}, 15-minute: {sat_source_path_15}") + return sat_source_path_5, sat_source_path_15