From cbf59930831e5dcf69c05f17f4bbb63fd679901c Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Fri, 30 May 2025 15:04:47 +0200 Subject: [PATCH 1/6] Add missing modalities to config --- aion/codecs/config.py | 60 ++++++++++++++++++++++++------------------ aion/codecs/manager.py | 5 ---- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/aion/codecs/config.py b/aion/codecs/config.py index bb3734c..8835677 100644 --- a/aion/codecs/config.py +++ b/aion/codecs/config.py @@ -18,12 +18,14 @@ HSCAY, HSCAZ, Dec, + DESISpectrum, GaiaFluxBp, GaiaFluxG, GaiaFluxRp, GaiaParallax, GaiaXpBp, GaiaXpRp, + HSCImage, HSCMagG, HSCMagI, HSCMagR, @@ -43,11 +45,13 @@ LegacySurveyFluxW3, LegacySurveyFluxW4, LegacySurveyFluxZ, + LegacySurveyImage, LegacySurveySegmentationMap, LegacySurveyShapeE1, LegacySurveyShapeE2, LegacySurveyShapeR, Ra, + SDSSSpectrum, Spectrum, Z, ) @@ -76,43 +80,47 @@ class CodecHFConfig: MODALITY_CODEC_MAPPING = { + Dec: ScalarCodec, + DESISpectrum: SpectrumCodec, + GaiaFluxBp: LogScalarCodec, + GaiaFluxG: LogScalarCodec, + GaiaFluxRp: LogScalarCodec, + GaiaParallax: LogScalarCodec, + GaiaXpBp: MultiScalarCodec, + GaiaXpRp: MultiScalarCodec, + HSCAG: ScalarCodec, + HSCAI: ScalarCodec, + HSCAR: ScalarCodec, + HSCAY: ScalarCodec, + HSCAZ: ScalarCodec, + HSCImage: ImageCodec, + HSCMagG: ScalarCodec, + HSCMagI: ScalarCodec, + HSCMagR: ScalarCodec, + HSCMagY: ScalarCodec, + HSCMagZ: ScalarCodec, + HSCShape11: ScalarCodec, + HSCShape12: ScalarCodec, + HSCShape22: ScalarCodec, Image: ImageCodec, - Spectrum: SpectrumCodec, LegacySurveyCatalog: CatalogCodec, - LegacySurveySegmentationMap: ScalarFieldCodec, + LegacySurveyEBV: ScalarCodec, LegacySurveyFluxG: LogScalarCodec, - LegacySurveyFluxR: LogScalarCodec, LegacySurveyFluxI: LogScalarCodec, - LegacySurveyFluxZ: LogScalarCodec, + LegacySurveyFluxR: LogScalarCodec, LegacySurveyFluxW1: LogScalarCodec, LegacySurveyFluxW2: LogScalarCodec, LegacySurveyFluxW3: LogScalarCodec, LegacySurveyFluxW4: LogScalarCodec, - LegacySurveyShapeR: LogScalarCodec, - GaiaFluxG: LogScalarCodec, - GaiaFluxBp: LogScalarCodec, - GaiaFluxRp: LogScalarCodec, - GaiaParallax: LogScalarCodec, + LegacySurveyFluxZ: LogScalarCodec, + LegacySurveyImage: ImageCodec, + LegacySurveySegmentationMap: ScalarFieldCodec, LegacySurveyShapeE1: ScalarCodec, LegacySurveyShapeE2: ScalarCodec, - LegacySurveyEBV: ScalarCodec, - HSCMagG: ScalarCodec, - HSCMagR: ScalarCodec, - HSCMagI: ScalarCodec, - HSCMagZ: ScalarCodec, - HSCMagY: ScalarCodec, - HSCShape11: ScalarCodec, - HSCShape22: ScalarCodec, - HSCShape12: ScalarCodec, - HSCAG: ScalarCodec, - HSCAR: ScalarCodec, - HSCAI: ScalarCodec, - HSCAZ: ScalarCodec, - HSCAY: ScalarCodec, + LegacySurveyShapeR: LogScalarCodec, Ra: ScalarCodec, - Dec: ScalarCodec, - GaiaXpBp: MultiScalarCodec, - GaiaXpRp: MultiScalarCodec, + SDSSSpectrum: SpectrumCodec, + Spectrum: SpectrumCodec, Z: GridScalarCodec, } diff --git a/aion/codecs/manager.py b/aion/codecs/manager.py index 9424e47..fed0000 100644 --- a/aion/codecs/manager.py +++ b/aion/codecs/manager.py @@ -60,11 +60,6 @@ def _load_codec(self, modality_type: type[BaseModality]) -> Codec: # Look up configuration in CODEC_CONFIG if modality_type in MODALITY_CODEC_MAPPING: codec_class = MODALITY_CODEC_MAPPING[modality_type] - elif ( - hasattr(modality_type, "__base__") - and modality_type.__base__ in MODALITY_CODEC_MAPPING - ): - codec_class = MODALITY_CODEC_MAPPING[modality_type.__base__] else: raise ModalityTypeError( f"No codec configuration found for modality type: {modality_type.__name__}" From f9fde2067a17133ecf8c297f24992e1957409ba7 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Fri, 30 May 2025 15:05:13 +0200 Subject: [PATCH 2/6] Check modality codec compatibility --- aion/codecs/utils.py | 59 ++++++++++++++++++++------------ tests/codecs/test_load_codecs.py | 24 +++++++++++++ 2 files changed, 61 insertions(+), 22 deletions(-) create mode 100644 tests/codecs/test_load_codecs.py diff --git a/aion/codecs/utils.py b/aion/codecs/utils.py index 79101a2..6d2841a 100644 --- a/aion/codecs/utils.py +++ b/aion/codecs/utils.py @@ -3,33 +3,50 @@ from aion.codecs.base import Codec from aion.modalities import Modality + ORIGINAL_CONFIG_NAME = hub_mixin.constants.CONFIG_NAME ORIGINAL_PYTORCH_WEIGHTS_NAME = hub_mixin.constants.PYTORCH_WEIGHTS_NAME ORIGINAL_SAFETENSORS_SINGLE_FILE = hub_mixin.constants.SAFETENSORS_SINGLE_FILE -def _override_config_and_weights_names(modality: type[Modality]): - hub_mixin.constants.CONFIG_NAME = f"codecs/{modality.name}/{ORIGINAL_CONFIG_NAME}" - hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ( - f"codecs/{modality.name}/{ORIGINAL_SAFETENSORS_SINGLE_FILE}" - ) - hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ( - f"codecs/{modality.name}/{ORIGINAL_PYTORCH_WEIGHTS_NAME}" - ) - - -def _reset_config_and_weights_names(): - hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ORIGINAL_PYTORCH_WEIGHTS_NAME - hub_mixin.constants.CONFIG_NAME = ORIGINAL_CONFIG_NAME - hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ORIGINAL_SAFETENSORS_SINGLE_FILE - - class CodecPytorchHubMixin(hub_mixin.PyTorchModelHubMixin): """Mixin for PyTorch models that correspond to codecs. Codec don't have their own model repo. Instead they lie in the transformer model repo as subfolders. """ + @staticmethod + def _override_config_and_weights_names(modality: type[Modality]): + hub_mixin.constants.CONFIG_NAME = ( + f"codecs/{modality.name}/{ORIGINAL_CONFIG_NAME}" + ) + hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ( + f"codecs/{modality.name}/{ORIGINAL_SAFETENSORS_SINGLE_FILE}" + ) + hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ( + f"codecs/{modality.name}/{ORIGINAL_PYTORCH_WEIGHTS_NAME}" + ) + + @staticmethod + def _reset_config_and_weights_names(): + hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ORIGINAL_PYTORCH_WEIGHTS_NAME + hub_mixin.constants.CONFIG_NAME = ORIGINAL_CONFIG_NAME + hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ORIGINAL_SAFETENSORS_SINGLE_FILE + + @staticmethod + def _validate_codec_modality(codec: type[Codec], modality: type[Modality]): + # Import MODALITY_CODEC_MAPPING here to avoid circular import + from aion.codecs.config import MODALITY_CODEC_MAPPING + + if not issubclass(codec, Codec): + raise TypeError("Only codecs can be loaded using this method.") + if modality not in MODALITY_CODEC_MAPPING: + raise ValueError(f"Modality {modality} has no corresponding codec.") + elif MODALITY_CODEC_MAPPING[modality] != codec: + raise TypeError( + f"Modality {modality} is associated with {MODALITY_CODEC_MAPPING[modality]} codec but {codec} requested." + ) + @classmethod def from_pretrained( cls, @@ -51,15 +68,13 @@ def from_pretrained( Returns: The loaded codec model. """ - if not issubclass(cls, Codec): - raise ValueError("Only codecs can be loaded using this method.") - # TODO: Check modality is valid + cls._validate_codec_modality(cls, modality) # Overwrite config and pytorch weights names to load codecs stored as submodels - _override_config_and_weights_names(modality) + cls._override_config_and_weights_names(modality) model = super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) - _reset_config_and_weights_names() + cls._reset_config_and_weights_names() return model def save_pretrained(self, save_directory, *args, **kwargs): @@ -71,7 +86,7 @@ def save_pretrained(self, save_directory, *args, **kwargs): **kwargs: Additional keyword arguments to pass to the save method. """ if not isinstance(self, Codec): - raise ValueError("Only codecs can be saved using this method.") + raise TypeError("Only codecs can be saved using this method.") # Construct the path to the codec subfolder codec_path = f"{save_directory}/codecs/{self.modality.name}" super().save_pretrained(codec_path, *args, **kwargs) diff --git a/tests/codecs/test_load_codecs.py b/tests/codecs/test_load_codecs.py new file mode 100644 index 0000000..37c7bff --- /dev/null +++ b/tests/codecs/test_load_codecs.py @@ -0,0 +1,24 @@ +import pytest +import torch + +from aion.codecs import ImageCodec +from aion.codecs.config import HF_REPO_ID +from aion.modalities import Image, LegacySurveyCatalog, LegacySurveyImage + + +def test_load_invalid_modality(): + """Test that loading a modality raises an error.""" + with pytest.raises(TypeError): + ImageCodec.from_pretrained(HF_REPO_ID, modality=LegacySurveyCatalog) + + +def test_load_image_codec(): + """Test that loading an image codec raises an error.""" + codec_image = ImageCodec.from_pretrained(HF_REPO_ID, modality=Image) + codec_legacy_survey_image = ImageCodec.from_pretrained( + HF_REPO_ID, modality=LegacySurveyImage + ) + for param_image, param_legacy_survey_image in zip( + codec_image.parameters(), codec_legacy_survey_image.parameters() + ): + assert torch.equal(param_image, param_legacy_survey_image) From dcffbc908ddb99b2c7c1e6e3cb7cee26cf669c72 Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Fri, 30 May 2025 15:17:23 +0200 Subject: [PATCH 3/6] Update modality class retrieval for scalar codecs --- aion/codecs/scalar.py | 8 ++--- aion/modalities.py | 75 ++++++++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 40 deletions(-) diff --git a/aion/codecs/scalar.py b/aion/codecs/scalar.py index e40b235..a256b4e 100644 --- a/aion/codecs/scalar.py +++ b/aion/codecs/scalar.py @@ -64,7 +64,7 @@ def __init__( reservoir_size: int, ): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = ScalarReservoirQuantizer( codebook_size=codebook_size, reservoir_size=reservoir_size, @@ -80,7 +80,7 @@ def __init__( min_log_value: float | None = -3, ): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = ScalarLogReservoirQuantizer( codebook_size=codebook_size, reservoir_size=reservoir_size, @@ -99,7 +99,7 @@ def __init__( num_quantizers: int, ): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = MultiScalarCompressedReservoirQuantizer( compression_fns=compression_fns, decompression_fns=decompression_fns, @@ -112,7 +112,7 @@ def __init__( class GridScalarCodec(BaseScalarIdentityCodec): def __init__(self, modality: str, codebook_size: int): super().__init__() - self._modality_class = next(m for m in ScalarModalities if m.name == modality) + self._modality_class = ScalarModalities[modality] self._quantizer = ScalarLinearQuantizer( codebook_size=codebook_size, range=(0.0, 1.0), diff --git a/aion/modalities.py b/aion/modalities.py index 7fb4f2b..075887a 100644 --- a/aion/modalities.py +++ b/aion/modalities.py @@ -412,42 +412,45 @@ class GaiaXpRp(Scalar, Modality): token_key: ClassVar[str] = "tok_xp_rp" -ScalarModalities = [ - LegacySurveyFluxG, - LegacySurveyFluxR, - LegacySurveyFluxI, - LegacySurveyFluxZ, - LegacySurveyFluxW1, - LegacySurveyFluxW2, - LegacySurveyFluxW3, - LegacySurveyFluxW4, - LegacySurveyShapeR, - LegacySurveyShapeE1, - LegacySurveyShapeE2, - LegacySurveyEBV, - Z, - HSCAG, - HSCAR, - HSCAI, - HSCAZ, - HSCAY, - HSCMagG, - HSCMagR, - HSCMagI, - HSCMagZ, - HSCMagY, - HSCShape11, - HSCShape22, - HSCShape12, - GaiaFluxG, - GaiaFluxBp, - GaiaFluxRp, - GaiaParallax, - Ra, - Dec, - GaiaXpBp, - GaiaXpRp, -] +ScalarModalities = { + modality.name: modality + for modality in [ + LegacySurveyFluxG, + LegacySurveyFluxR, + LegacySurveyFluxI, + LegacySurveyFluxZ, + LegacySurveyFluxW1, + LegacySurveyFluxW2, + LegacySurveyFluxW3, + LegacySurveyFluxW4, + LegacySurveyShapeR, + LegacySurveyShapeE1, + LegacySurveyShapeE2, + LegacySurveyEBV, + Z, + HSCAG, + HSCAR, + HSCAI, + HSCAZ, + HSCAY, + HSCMagG, + HSCMagR, + HSCMagI, + HSCMagZ, + HSCMagY, + HSCShape11, + HSCShape22, + HSCShape12, + GaiaFluxG, + GaiaFluxBp, + GaiaFluxRp, + GaiaParallax, + Ra, + Dec, + GaiaXpBp, + GaiaXpRp, + ] +} # Convenience type for any modality data ModalityType = ( From 6b505e57df0c6953946cca933b18566b974720fd Mon Sep 17 00:00:00 2001 From: Lucas Meyer Date: Fri, 30 May 2025 15:17:43 +0200 Subject: [PATCH 4/6] Remove unused function --- aion/codecs/scalar_field.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/aion/codecs/scalar_field.py b/aion/codecs/scalar_field.py index c736a9b..5499a42 100644 --- a/aion/codecs/scalar_field.py +++ b/aion/codecs/scalar_field.py @@ -1,4 +1,3 @@ -from functools import reduce from typing import Callable, Optional, Type import torch @@ -17,10 +16,6 @@ from .quantizers import FiniteScalarQuantizer, Quantizer -def _deep_get(dictionary, path, default=None): - return reduce(lambda d, key: d[key], path.split("."), dictionary) - - class AutoencoderScalarFieldCodec(Codec): """Abstract class for autoencoding scalar field codecs.""" From 13d611efb0bfb49d5d13620b1f4bb819d1db3b4d Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Tue, 17 Jun 2025 18:29:47 +0200 Subject: [PATCH 5/6] adding claude config --- CLAUDE.md | 89 ++++++++++++++++++++++++++++++++++++++++++++ aion/codecs/utils.py | 8 ++-- 2 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9990fd1 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,89 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +AION (AstronomIcal Omnimodal Network) is a large omnimodal transformer model for astronomical surveys. It processes 39 distinct astronomical data modalities using a two-stage architecture: + +1. **Modality-specific tokenizers** transform raw inputs (images, spectra, catalogs, scalars) into discrete tokens +2. **Unified encoder-decoder transformer** processes all token streams via multimodal masked modeling (4M) + +The model comes in three variants: Base (300M), Large (800M), and XLarge (3B parameters). + +## Development Commands + +### Testing +```bash +pytest # Run all tests +pytest tests/codecs/ # Run codec tests only +pytest tests/test_data/ # Uses pre-computed test data for validation +``` + +### Linting and Code Quality +```bash +ruff check . # Check code style and lint +ruff check . --fix # Auto-fix linting issues +``` + +### Installation for Development +```bash +pip install -e .[torch,dev] # Install in editable mode with dev dependencies +``` + +### Documentation +```bash +cd docs && make html # Build Sphinx documentation +``` + +## Architecture Overview + +### Core Components + +- **`aion/model.py`**: Main AION wrapper class, inherits from FM (4M) transformer +- **`aion/fourm/`**: 4M (Four-Modal) transformer implementation + - `fm.py`: Core transformer architecture with encoder-decoder blocks + - `modality_info.py`: Configuration for all 39 supported modalities + - `encoder_embeddings.py` / `decoder_embeddings.py`: Modality-specific embedding layers +- **`aion/codecs/`**: Modality tokenization system + - `manager.py`: Dynamic codec loading and management + - `base.py`: Abstract base codec class + - Individual codec implementations for images, spectra, scalars, etc. +- **`aion/modalities.py`**: Type definitions for all astronomical data types + +### Key Design Patterns + +1. **Modality System**: Each astronomical data type (flux, spectrum, catalog) has: + - A modality class in `modalities.py` defining data structure + - A codec in `codecs/` for tokenization + - Embedding layers in `fourm/` for the transformer + +2. **Token Keys**: Each modality has a `token_key` (e.g., `tok_image`, `tok_spectrum_sdss`) that maps between modalities and model components + +3. **HuggingFace Integration**: Models and codecs are distributed via HuggingFace Hub with `from_pretrained()` methods + +## Code Conventions + +- Type hints are mandatory, using `jaxtyping` for tensor shapes (e.g., `Float[Tensor, "batch height width"]`) +- Modality classes use `@dataclass` and inherit from `BaseModality` +- All tensor operations should handle device placement explicitly +- Test data is pre-computed and stored in `tests/test_data/` as `.pt` files + +## Testing Strategy + +Tests validate both encoding and decoding for each modality using pre-computed reference data. The test pattern is: +1. Load input, encoded, and decoded reference tensors +2. Run codec encode/decode operations +3. Assert outputs match reference data within tolerance + +Test files follow naming: `test_{modality}_codec.py` + +## Astronomical Context + +The model processes data from major surveys: +- **Legacy Survey**: Optical images and catalogs (g,r,i,z bands + WISE) +- **HSC (Hyper Suprime-Cam)**: Deep optical imaging (g,r,i,z,y bands) +- **Gaia**: Astrometry, photometry, and BP/RP spectra +- **SDSS/DESI**: Optical spectra + +Each modality represents different physical measurements (flux, shape parameters, coordinates, extinction, etc.) that the model learns to correlate. diff --git a/aion/codecs/utils.py b/aion/codecs/utils.py index cdc47c9..3588949 100644 --- a/aion/codecs/utils.py +++ b/aion/codecs/utils.py @@ -139,11 +139,11 @@ def from_pretrained( """ # cls._validate_codec_modality(cls, modality) # Overwrite config and pytorch weights names to load codecs stored as submodels - #cls._override_config_and_weights_names(modality) - #model = super().from_pretrained( + # cls._override_config_and_weights_names(modality) + # model = super().from_pretrained( # pretrained_model_name_or_path, *model_args, **kwargs - #) - #cls._reset_config_and_weights_names() + # ) + # cls._reset_config_and_weights_names() if not issubclass(cls, Codec): raise ValueError("Only codec classes can be loaded using this method.") From 58a9f185fafd0df427a8b53e72a8a3ceba2279ed Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Tue, 17 Jun 2025 18:38:16 +0200 Subject: [PATCH 6/6] minor cleanup --- aion/codecs/utils.py | 37 +++++++++++-------------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/aion/codecs/utils.py b/aion/codecs/utils.py index 3588949..6c0f686 100644 --- a/aion/codecs/utils.py +++ b/aion/codecs/utils.py @@ -81,25 +81,17 @@ class CodecPytorchHubMixin(hub_mixin.PyTorchModelHubMixin): """ @staticmethod - def _override_config_and_weights_names(modality: type[Modality]): - hub_mixin.constants.CONFIG_NAME = ( - f"codecs/{modality.name}/{ORIGINAL_CONFIG_NAME}" - ) - hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ( - f"codecs/{modality.name}/{ORIGINAL_SAFETENSORS_SINGLE_FILE}" - ) - hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ( - f"codecs/{modality.name}/{ORIGINAL_PYTORCH_WEIGHTS_NAME}" - ) + def _validate_codec_modality(codec: type[Codec], modality: type[Modality]): + """Validate that a codec class is compatible with a modality. - @staticmethod - def _reset_config_and_weights_names(): - hub_mixin.constants.PYTORCH_WEIGHTS_NAME = ORIGINAL_PYTORCH_WEIGHTS_NAME - hub_mixin.constants.CONFIG_NAME = ORIGINAL_CONFIG_NAME - hub_mixin.constants.SAFETENSORS_SINGLE_FILE = ORIGINAL_SAFETENSORS_SINGLE_FILE + Args: + codec: The codec class to validate + modality: The modality type to validate against - @staticmethod - def _validate_codec_modality(codec: type[Codec], modality: type[Modality]): + Raises: + TypeError: If the codec is not a valid codec class or is incompatible with the modality + ValueError: If the modality has no corresponding codec configuration + """ # Import MODALITY_CODEC_MAPPING here to avoid circular import from aion.codecs.config import MODALITY_CODEC_MAPPING @@ -137,15 +129,8 @@ def from_pretrained( Raises: ValueError: If the class is not a codec subclass or modality is invalid. """ - # cls._validate_codec_modality(cls, modality) - # Overwrite config and pytorch weights names to load codecs stored as submodels - # cls._override_config_and_weights_names(modality) - # model = super().from_pretrained( - # pretrained_model_name_or_path, *model_args, **kwargs - # ) - # cls._reset_config_and_weights_names() - if not issubclass(cls, Codec): - raise ValueError("Only codec classes can be loaded using this method.") + # Validate codec-modality compatibility + cls._validate_codec_modality(cls, modality) # Validate modality _validate_modality(modality)