diff --git a/.flake8 b/.flake8 index 9f27978..4749b27 100644 --- a/.flake8 +++ b/.flake8 @@ -5,6 +5,29 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. [flake8] -exclude = - examples - setup.py \ No newline at end of file +max-line-length = 100 +extend-ignore = E203, E501, W503, W293, W291, F541, F841 +exclude = + .git, + __pycache__, + docs/source/conf.py, + old, + build, + examples, + setup.py + dist, + *.egg-info, + .venv, + venv, + env, + .pytest_cache, + htmlcov, + logs, + tmp, + screenshots, + workflow_*.json, + tracking*.json, + conftest.py, + playwright.config.py, + .csv, + paper_scripts \ No newline at end of file diff --git a/.github/workflows/dead_code.yml b/.github/workflows/dead_code.yml new file mode 100644 index 0000000..5f1d511 --- /dev/null +++ b/.github/workflows/dead_code.yml @@ -0,0 +1,44 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +name: Dead Code Detection + +on: [pull_request] + +jobs: + vulture-strict: + name: Vulture (100% confidence - blocking) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install vulture + run: pip install vulture>=2.10 + - name: Run vulture (100% confidence) + run: | + echo "Running vulture dead code detection (100% confidence - blocking)..." + vulture anomaly_match/ .vulture_whitelist.py --min-confidence 100 + + vulture-warnings: + name: Vulture (60% confidence - not required) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install vulture + run: pip install vulture>=2.10 + - name: Run vulture (60% confidence) + run: | + echo "Running vulture dead code detection (60% confidence)..." + echo "This check fails if potential dead code is found, but is not required to pass." + echo "" + vulture anomaly_match/ .vulture_whitelist.py --min-confidence 60 diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 252d3b5..e8a92b6 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -30,6 +30,7 @@ jobs: flake8 . --count --show-source --statistics --max-line-length=127 --ignore=E402,W503,E203 build: runs-on: ubuntu-latest + timeout-minutes: 10 permissions: pull-requests: write contents: read diff --git a/.gitignore b/.gitignore index 012aae2..0270407 100644 --- a/.gitignore +++ b/.gitignore @@ -189,8 +189,11 @@ paper_scripts/test_plots_output paper_scripts/**/*.png paper_scripts/**/*.jpg paper_scripts/**/*.jpeg +paper_scripts/**/*.pdf science_paper/**/*.png science_paper/**/*.jpg science_paper/**/*.jpeg pytest-coverage.txt pytest.xml +# IDE and editor settings +.vscode/ diff --git a/.vulture_whitelist.py b/.vulture_whitelist.py new file mode 100644 index 0000000..2682d5d --- /dev/null +++ b/.vulture_whitelist.py @@ -0,0 +1,64 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +""" +Vulture whitelist file. + +Add entries here for code that vulture incorrectly identifies as unused. +Format: function_name # noqa - comment explaining why it's used +""" + +# SessionIOHandler methods - public API used in tests +save_model_checkpoint # noqa - Used in test_session_io_handler.py, test_model_io_integration.py +load_model_checkpoint # noqa - Used in test_model_io_integration.py +list_sessions # noqa - Used in test_session_io_handler.py +save_run # noqa - Used in test_run_label_migration.py +save_labels_to_output_dir # noqa - Used in test_run_label_migration.py + +# FixMatch class attributes +requires_grad # noqa - PyTorch tensor property set to disable gradient for EMA model + +# AnomalyDetectionDataset methods used in tests (tests/dataset_test.py) +_read_and_resize_image # noqa - Used in test_read_and_resize_different_formats +unlabeled_filepaths # noqa - Used in test_anomaly_detection_dataset_properties +save_as_hdf5 # noqa - Used in test_anomaly_detection_dataset_hdf5 +load_from_hdf5 # noqa - Used in test_anomaly_detection_dataset_hdf5 + +# Transform functions used in paper_scripts/ +get_strong_transforms # noqa - Used in paper_scripts/get_example_images.py + +# File I/O utility functions - public API +get_image_paths_from_folder # noqa - Companion to get_image_names_from_folder, tested + +# Session class public API +start_UI # noqa - Public API - used in StarterNotebook.ipynb + +# Widget methods - public API +update_image_display # noqa - Public API method for updating image display + +# ipywidgets style/layout attributes - used by ipywidgets framework +_.style # noqa - Widget.py: progress_bar.style for visual feedback +_.button_color # noqa - ipywidgets button styling +_.font_size # noqa - ipywidgets widget styling +_.width # noqa - ipywidgets layout attribute +_.height # noqa - ipywidgets layout attribute + +# Learning rate scheduler utility - tested in tests/utils_test.py +get_cosine_schedule_with_warmup # noqa - Used in tests and available for external use + +# Configuration attributes - validated and documented +bn_momentum # noqa - Part of default config for batch normalization momentum +N_batch_prediction # noqa - Used in prediction scripts for batch size + +# Seed utility function - used in paper_scripts/paper_benchmark.py and tests +set_seeds # noqa - Used for reproducibility in benchmarks and testing + +# PyTorch CUDA attribute - set in set_seeds.py for deterministic/performance mode +_.benchmark # noqa - torch.backends.cudnn.benchmark attribute + +# Image processing functions used in prediction scripts (root level, excluded from scan) +process_single_wrapper # noqa - Used in prediction_process_hdf5.py, prediction_process_zarr.py +_.n_expected_channels # noqa - fitsbolt config attribute set dynamically diff --git a/CHANGELOG.MD b/CHANGELOG.MD index 3cba838..2af4657 100644 --- a/CHANGELOG.MD +++ b/CHANGELOG.MD @@ -5,6 +5,35 @@ [//]: # (this file, may be copied, modified, propagated, or distributed except according to) [//]: # (the terms contained in the file 'LICENCE.txt'.) +## [v1.2.0] – 2025-01-13 + +### Added +- **Cutana streaming integration** for catalogue-based predictions with parquet and CSV support +- **FitsBolt integration** for consistent FITS normalization across training and prediction +- **Iteration score storage** for tracking unlabeled and test data scores per iteration +- **Automatic batch size estimation** using exponential and binary search for optimal GPU memory usage +- **Full resolution image preview** button in the UI for detailed inspection +- **Dead code detection** CI workflow using Vulture for codebase maintenance + +### Changed +- **Refactored Widget architecture** by extracting PreviewWidget for better code organization +- **FitsBolt config persistence** in model checkpoints for reproducible normalization +- **Parquet format** for Cutana buffer instead of CSV for improved performance +- **Black line-length** updated to 100 characters for better readability + +### Fixed +- **Gallery filename display** for long filenames with improved shortening (#237) +- **Duplicate result accumulation** in prediction process (#238) +- **Error handling** for iteration score CSV saves (#236) +- **FITS extension handling** in Cutana streaming +- **Tensor handling** improvements throughout the codebase + +### Removed +- **Dead code cleanup** removing unused functions and imports identified by Vulture +- **IDE/editor files** from repository with updated .gitignore + +--- + ## [v1.1.0] – 2025-07-04 ### Added diff --git a/CITATION.cff b/CITATION.cff index c01d0c6..e29483d 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,6 +5,12 @@ authors: - family-names: "Gómez" given-names: "Pablo" orcid: "https://orcid.org/0000-0002-5631-8240" +- family-names: "Ruhberg" + given-names: "Laslo E." + orcid: "https://orcid.org/0009-0003-3810-1245" +- family-names: "Nardone" + given-names: "Maria Teresa" + orcid: "https://orcid.org/0009-0001-4102-9630" - family-names: "O'Ryan" given-names: "David" orcid: "https://orcid.org/0000-0003-1217-4617" @@ -19,11 +25,18 @@ preferred-citation: - family-names: "Gómez" given-names: "Pablo" orcid: "https://orcid.org/0000-0002-5631-8240" + - family-names: "Ruhberg" + given-names: "Laslo E." + orcid: "https://orcid.org/0009-0003-3810-1245" + - family-names: "Nardone" + given-names: "Maria Teresa" + orcid: "https://orcid.org/0009-0001-4102-9630" - family-names: "O'Ryan" given-names: "David" orcid: "https://orcid.org/0000-0003-1217-4617" title: "AnomalyMatch: Discovering Rare Objects of Interest with Semi-supervised and Active Learning" - journal: "arXiv preprint" + journal: "arXiv e-prints" year: 2025 + month: 5 doi: 10.48550/arXiv.2505.03509 - url: "https://arxiv.org/abs/2505.03509" + url: "https://arxiv.org/abs/2505.03509" \ No newline at end of file diff --git a/README.md b/README.md index f597cd0..a5f3954 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,15 @@ session_name_timestamp/ ├── session_metadata.json # Complete session tracking data ├── labeled_data.csv # All labelled samples ├── config.toml # Final configuration -└── model.pth # Model checkpoint +├── model.pth # Model checkpoint +└── iteration_scores/ # Per-iteration prediction scores + ├── iteration_1_unlabelled_scores.csv + ├── iteration_1_test_scores.csv + └── ... ``` +**Iteration Scores:** After each training iteration, AnomalyMatch stores prediction scores for both unlabelled and test data (if `test_ratio > 0`). These CSV files contain filenames and their corresponding anomaly scores, enabling analysis of how predictions evolve across training iterations. + You can view any saved session using: ```python import anomaly_match as am @@ -178,11 +184,46 @@ cfg.prediction_search_dir = "/path/to/directory/containing/zarr/files" AnomalyMatch will automatically discover all `.zarr` files in the specified directory and process them efficiently in parallel. Each Zarr file should contain image data with optional metadata in a corresponding `.parquet` file. +#### Multiple Zarr Files for Prediction + +When running predictions on large datasets split across multiple Zarr files, AnomalyMatch automatically discovers and processes all Zarr stores in `prediction_search_dir`. Two folder structures are supported: + +**Option 1: Direct Zarr files** +``` +prediction_search_dir/ +├── dataset_part1.zarr/ +│ └── images/ # Zarr array with shape (N, H, W, C) +├── dataset_part1_metadata.parquet +├── dataset_part2.zarr/ +│ └── images/ +└── dataset_part2_metadata.parquet +``` + +**Option 2: Batch folders with images.zarr subdirectory** +``` +prediction_search_dir/ +├── batch_001/ +│ ├── images.zarr/ +│ │ └── images/ +│ └── images_metadata.parquet +├── batch_002/ +│ ├── images.zarr/ +│ │ └── images/ +│ └── images_metadata.parquet +``` + +**Metadata requirements:** +- Parquet files should contain a `filename`, `original_filename`, or `source_id` column +- For direct zarr files: `_metadata.parquet` in the same directory +- For batch folders: `images_metadata.parquet` in the batch folder + +**Filename handling:** To prevent collisions across zarr files, filenames are automatically prefixed with the zarr/batch folder name (e.g., `batch_001__image_000042`). + ### FITS File Handling - By default, the first extension (index 0) is used when loading FITS files - You can specify a particular extension using the `fits_extension` parameter in the configuration: - - Set `cfg.fits_extension` in your code to control which FITS extensions to use + - Set `cfg.normalisation.fits_extension` in your code to control which FITS extensions to use - Integer values (e.g., `0`, `1`, `2`) to access extensions by index - String values (e.g., `"PRIMARY"`, `"SCIENCE"`) to access extensions by name - List of integers or strings (e.g., `[0, 1, 2]` or `["PRIMARY", "SCIENCE", "ERROR"]`) to combine multiple extensions @@ -198,9 +239,23 @@ AnomalyMatch will automatically discover all `.zarr` files in the specified dire When working with FITS files containing multiple images or data products, specify which extension(s) to use in the configuration. +### Cutana Streaming Integration + +AnomalyMatch supports streaming predictions via [Cutana](https://github.com/esa/cutana), which enables on-the-fly cutout extraction from FITS tiles. This is particularly useful for Euclid mission data, which Cutana primarily targets. + +**How to use Cutana streaming:** + +1. Prepare a Cutana-compatible source catalogue (CSV or Parquet) with columns for coordinates and FITS file paths +2. Set `cfg.prediction_search_dir` to a folder containing your catalogue files +3. AnomalyMatch will automatically detect the catalogues and stream cutouts via Cutana + +**FITS extension configuration:** When using Cutana streaming, ensure `cfg.normalisation.fits_extension` matches the FITS extensions referenced in your catalogue. For multi-band Euclid data, this might be `["VIS", "NIR-H", "NIR-J"]` or similar, depending on your catalogue structure. + +For more details on catalogue format and Cutana configuration, see the [Cutana documentation](https://github.com/esa/cutana). + ## Normalisation and Stretching - Normalisation can be selected in the UI via a drop-down. Alternatively it can be changed by setting e.g. - `cfg.normalisation_method = am.NormalisationMethod.ZSCALE` + `cfg.normalisation.normalisation_method = am.NormalisationMethod.ZSCALE` - Current options are - `CONVERSION_ONLY`: no normalisation - `LOG`: [logarithmic normalisation](https://docs.astropy.org/en/stable/api/astropy.visualization.LogStretch.html#astropy.visualization.LogStretch) @@ -218,8 +273,9 @@ When working with FITS files containing multiple images or data products, specif - `logLevel`: Controls verbosity of training/session logs. - `test_ratio`: Proportion of data used for evaluation (0.0 disables test evaluation, > 0 shows AUROC/AUPRC curves). - `size`: Dimensions to which images are resized (below 96x96 is not recommended). -- `N_to_load`: Number of unlabeled images loaded into the training dataset at once. +- `N_to_load`: Number of unlabeled images loaded into the training dataset at once. From this (`uratio`*`batch_size`*`num_train_iter`) (5*16*200) unlabeled images will be sampled for training. - `output_dir`: Folder for storing results (e.g., labeled_data.csv or final logs). +- `prediction_batch_size`: Batch size for prediction. If not set, AnomalyMatch automatically estimates an optimal batch size based on available GPU memory. ## Advanced CFG Parameters @@ -249,6 +305,7 @@ The following advanced parameters can be configured: ### Additional Parameters - `fits_extension`: Extension(s) to use for FITS files, can be int, string, or list of int/string (default: None) +- `fits_combination`: Dictonary with keys `R`,`G`,`B` of lists of length of `fits_extension` denoting how the specified fits_extensions are (linearly) mapped to the R,G,B channels. - `interpolation_order`: 0-5 corresponding to [skimage resize interpolation orders](https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.warp) (default: 1 (Bi-linear)) - `normalisation_method`: Normalisation method to be applied during file loading. Can also be selected in the UI dropdown. Correspons to an entry from the class NormalisationMethod (default: `NormalisationMethod.CONVERSION_ONLY`) diff --git a/StarterNotebook.ipynb b/StarterNotebook.ipynb index b1a69f0..9703d1c 100644 --- a/StarterNotebook.ipynb +++ b/StarterNotebook.ipynb @@ -1,191 +1,192 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Copyright (c) European Space Agency, 2025.\n", - "\n", - "This file is subject to the terms and conditions defined in file 'LICENCE.txt', which is part of this source code package. No part of the package, including this file, may be copied, modified, propagated, or distributed except according to the terms contained in the file ‘LICENCE.txt’.\n", - "\n", - "### How to AnomalyMatch\n", - "\n", - "#### 1. Recommended Folder Structure\n", - "\n", - "- project/\n", - " - labeled_data.csv | containing annotations of labeled examples\n", - " - training_images/ | the cfg.data_dir\n", - " - image1.jpeg\n", - " - image2.jpeg\n", - " - data_to_predict/ | the cfg.prediction_search_dir\n", - " - unlabeled_file_part1.hdf5\n", - " - unlabeled_file_part2.hdf5\n", - "\n", - "Example of a minimal labeled_data.csv:\n", - "\n", - "```\n", - "filename,label,your_custom_source_id\n", - "image1.jpeg,normal,123456\n", - "image2.jpeg,anomaly,424242\n", - "```\n", - "\n", - "#### 2. Specify paths and configuration parameters below.\n", - "\n", - "#### 3. Refer to the \"UI Explanation\" section at the bottom for details on how to use the interface.\n", - "\n", - "#### 4. Datalabs-specific hints\n", - "\n", - "If you are using Datalabs, you can install additional modules with conda / mamba in the terminal via e.g. `conda install scipy`\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import anomaly_match as am" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# We use a cfg DotMap (a dictionary with dot accessors) to store the configuration for the run\n", - "cfg = am.get_default_cfg()\n", - "cfg.name = \"my_test_run\"\n", - "\n", - "# Path to a model you may load\n", - "cfg.model_path = \"anomaly_match_results/sessions/my_model.pth\"\n", - "\n", - "# Set the data directory\n", - "# This directory should contain the images to be used for active labeling and training and testing\n", - "cfg.data_dir = \"/media/home/AnomalyMatch/tests/test_data/\"\n", - "\n", - "# Set the label file\n", - "cfg.label_file = \"/media/home/AnomalyMatch/tests/test_data/labeled_data.csv\" # CSV mapping annotated images to labels\n", - "\n", - "# Set metadata file\n", - "cfg.metadata_file = \"/media/home/AnomalyMatch/tests/test_data/metadata.csv\" # CSV mapping images to metadata such as sourceID, ra, dec (optional)\n", - "\n", - "\n", - "# Set the search directory\n", - "# You can predict on a large unlabeled dataset (*.hdf5, ideally) by setting this to the directory containing the unlabeled images / files\n", - "# This will be triggered when you press evaluate_search_dir\n", - "cfg.prediction_search_dir = None\n", - "\n", - "# Normalisation method to use when loading images, can be adjusted in the GUI\n", - "cfg.normalisation_method = am.NormalisationMethod.CONVERSION_ONLY\n", - "\n", - "# Set the test ratio\n", - "cfg.test_ratio = 0.0 # Proportion of data used for evaluation (0.0 disables test evaluation, > 0 shows AUROC/AUPRC curves)\n", - "\n", - "# Set the number of unlabeled images to load\n", - "cfg.N_to_load = 100 # Number of unlabeled images loaded into the training dataset at once\n", - "\n", - "# Set the image size\n", - "cfg.size = [64, 64] # Dimensions to which images are resized (below 96x96 is not recommended)\n", - "\n", - "# Set the logger level (options: \"trace\",\"debug\", \"info\", \"warning\", \"error\", \"critical\")\n", - "logger_level = \"info\"\n", - "am.set_log_level(logger_level, cfg)\n", - "\n", - "# Create a session\n", - "session = am.Session(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Start the UI\n", - "session.start_UI()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### UI Explanation\n", - "\n", - "The UI consists of several components:\n", - "\n", - "1. **Image Display Area**: This area shows the currently selected image along with its score and label. The image can be manipulated using the controls below it.\n", - "\n", - "2. **Control Buttons**:\n", - "\n", - " - **Save Model**: Saves the current model state to disk.\n", - " - **Load Model**: Loads a previously saved model from model path.\n", - " - **Save Labels**: Saves the current labels to disk (will not overwrite the original labels file).\n", - " - **Load Top Files**: Loads the top anomalies from a search run.\n", - " - **Remember**: Adds the current image to the remembered list for follow-up.\n", - "\n", - "3. **Image Manipulation Controls**:\n", - "\n", - " - **Invert Image**: Inverts the colors of the image.\n", - " - **Restore**: Restores the image to its original state.\n", - " - **Apply Unsharp Mask**: Applies an unsharp mask to the image to enhance edges.\n", - " - **Brightness and Contrast Sliders**: Adjust the brightness and contrast of the image.\n", - " - **RGB Channel Checkboxes**: Adjust which channels are currently displayed.\n", - " - **Normalisation Dropdown**: Select normalisation to be applied when loading the image. Selection affects training.\n", - "\n", - "4. **Navigation Buttons**:\n", - "\n", - " - **Previous**: Moves to the previous image.\n", - " - **Anomalous**: Marks the image as anomalous for next trainings (Original label_file will not be overwritten).\n", - " - **Nominal**: Marks the image as nominal for next trainings (Original label_file will not be overwritten).\n", - " - **Next**: Moves to the next image.\n", - "\n", - "5. **Training Controls**:\n", - "\n", - " - **Train Iterations**: Sets the number of training iterations.\n", - " - **Batch Size**: Sets the amount of unlabeled images to be used in each training batch (watch out for memory constraints).\n", - " - **Train**: Starts the training process.\n", - " - **Evaluate Search Dir**: Evaluates the images in the search directory.\n", - "\n", - "6. **Model Controls**:\n", - "\n", - " - **Reset Model**: Resets the model to its initial state.\n", - " - **Next Batch**: Loads the next batch unlabeled batch of images for prediction.\n", - "\n", - "7. **Top Images Display**: Shows the top 4 anomalous and top 4 nominal images based on the scores.\n", - "\n", - "This UI allows users to interactively label images, adjust image properties, and manage the training and evaluation process.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "am", - "language": "python", - "name": "am" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.12" - } - }, - "nbformat": 4, - "nbformat_minor": 4 + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) European Space Agency, 2025.\n", + "\n", + "This file is subject to the terms and conditions defined in file 'LICENCE.txt', which is part of this source code package. No part of the package, including this file, may be copied, modified, propagated, or distributed except according to the terms contained in the file ‘LICENCE.txt’.\n", + "\n", + "### How to AnomalyMatch\n", + "\n", + "#### 1. Recommended Folder Structure\n", + "\n", + "- project/\n", + " - labeled_data.csv | containing annotations of labeled examples\n", + " - training_images/ | the cfg.data_dir\n", + " - image1.jpeg\n", + " - image2.jpeg\n", + " - data_to_predict/ | the cfg.prediction_search_dir\n", + " - unlabeled_file_part1.hdf5\n", + " - unlabeled_file_part2.hdf5\n", + "\n", + "Example of a minimal labeled_data.csv:\n", + "\n", + "```\n", + "filename,label,your_custom_source_id\n", + "image1.jpeg,normal,123456\n", + "image2.jpeg,anomaly,424242\n", + "```\n", + "\n", + "#### 2. Specify paths and configuration parameters below.\n", + "\n", + "#### 3. Refer to the \"UI Explanation\" section at the bottom for details on how to use the interface.\n", + "\n", + "#### 4. Datalabs-specific hints\n", + "\n", + "If you are using Datalabs, you can install additional modules with conda / mamba in the terminal via e.g. `conda install scipy`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import anomaly_match as am" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# We use a cfg DotMap (a dictionary with dot accessors) to store the configuration for the run\n", + "cfg = am.get_default_cfg()\n", + "cfg.name = \"my_test_run\"\n", + "\n", + "# Path to a model you may load\n", + "cfg.model_path = \"anomaly_match_results/sessions/my_model.pth\"\n", + "\n", + "# Set the data directory\n", + "# This directory should contain the images to be used for active labeling and training and testing\n", + "cfg.data_dir = \"/media/home/AnomalyMatch/tests/test_data/\"\n", + "\n", + "# Set the label file\n", + "cfg.label_file = \"/media/home/AnomalyMatch/tests/test_data/labeled_data.csv\" # CSV mapping annotated images to labels\n", + "\n", + "# Set metadata file\n", + "cfg.metadata_file = \"/media/home/AnomalyMatch/tests/test_data/metadata.csv\" # CSV mapping images to metadata such as sourceID, ra, dec (optional)\n", + "\n", + "\n", + "# Set the search directory\n", + "# You can predict on a large unlabeled dataset by setting this to the directory containing the unlabeled data.\n", + "# Supported formats: HDF5, Zarr, or CSV/Parquet catalogues for Cutana streaming.\n", + "# This will be triggered when you press evaluate_search_dir\n", + "cfg.prediction_search_dir = None\n", + "\n", + "# Normalisation method to use when loading images, can be adjusted in the GUI\n", + "cfg.normalisation.normalisation_method = am.NormalisationMethod.CONVERSION_ONLY\n", + "\n", + "# Set the test ratio\n", + "cfg.test_ratio = 0.0 # Proportion of data used for evaluation (0.0 disables test evaluation, > 0 shows AUROC/AUPRC curves)\n", + "\n", + "# Set the number of unlabeled images to load\n", + "cfg.N_to_load = 100 # Number of unlabeled images loaded into the training dataset at once\n", + "\n", + "# Set the image size\n", + "cfg.normalisation.image_size = [64, 64] # Dimensions to which images are resized (below 96x96 is not recommended)\n", + "\n", + "# Set the logger level (options: \"trace\",\"debug\", \"info\", \"warning\", \"error\", \"critical\")\n", + "logger_level = \"info\"\n", + "am.set_log_level(logger_level, cfg)\n", + "\n", + "# Create a session\n", + "session = am.Session(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start the UI\n", + "session.start_UI()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### UI Explanation\n", + "\n", + "The UI consists of several components:\n", + "\n", + "1. **Image Display Area**: This area shows the currently selected image along with its score and label. The image can be manipulated using the controls below it.\n", + "\n", + "2. **Control Buttons**:\n", + "\n", + " - **Save Model**: Saves the current model state to disk.\n", + " - **Load Model**: Loads a previously saved model from model path.\n", + " - **Save Labels**: Saves the current labels to disk (will not overwrite the original labels file).\n", + " - **Load Top Files**: Loads the top anomalies from a search run.\n", + " - **Remember**: Adds the current image to the remembered list for follow-up.\n", + "\n", + "3. **Image Manipulation Controls**:\n", + "\n", + " - **Invert Image**: Inverts the colors of the image.\n", + " - **Restore**: Restores the image to its original state.\n", + " - **Apply Unsharp Mask**: Applies an unsharp mask to the image to enhance edges.\n", + " - **Brightness and Contrast Sliders**: Adjust the brightness and contrast of the image.\n", + " - **RGB Channel Checkboxes**: Adjust which channels are currently displayed.\n", + " - **Normalisation Dropdown**: Select normalisation to be applied when loading the image. Selection affects training.\n", + "\n", + "4. **Navigation Buttons**:\n", + "\n", + " - **Previous**: Moves to the previous image.\n", + " - **Anomalous**: Marks the image as anomalous for next trainings (Original label_file will not be overwritten).\n", + " - **Nominal**: Marks the image as nominal for next trainings (Original label_file will not be overwritten).\n", + " - **Next**: Moves to the next image.\n", + "\n", + "5. **Training Controls**:\n", + "\n", + " - **Train Iterations**: Sets the number of training iterations.\n", + " - **# unlabelled to use**: Sets the amount of unlabeled images to be loaded in each training batch (watch out for memory constraints). Out of this the unlabelled used for training will be sampled.\n", + " - **Train**: Starts the training process.\n", + " - **Evaluate Search Dir**: Evaluates the images in the search directory.\n", + "\n", + "6. **Model Controls**:\n", + "\n", + " - **Reset Model**: Resets the model to its initial state.\n", + " - **Next Batch**: Loads the next batch unlabeled batch of images for prediction.\n", + "\n", + "7. **Top Images Display**: Shows the top 4 anomalous and top 4 nominal images based on the scores.\n", + "\n", + "This UI allows users to interactively label images, adjust image properties, and manage the training and evaluation process.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "am", + "language": "python", + "name": "am" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/anomaly_match/__init__.py b/anomaly_match/__init__.py index 49a7c17..5d73855 100644 --- a/anomaly_match/__init__.py +++ b/anomaly_match/__init__.py @@ -4,14 +4,14 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from .image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from .pipeline.session import Session from .utils.get_default_cfg import get_default_cfg from .utils.set_log_level import set_log_level from .utils.print_cfg import print_cfg from .data_io.SessionIOHandler import print_session -__version__ = "1.1.0" +__version__ = "1.2.0" __all__ = [ "get_default_cfg", diff --git a/anomaly_match/data_io/SessionIOHandler.py b/anomaly_match/data_io/SessionIOHandler.py index 8da8218..ad875d5 100644 --- a/anomaly_match/data_io/SessionIOHandler.py +++ b/anomaly_match/data_io/SessionIOHandler.py @@ -55,7 +55,10 @@ def get_session_save_path(self, session_tracker: SessionTracker) -> Path: return self.base_save_path / session_folder def save_session( - self, session_tracker: SessionTracker, save_path: Optional[Path] = None, cfg=None + self, + session_tracker: SessionTracker, + save_path: Optional[Path] = None, + cfg=None, ) -> Path: """ Save complete session data to disk. @@ -126,6 +129,62 @@ def _save_config(self, session_tracker: SessionTracker, save_path: Path, cfg=Non except Exception as e: logger.warning(f"Failed to save configuration: {e}") + def save_iteration_scores( + self, + session_tracker: SessionTracker, + unlabelled_scores: Dict[str, float] = None, + test_scores: Dict[str, float] = None, + save_path: Optional[Path] = None, + ) -> None: + """ + Save per-sample scores for the current iteration. + + Saves unlabelled data scores and test set scores (if available) as CSV files. + Updates the session tracker with the file paths. + + Args: + session_tracker: SessionTracker instance to update. + unlabelled_scores: Dict mapping filename to anomaly score for unlabelled data. + test_scores: Dict mapping filename to anomaly score for test set. + save_path: Optional custom save path. If None, uses default session path. + """ + if save_path is None: + save_path = self.get_session_save_path(session_tracker) + + save_path.mkdir(parents=True, exist_ok=True) + scores_dir = save_path / "iteration_scores" + scores_dir.mkdir(exist_ok=True) + + iteration_num = ( + len(session_tracker.session_iterations) - 1 if session_tracker.session_iterations else 0 + ) + + # Save unlabelled scores + if unlabelled_scores is not None and len(unlabelled_scores) > 0: + unlabelled_df = pd.DataFrame( + list(unlabelled_scores.items()), columns=["filename", "score"] + ) + unlabelled_path = scores_dir / f"unlabelled_scores_iter_{iteration_num}.csv" + try: + unlabelled_df.to_csv(unlabelled_path, index=False) + session_tracker.update_unlabelled_scores_path(str(unlabelled_path)) + logger.debug( + f"Saved {len(unlabelled_scores)} unlabelled scores to: {unlabelled_path}" + ) + except Exception as e: + logger.warning(f"Failed to save unlabelled scores: {e}") + + # Save test scores + if test_scores is not None and len(test_scores) > 0: + test_df = pd.DataFrame(list(test_scores.items()), columns=["filename", "score"]) + test_path = scores_dir / f"test_scores_iter_{iteration_num}.csv" + try: + test_df.to_csv(test_path, index=False) + session_tracker.update_test_scores_path(str(test_path)) + logger.debug(f"Saved {len(test_scores)} test scores to: {test_path}") + except Exception as e: + logger.warning(f"Failed to save test scores: {e}") + def save_model_checkpoint( self, model_state: Dict[str, Any], @@ -207,6 +266,11 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str: model.eval_model.module if hasattr(model.eval_model, "module") else model.eval_model ) + # Get fitsbolt config if present (DotMap pickles directly) + fitsbolt_cfg = getattr(cfg, "fitsbolt_cfg", None) + if fitsbolt_cfg is not None: + logger.debug("Including fitsbolt config in model checkpoint") + # Create save state save_state = { "train_model": train_model.state_dict(), @@ -218,7 +282,8 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str: "best_eval_acc": model.best_eval_acc, "best_it": model.best_it, "last_normalisation_method": getattr(model, "last_normalisation_method", None), - "normalisation_method": cfg.normalisation_method, + "normalisation_method": cfg.normalisation.normalisation_method, + "fitsbolt_cfg": fitsbolt_cfg, } # Save model @@ -309,22 +374,36 @@ def load_model(self, model, cfg, model_path: str = None) -> bool: model.best_it = checkpoint["best_it"] # Handle normalisation method updates + # Checkpoints store two fields (both should have the same value in normal operation): + # - "last_normalisation_method": from model.last_normalisation_method (what model was trained with) + # - "normalisation_method": from cfg.normalisation.normalisation_method (config at save time) + # We prefer "last_normalisation_method" (actual training method) over config value + # Note: last_normalisation_method can be None if model was saved before first training normalisation_updated = False - if "last_normalisation_method" in checkpoint: - cfg.normalisation_method = checkpoint["normalisation_method"] - model.last_normalisation_method = checkpoint["normalisation_method"] - normalisation_updated = True - logger.debug(f"Updated normalisation method to: {cfg.normalisation_method}") - elif "normalisation_method" in checkpoint: - cfg.normalisation_method = checkpoint["last_normalisation_method"] + if checkpoint.get("last_normalisation_method") is not None: + cfg.normalisation.normalisation_method = checkpoint["last_normalisation_method"] model.last_normalisation_method = checkpoint["last_normalisation_method"] normalisation_updated = True logger.debug( - f"Updated normalisation method from legacy field: {cfg.normalisation_method}" + f"Updated normalisation method to: {cfg.normalisation.normalisation_method.name}" + ) + elif checkpoint.get("normalisation_method") is not None: + cfg.normalisation.normalisation_method = checkpoint["normalisation_method"] + model.last_normalisation_method = checkpoint["normalisation_method"] + normalisation_updated = True + logger.debug( + f"Updated normalisation method from config field: {cfg.normalisation.normalisation_method.name}" ) if normalisation_updated: - logger.info(f"Model loaded with normalisation method: {cfg.normalisation_method}") + logger.info( + f"Model loaded with normalisation method: {cfg.normalisation.normalisation_method.name}" + ) + + # Load fitsbolt config if present in checkpoint (DotMap pickles directly) + if "fitsbolt_cfg" in checkpoint and checkpoint["fitsbolt_cfg"] is not None: + cfg.fitsbolt_cfg = checkpoint["fitsbolt_cfg"] + logger.debug("Loaded fitsbolt config from model checkpoint") # Update config to point to the successfully loaded model path cfg.model_path = load_path @@ -413,6 +492,8 @@ def load_session(self, session_path: Path) -> SessionTracker: model_state_path=iter_data.get("model_state_path"), num_newly_labeled_anomalous=iter_data.get("num_newly_labeled_anomalous", 0), num_newly_labeled_nominal=iter_data.get("num_newly_labeled_nominal", 0), + unlabelled_scores_file=iter_data.get("unlabelled_scores_file"), + test_scores_file=iter_data.get("test_scores_file"), ) session_tracker.session_iterations.append(iteration_info) @@ -499,6 +580,13 @@ def save_run( else model.eval_model ) + # Get fitsbolt config if present (DotMap pickles directly) + fitsbolt_cfg = None + if cfg is not None: + fitsbolt_cfg = getattr(cfg, "fitsbolt_cfg", None) + if fitsbolt_cfg is not None: + logger.debug("Including fitsbolt config in training run checkpoint") + # Save model state save_state = { "train_model": train_model.state_dict(), @@ -510,6 +598,7 @@ def save_run( "best_eval_acc": getattr(model, "best_eval_acc", None), "best_it": getattr(model, "best_it", None), "last_normalisation_method": getattr(model, "last_normalisation_method", None), + "fitsbolt_cfg": fitsbolt_cfg, } torch.save(save_state, save_filename) @@ -545,7 +634,10 @@ def save_run( return save_filename def save_labels_to_output_dir( - self, labeled_data_df: pd.DataFrame, output_dir: str, session_tracker: SessionTracker = None + self, + labeled_data_df: pd.DataFrame, + output_dir: str, + session_tracker: SessionTracker = None, ) -> str: """ Save labeled data to the session directory if session_tracker is available, diff --git a/anomaly_match/data_io/find_images_in_folder.py b/anomaly_match/data_io/find_images_in_folder.py index 97b6c7c..0c04665 100644 --- a/anomaly_match/data_io/find_images_in_folder.py +++ b/anomaly_match/data_io/find_images_in_folder.py @@ -10,8 +10,7 @@ import os from pathlib import Path from loguru import logger - -from anomaly_match.utils.constants import SUPPORTED_IMAGE_EXTENSIONS +from fitsbolt import SUPPORTED_IMAGE_EXTENSIONS def get_image_names_from_folder(folder_path, recursive=False, extensions=None): diff --git a/anomaly_match/data_io/load_images.py b/anomaly_match/data_io/load_images.py index 8a124c8..e561eab 100644 --- a/anomaly_match/data_io/load_images.py +++ b/anomaly_match/data_io/load_images.py @@ -5,393 +5,143 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. """ -Functions for loading and processing images. +Functions for loading and processing images as a wrapper around fitsbolt """ -import os -import numpy as np -from PIL import Image -from skimage.transform import resize -from concurrent.futures import ThreadPoolExecutor -from tqdm import tqdm -from loguru import logger -from astropy.io import fits +from dotmap import DotMap +from fitsbolt.image_loader import load_and_process_images, _process_image +from fitsbolt.cfg.create_config import create_config as fb_create_cfg -from anomaly_match.image_processing.normalisation import normalise_image -from anomaly_match.utils.constants import SUPPORTED_IMAGE_EXTENSIONS - -def read_image_data(filepath, cfg): - """ - Read raw image data from a file without processing. +def get_fitsbolt_config(cfg, size_override="default"): + """Get the fitsbolt configuration from the provided configuration. Args: - filepath (str): Path to the image file - cfg: Configuration object containing fits_extension + cfg (DotMap): The configuration object. + size_override: If "default", use cfg.normalisation.image_size. + If None, disable resizing (full resolution). + Otherwise use the provided value. Returns: - numpy.ndarray: Raw image array + DotMap: The cfg with a cfg.fitsbolt_cfg subdotmap. """ - fits_extension = cfg.fits_extension - # Get file extension - file_ext = os.path.splitext(filepath.lower())[1] - - # Validate file extension - assert file_ext in SUPPORTED_IMAGE_EXTENSIONS, ( - f"Unsupported file extension {file_ext} for file {filepath}. " - f"Supported extensions: {SUPPORTED_IMAGE_EXTENSIONS}" + if not hasattr(cfg, "normalisation"): + raise ValueError("Configuration must include a normalisation config for fitsbolt.") + size = cfg.normalisation.image_size if size_override == "default" else size_override + cfg.fitsbolt_cfg = fb_create_cfg( + output_dtype=cfg.normalisation.output_dtype, + size=size, + fits_extension=cfg.normalisation.fits_extension, + interpolation_order=cfg.normalisation.interpolation_order, + n_output_channels=cfg.normalisation.n_output_channels, + normalisation_method=cfg.normalisation.normalisation_method, + channel_combination=cfg.normalisation.channel_combination, + num_workers=cfg.num_workers, + norm_maximum_value=cfg.normalisation.norm_maximum_value, + norm_minimum_value=cfg.normalisation.norm_minimum_value, + norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, + norm_crop_for_maximum_value=cfg.normalisation.norm_crop_for_maximum_value, + norm_asinh_scale=cfg.normalisation.norm_asinh_scale, + norm_asinh_clip=cfg.normalisation.norm_asinh_clip, + log_level="WARNING", + force_dtype=True, ) - logger.trace(f"Reading image {filepath} with extension {file_ext}") - - if file_ext == ".fits": - # Handle FITS files with astropy - with fits.open(filepath) as hdul: - try: - # Handle different extension types (None, int, string, or list) - if fits_extension is None: - # Default to first extension (index 0) - image = hdul[0].data - elif isinstance(fits_extension, list): - # Handle list of extensions - need to load and combine them - extension_images = [] - extension_shapes = [] - extension_names = [] - - # First load all extensions to validate shapes match - for ext in fits_extension: - if isinstance(ext, (int, np.integer)): - # Integer index - check valid bounds - ext_idx = int(ext) - if ext_idx < 0 or ext_idx >= len(hdul): - available_indices = list(range(len(hdul))) - logger.error( - f"Invalid FITS extension index {ext_idx} for file {filepath}. " - f"Available indices: {available_indices}" - ) - raise IndexError( - f"FITS extension index {ext_idx} is out of bounds (0-{len(hdul) - 1})" - ) - ext_data = hdul[ext_idx].data - extension_names.append(f"extension {ext_idx}") - else: - # Try as string extension name - try: - ext_data = hdul[ext].data - extension_names.append(f"'{ext}'") - except KeyError: - available_ext = [ - ext_name.name for ext_name in hdul if hasattr(ext_name, "name") - ] - logger.error( - f"FITS extension name '{ext}' not found in file {filepath}. " - f"Available extensions: {available_ext}" - ) - raise KeyError(f"FITS extension name '{ext}' not found") - - # Check for None data - if ext_data is None: - logger.error(f"FITS extension {ext} in file {filepath} has no data") - raise ValueError(f"FITS extension {ext} in file {filepath} has no data") - - # Record the shape for validation - extension_images.append(ext_data) - extension_shapes.append(ext_data.shape) - - # Validate all shapes match - if len(set(str(shape) for shape in extension_shapes)) > 1: - shape_info = [ - f"{name}: {shape}" - for name, shape in zip(extension_names, extension_shapes) - ] - error_msg = ( - f"Cannot combine FITS extensions with different shapes in file {filepath}. " - f"Extension shapes: {', '.join(shape_info)}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - - # Stack the extensions along a new dimension - image = np.stack(extension_images) - - # If images are 2D (Height, Width), stack results in 3D array (Ext, Height, Width) - # If images are 3D (Height, Width, Channels), stack results in 4D (Ext, Height, Width, Channels) - # For 2D images (now 3D after stacking), treat extensions as channels (RGB) - if len(extension_shapes[0]) == 2: - # Only use up to 3 extensions for RGB (more will be handled later by truncation) - if len(extension_images) > 3: - import warnings - - warnings.warn( - f"More than 3 extensions provided for file {filepath}. " - f"Only the first 3 will be used as RGB channels." - ) - logger.warning( - f"More than 3 extensions provided for file {filepath}. " - f"Only the first 3 will be used as RGB channels." - ) - # Transpose to get (Height, Width, Extensions) which is compatible with RGB format - image = np.transpose(image, (1, 2, 0)) - elif isinstance(fits_extension, (int, np.integer)): - # Integer index - check valid bounds - extension_idx = int(fits_extension) - if extension_idx < 0 or extension_idx >= len(hdul): - logger.error( - f"Invalid FITS extension index {extension_idx} for file {filepath} with {len(hdul)} extensions" - ) - raise IndexError( - f"FITS extension index {extension_idx} is out of bounds (0-{len(hdul) - 1})" - ) - image = hdul[extension_idx].data - else: - # Try as string extension name - try: - image = hdul[fits_extension].data - except KeyError: - available_ext = [ext.name for ext in hdul if hasattr(ext, "name")] - logger.error( - f"FITS extension name '{fits_extension}' not found in file {filepath}. " - + f"Available extensions: {available_ext}" - ) - raise KeyError(f"FITS extension name '{fits_extension}' not found") - except Exception as e: - if isinstance(e, (IndexError, KeyError, ValueError)): - # Re-raise specific extension errors - raise - else: - # For other errors, log and re-raise - logger.error( - f"Error accessing FITS extension {fits_extension} in file {filepath}: {e}" - ) - raise - - # Handle case where data is None - if image is None: - logger.error(f"FITS extension {fits_extension} in file {filepath} has no data") - raise ValueError(f"FITS extension {fits_extension} in file {filepath} has no data") - - # Handle dimension issues in FITS data - if image.ndim > 3: - logger.warning( - f"FITS image {filepath} has more than 3 dimensions. Taking the first 3 dimensions." - ) - image = image[:3] - if image.shape[0] < image.shape[-1]: - logger.warning( - f"FITS image {filepath} seems to be in Channel x Height x Width format. Transposing." - ) - image = np.transpose(image, (1, 2, 0)) - # Normalisation is done later - if image.dtype != np.uint8: - # Safe normalization that handles edge cases - img_min, img_max = image.min(), image.max() - if img_max <= img_min: - # incorrect image, set to zero - logger.warning( - f"FITS image {filepath} has no valid data (min=max). Setting to zero." - ) - image = np.zeros_like(image, dtype=np.uint8) - - # Validate that we have a valid image with at least 2 dimensions - assert ( - image.ndim >= 2 and image.ndim <= 3 - ), f"FITS image {filepath} has less than 2 or more than 3 dimensions: {image.shape}" - else: - # Use PIL for standard image formats - image = np.array(Image.open(filepath)) + return cfg - # Validate the image has appropriate dimensions - assert ( - image.ndim >= 2 and image.ndim <= 3 - ), f"Image {filepath} has less than 2 or more than 3 dimensions: {image.shape}" - return image +def load_and_process_wrapper(filepaths, cfg, desc="Loading images", show_progress=True): + images_list = load_and_process_images( + filepaths, + cfg=None, + output_dtype=cfg.normalisation.output_dtype, + size=cfg.normalisation.image_size, + fits_extension=cfg.normalisation.fits_extension, + interpolation_order=cfg.normalisation.interpolation_order, + normalisation_method=cfg.normalisation.normalisation_method, + channel_combination=cfg.normalisation.channel_combination, + n_output_channels=cfg.normalisation.n_output_channels, + num_workers=cfg.num_workers, + norm_maximum_value=cfg.normalisation.norm_maximum_value, + norm_minimum_value=cfg.normalisation.norm_minimum_value, + norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, + norm_crop_for_maximum_value=cfg.normalisation.norm_crop_for_maximum_value, + norm_asinh_scale=cfg.normalisation.norm_asinh_scale, + norm_asinh_clip=cfg.normalisation.norm_asinh_clip, + desc=desc, + show_progress=show_progress, + ) + # return a list of tuples (filename, image) + # check that images_list has same length as filepaths + if len(images_list) != len(filepaths): + raise ValueError( + f"Mismatch between filepaths ({len(filepaths)}) and images_list ({len(images_list)})" + ) + return [(fp, img) for fp, img in zip(filepaths, images_list)] -def read_and_resize_image( +def load_and_process_single_wrapper( filepath, cfg, - convert_to_rgb=True, -): - """ - Read an image from file and resize it if needed. - - Args: - filepath (str): Path to the image file - cfg: Configuration object containing size, fits_extension, normalisation_method - convert_to_rgb (bool): Whether to convert grayscale/RGBA to RGB - - Returns: - numpy.ndarray: Image array as uint8 - """ - try: - # Read raw image data - image = read_image_data(filepath, cfg) - - # Process the image using the centralized processing function - return process_image_array(image, cfg, convert_to_rgb, filepath) - - except Exception as e: - logger.error(f"Error reading image {filepath}: {e}") - raise e - - -def process_image_array( - image, - cfg, - convert_to_rgb=True, - image_source="array", -): - """ - Process an image array by normalising and resizing it. - Args: - image (numpy.ndarray): Image array to process - cfg: Configuration object containing size, normalisation_method - convert_to_rgb (bool): Whether to convert grayscale/RGBA to RGB - image_source (str): Source of the image for logging - Returns: - numpy.ndarray: Processed image array as uint8 - """ - try: - - # Convert to RGB if requested - if convert_to_rgb: - # Handle grayscale images - if len(image.shape) == 2 or (len(image.shape) == 3 and image.shape[2] == 1): - image = np.stack((image,) * 3, axis=-1) - # Handle RGBA images - elif len(image.shape) == 3 and image.shape[2] > 3: - logger.trace( - f"Image {image_source} is in RGBA format. Converting to RGB by dropping the alpha channel." - ) - image = image[:, :, :3] - - # Validate RGB structure after conversion - if convert_to_rgb: - assert ( - len(image.shape) == 3 and image.shape[2] == 3 - ), f"After RGB conversion, image {image_source} has unexpected shape: {image.shape}" - - logger.trace(f"Normalising image with setting {cfg.normalisation_method}") - image = normalise_image(image, cfg=cfg) - - # Simple resize that maintains uint8 type if requested - if cfg.size is not None and image.shape[:2] != tuple(cfg.size): - image = resize( - image, - cfg.size, - anti_aliasing=None, - order=cfg.interpolation_order if cfg.interpolation_order is not None else 1, - preserve_range=True, - ) - image = np.clip(image, 0, 255).astype(np.uint8) - - return image - - except Exception as e: - logger.error(f"Error processing image {image_source}: {e}") - raise e - - -def load_images_parallel( - filepaths, - cfg, - transform=None, - max_workers=None, - desc="Loading images", - show_progress=True, + desc="image load and process", + show_progress=False, + prediction=False, + size_override="default", ): - """ - Load multiple images in parallel using ThreadPoolExecutor. + """Load and process a single image file. Creates a fitsbolt config if not part of current cfg Args: - filepaths (list): List of image filepaths to load - size (tuple, optional): Size to resize images to (height, width) - transform (callable, optional): Function to apply to each image after loading - max_workers (int, optional): Max number of worker threads, None for default - desc (str): Description for the progress bar - show_progress (bool): Whether to show a progress bar - fits_extension (int, str, list, optional): The FITS extension(s) to use. Can be: - - An integer index - - A string extension name - - A list of integers or strings to combine multiple extensions - Uses the first extension (0) if None. - + filepath (str): The path to the image file. + cfg (DotMap): The configuration object of AnomalyMatch. + desc (str, optional): Description for the loading process. Defaults to "image load and process". + show_progress (bool, optional): Whether to show progress. Defaults to False. + prediction (bool, optional): Whether this is a prediction step. Defaults to False. + size_override: If "default", use cfg.normalisation.image_size. + If None, disable resizing (full resolution). Returns: - list: List of (filepath, image) tuples for successfully loaded images + np.ndarray: The processed image in H,W,C format. """ - logger.debug( - f"Loading {len(filepaths)} images in parallel with normalisation: {cfg.normalisation_method}" + if not prediction: + # cfg might have changed, update fitsbolt config + cfg = get_fitsbolt_config(cfg, size_override=size_override) + cfg.fitsbolt_cfg.num_workers = 1 # for single image processing, force to 1 + # this breaks down for prediction filepath = os.path.join(cfg.data_dir, os.path.basename(filename)) + return load_and_process_images( + filepath, + cfg=cfg.fitsbolt_cfg, + desc=desc, + show_progress=show_progress, ) - def load_single_image(filepath): - try: - image = read_and_resize_image( - filepath, - cfg, - convert_to_rgb=True, - ) - # Apply transform if provided - if transform is not None: - image = transform(image) - - return filepath, image - except Exception as e: - logger.error(f"Error loading {filepath}: {str(e)}") - return None - - # Use ThreadPoolExecutor for parallel loading - with ThreadPoolExecutor(max_workers=max_workers) as executor: - if show_progress: - results = list( - tqdm( - executor.map(load_single_image, filepaths), - desc=desc, - total=len(filepaths), - ) - ) - else: - results = list(executor.map(load_single_image, filepaths)) - - # Filter out None results (failed loads) - results = [r for r in results if r is not None] - - logger.debug(f"Successfully loaded {len(results)} of {len(filepaths)} images") - return results - - -def load_and_process_batch( - filepaths, - cfg, - transform=None, - max_workers=None, -): - """ - Load and process a batch of images in parallel, returning images and paths separately. +# in the future it might make sense to have a process wrapper, right now using legacy functionality +def process_single_wrapper(image, cfg, desc="source"): + """Process a single image using fitsbolt. Args: - filepaths (list): List of image filepaths to load - size (tuple, optional): Size to resize images to (height, width) - transform (callable, optional): Function to apply to each image after loading - max_workers (int, optional): Max number of worker threads, None for default - fits_extension (int, str, list, optional): The FITS extension(s) to use. Can be: - - An integer index - - A string extension name - - A list of integers or strings to combine multiple extensions - Uses the first extension (0) if None. + image: Input image as numpy array + cfg: Configuration object with fitsbolt_cfg set + desc: Description for logging Returns: - tuple: (loaded_images, valid_filepaths) - lists of successfully loaded images and their paths - """ - results = load_images_parallel( - filepaths, - cfg, - transform, - max_workers, - ) + Processed image - # Split the results into separate lists - valid_filepaths = [filepath for filepath, _ in results] - loaded_images = [image for _, image in results] - - return loaded_images, valid_filepaths + Raises: + ValueError: If fitsbolt_cfg is not properly set in cfg + """ + # Validate fitsbolt_cfg exists and is valid + # DotMap auto-creates empty DotMaps when accessing missing keys, so check for 'size' key + fitsbolt_cfg = cfg.fitsbolt_cfg + if fitsbolt_cfg is None or (isinstance(fitsbolt_cfg, DotMap) and "size" not in fitsbolt_cfg): + raise ValueError( + "fitsbolt_cfg is not set in configuration. " + "Models must be saved with fitsbolt config for prediction. " + "Please retrain and save the model to include fitsbolt config." + ) + + # processing requires n_expected channels from the input image + if image.ndim == 2: + fitsbolt_cfg.n_expected_channels = 1 + else: + fitsbolt_cfg.n_expected_channels = image.shape[-1] + fitsbolt_cfg.num_workers = 1 # for single image processing, force to 1 + return _process_image(image, fitsbolt_cfg, image_source=desc) diff --git a/anomaly_match/data_io/save_config.py b/anomaly_match/data_io/save_config.py index 59f7299..92d0aca 100644 --- a/anomaly_match/data_io/save_config.py +++ b/anomaly_match/data_io/save_config.py @@ -11,7 +11,7 @@ from dotmap import DotMap from loguru import logger -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod def _critical_optional_fields() -> list: diff --git a/anomaly_match/datasets/AnomalyDetectionDataset.py b/anomaly_match/datasets/AnomalyDetectionDataset.py index de784c9..452ba56 100644 --- a/anomaly_match/datasets/AnomalyDetectionDataset.py +++ b/anomaly_match/datasets/AnomalyDetectionDataset.py @@ -14,11 +14,11 @@ from loguru import logger from .Label import Label + from anomaly_match.data_io.load_images import ( - read_and_resize_image, - load_images_parallel, + load_and_process_single_wrapper, + load_and_process_wrapper, ) - from anomaly_match.data_io.find_images_in_folder import get_image_names_from_folder from anomaly_match.data_io.metadata_handler import MetadataHandler @@ -53,9 +53,8 @@ def __init__( logger.debug(f"Loading AnomalyDetectionDataset from {cfg.data_dir}") # Initialize key variables - self.classes = ["normal", "anomaly"] self.seed = cfg.seed - self.size = cfg.size + self.size = cfg.normalisation.image_size self.num_channels = 3 self.root_dir = cfg.data_dir self.transform = transform @@ -107,7 +106,6 @@ def __init__( else: self._load_initial_data() self.mean, self.std = self.compute_mean_std() - # self._save_split_hdf5() # Load the labels from the CSV file self._load_csv_and_apply_labels() @@ -132,16 +130,18 @@ def _load_csv_and_apply_labels(self): # Check that labels are valid assert set(labeled_data["label"].unique()) <= set( - ["normal", "anomaly"] - ), "Labels should be either 'normal' or 'anomaly' but found" + str( + ["normal", "anomaly", "removed"] + ), "Labels should be either 'normal', 'anomaly' or 'removed' but found" + str( set(labeled_data["label"].unique()) ) # Label distribution in the new CSV normal_count = labeled_data["label"].value_counts().get("normal", 0) anomaly_count = labeled_data["label"].value_counts().get("anomaly", 0) + removed_count = labeled_data["label"].value_counts().get("removed", 0) logger.debug( - f"Label distribution in CSV file: Normal: {normal_count}, Anomaly: {anomaly_count}" + f"Label distribution in CSV file: Normal: {normal_count}, Anomaly: {anomaly_count}, " + f"Removed: {removed_count}" ) # Update the dataset with new labels @@ -149,11 +149,7 @@ def _load_csv_and_apply_labels(self): def _read_and_resize_image(self, filepath): """Read an image file and resize it. Used in testing""" - return read_and_resize_image( - filepath, - cfg=self.cfg, - convert_to_rgb=True, - ) + return load_and_process_single_wrapper(filepath, self.cfg, desc="ADD loading image") def _load_initial_data(self): """Load labeled data and first batch of unlabeled data.""" @@ -167,11 +163,8 @@ def _load_initial_data(self): ] # 1. First load all images in parallel using our centralized image loading function labeled_filepaths = [os.path.join(self.root_dir, filename) for filename in labeled_files] - loading_results = load_images_parallel( - labeled_filepaths, - cfg=self.cfg, - desc="Loading labeled data", - max_workers=None, + loading_results = load_and_process_wrapper( + labeled_filepaths, cfg=self.cfg, desc="Loading labeled images" ) # 2. Then apply labels to the loaded images @@ -209,11 +202,8 @@ def _load_next_unlabeled_batch(self): batch_filepaths = [os.path.join(self.root_dir, filename) for filename in batch_files] # Load the images in parallel - loading_results = load_images_parallel( - batch_filepaths, - cfg=self.cfg, - desc=f"Loading batch {self.current_batch_idx}", - max_workers=None, + loading_results = load_and_process_wrapper( + batch_filepaths, cfg=self.cfg, desc=f"Loading batch {self.current_batch_idx}" ) # Add the loaded images to the data dictionary with UNKNOWN labels @@ -347,12 +337,6 @@ def __len__(self): def get_nr_of_unlabeled_images(self): return len(self.all_filenames) - len(self.train_data[0]) - len(self.test_data[0]) - def get_nr_of_batches(self): - return int(np.ceil(self.get_nr_of_unlabeled_images() / self.N_to_load)) - - def reset_batch_idx(self): - self.current_batch_idx = 0 - def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() @@ -387,6 +371,8 @@ def update_labels(self, new_labels_df, update_training_data=True): label_enum = Label.NORMAL elif label == "anomaly": label_enum = Label.ANOMALY + elif label == "removed": + continue else: raise ValueError(f"Invalid label {label} for {filename}") @@ -477,43 +463,6 @@ def load_from_hdf5(self, hdf5_path): logger.info(f"Dataset loaded from {hdf5_path}") - def _save_split_hdf5(self): - """Save labeled and unlabeled data in separate HDF5 files.""" - # Save labeled data - with h5py.File(self.labeled_hdf5, "w") as f: - labeled_data = {k: v for k, v in self.data_dict.items() if v[1] != Label.UNKNOWN} - dtype = h5py.special_dtype(vlen=str) - data_dtype = np.dtype( - [ - ("filename", dtype), - ("image", h5py.vlen_dtype(np.dtype("uint8"))), - ("label", np.int8), - ] - ) - - data = [(k, v[0].flatten(), v[1]) for k, v in labeled_data.items()] - dset = f.create_dataset("data", (len(data),), dtype=data_dtype) - for i, (filename, image, label) in enumerate(data): - dset[i] = (filename, image, label) - - f.create_dataset("mean", data=self.mean) - f.create_dataset("std", data=self.std) - - # Save current unlabeled batch - batch_file = self.batch_hdf5_template.format(self.current_batch_idx) - with h5py.File(batch_file, "w") as f: - unlabeled_data = {k: v for k, v in self.data_dict.items() if v[1] == Label.UNKNOWN} - data = [(k, v[0].flatten()) for k, v in unlabeled_data.items()] - - dtype = h5py.special_dtype(vlen=str) - data_dtype = np.dtype( - [("filename", dtype), ("image", h5py.vlen_dtype(np.dtype("uint8")))] - ) - - dset = f.create_dataset("data", (len(data),), dtype=data_dtype) - for i, (filename, image) in enumerate(data): - dset[i] = (filename, image) - def _load_labeled_from_hdf5(self): """Load labeled data from HDF5.""" with h5py.File(self.labeled_hdf5, "r") as f: diff --git a/anomaly_match/datasets/BasicDataset.py b/anomaly_match/datasets/BasicDataset.py index 48ae58b..bdc8255 100644 --- a/anomaly_match/datasets/BasicDataset.py +++ b/anomaly_match/datasets/BasicDataset.py @@ -6,12 +6,12 @@ # the terms contained in the file 'LICENCE.txt'. from torchvision import transforms from torch.utils.data import Dataset -from .augmentation.randaugment import RandAugment from PIL import Image import numpy as np import torch -import copy + +from anomaly_match.image_processing.transforms import get_strong_transforms class BasicDataset(Dataset): @@ -31,8 +31,6 @@ def __init__( transform=None, use_strong_transform=False, strong_transform=None, - *args, - **kwargs, ): """ Args @@ -74,17 +72,13 @@ def __init__( self.num_classes = num_classes self.use_strong_transform = use_strong_transform - self.use_ms_augmentations = False self.transform = transform if use_strong_transform: if strong_transform is None: - self.strong_transform = copy.deepcopy(transform) - self.strong_transform.transforms.insert( - 0, RandAugment(3, 5, use_ms_augmentations=self.use_ms_augmentations) - ) - else: - self.strong_transform = strong_transform + self.strong_transform = get_strong_transforms() + else: + self.strong_transform = strong_transform def __getitem__(self, idx): """ @@ -112,46 +106,3 @@ def __getitem__(self, idx): def __len__(self): return len(self.data) - - def plot_example_imgs(self, N=16, img_size=(256, 256)): - """Plot N example images from the dataset with a fixed image size.""" - import matplotlib.pyplot as plt - from PIL import Image - - images_per_row = 8 - # Calculate the number of rows needed - num_rows = (N + images_per_row - 1) // images_per_row - - plt.figure( - figsize=(images_per_row * 2, num_rows * 2 * self.num_classes), - dpi=150, - facecolor="black", - ) - for class_idx in range(self.num_classes): - idxs = np.where(self.targets == class_idx)[0] - np.random.shuffle(idxs) - for i, idx in enumerate(idxs[:N]): - row = i // images_per_row - col = i % images_per_row - plt.subplot( - self.num_classes * num_rows, - images_per_row, - class_idx * num_rows * images_per_row + row * images_per_row + col + 1, - ) - img = Image.fromarray(self.data[idx].numpy()) - img = img.resize(img_size, Image.LANCZOS) - plt.imshow(img) - # Annotate filename in img - plt.text( - 0, - 0, - self.filenames[idx], - color="white", - backgroundcolor="black", - fontsize=5, - ) - plt.axis("off") - if i == 0: - plt.title(f"Class {class_idx}", color="white", fontsize=10) - plt.tight_layout() - plt.show() diff --git a/anomaly_match/datasets/SSL_Dataset.py b/anomaly_match/datasets/SSL_Dataset.py index eeda52c..d848632 100644 --- a/anomaly_match/datasets/SSL_Dataset.py +++ b/anomaly_match/datasets/SSL_Dataset.py @@ -43,7 +43,7 @@ def __init__( self.N_to_load = cfg.N_to_load self.train = train self.num_classes = 2 - self.size = cfg.size + self.size = cfg.normalisation.image_size self.data_dir = cfg.data_dir self.label_file = cfg.label_file self.dset = None diff --git a/anomaly_match/datasets/augmentation/randaugment.py b/anomaly_match/datasets/augmentation/randaugment.py index d18c951..f6abd1c 100644 --- a/anomaly_match/datasets/augmentation/randaugment.py +++ b/anomaly_match/datasets/augmentation/randaugment.py @@ -82,19 +82,6 @@ def Equalize(img, _): return PIL.ImageOps.equalize(img) -def Invert(img, _): - """Invert the colors of the image. - - Args: - img: PIL Image to be processed - _: Unused parameter - - Returns: - PIL Image with inverted colors - """ - return PIL.ImageOps.invert(img) - - def Identity(img, v): """Return the image unchanged. @@ -190,19 +177,6 @@ def TranslateX(img, v): return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) -def TranslateXabs(img, v): - """Translate the image horizontally by an absolute amount. - - Args: - img: PIL Image to be processed - v: Absolute translation amount in pixels - - Returns: - PIL Image with horizontal translation applied - """ - return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) - - def TranslateY(img, v): """Translate the image vertically by a percentage of its height. @@ -217,19 +191,6 @@ def TranslateY(img, v): return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) -def TranslateYabs(img, v): - """Translate the image vertically by an absolute amount. - - Args: - img: PIL Image to be processed - v: Absolute translation amount in pixels - - Returns: - PIL Image with vertical translation applied - """ - return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) - - def Solarize(img, v): """Invert all pixel values above a threshold. @@ -334,13 +295,12 @@ class RandAugment: augment_list: List of available augmentation operations """ - def __init__(self, n, m, use_ms_augmentations=False): + def __init__(self, n, m): """Initialize the RandAugment pipeline. Args: n: Number of augmentation operations to apply m: Magnitude parameter [0, 30] (deprecated) - use_ms_augmentations: Whether to use MS-specific augmentations (currently unused) """ self.n = n self.m = m @@ -365,7 +325,7 @@ def __call__(self, img): if __name__ == "__main__": - randaug = RandAugment(3, 5, True) + randaug = RandAugment(3, 5) test_img = np.zeros([32, 32, 13], dtype="uint8") print(randaug) diff --git a/anomaly_match/datasets/data_utils.py b/anomaly_match/datasets/data_utils.py index 317cd2a..532ea91 100644 --- a/anomaly_match/datasets/data_utils.py +++ b/anomaly_match/datasets/data_utils.py @@ -41,7 +41,6 @@ def get_prediction_dataloader(dset, batch_size=None, num_workers=4, pin_memory=T transform=transform, use_strong_transform=False, strong_transform=transform, - use_ms_augmentations=False, ) return DataLoader( diff --git a/anomaly_match/image_processing/NormalisationMethod.py b/anomaly_match/image_processing/NormalisationMethod.py deleted file mode 100644 index 6849383..0000000 --- a/anomaly_match/image_processing/NormalisationMethod.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) European Space Agency, 2025. -# -# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which -# is part of this source code package. No part of the package, including -# this file, may be copied, modified, propagated, or distributed except according to -# the terms contained in the file 'LICENCE.txt'. -from enum import IntEnum - - -class NormalisationMethod(IntEnum): - """Enum for normalisation methods.""" - - CONVERSION_ONLY = 0 - LOG = 1 - ZSCALE = 2 - ASINH = 3 - - @classmethod - def get_dropdown_options(cls): - """Returns a list of tuples (label, value) for use in dropdown widgets.""" - return [ - ("ConversionOnly", cls.CONVERSION_ONLY), - ("LogStretch", cls.LOG), - ("ZscaleInterval", cls.ZSCALE), - ("Asinh", cls.ASINH), - ] - - @classmethod - def get_test_methods(cls): - """Returns all methods for testing purposes.""" - return [cls.CONVERSION_ONLY, cls.LOG, cls.ZSCALE, cls.ASINH] diff --git a/anomaly_match/image_processing/normalisation.py b/anomaly_match/image_processing/normalisation.py deleted file mode 100644 index a502d19..0000000 --- a/anomaly_match/image_processing/normalisation.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) European Space Agency, 2025. -# -# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which -# is part of this source code package. No part of the package, including -# this file, may be copied, modified, propagated, or distributed except according to -# the terms contained in the file 'LICENCE.txt'. -import numpy as np -from loguru import logger - - -from skimage.util import img_as_ubyte -from astropy.visualization import ( - ImageNormalize, - LogStretch, - LinearStretch, - ZScaleInterval, - AsinhStretch, - PercentileInterval, -) - -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod - - -def _crop_center(data: np.ndarray, crop_height: int, crop_width: int) -> np.ndarray: - """ - Crop the central region of an image. - - Parameters: - - data: np.ndarray - Input image as (H, W, ...) array. - - crop_height: int - Height of the cropped region. - - crop_width: int - Width of the cropped region. - - Returns: - - np.ndarray - Cropped central region. - """ - h, w = data.shape[:2] - top = (h - crop_height) // 2 - left = (w - crop_width) // 2 - if top < 0 or left < 0: - logger.warning("Crop size is larger than image size, returning original image") - return data - return data[top : top + crop_height, left : left + crop_width] - - -def _compute_max_value(data, cfg=None): - """Compute the maximum value of the image for normalisation - Args: - data (numpy array): Input image array, can be high dynamic range - cfg (DotMap or None): Configuration with optional normalisation values. - Returns: - float: Maximum value for normalisation - """ - - if ( - cfg.normalisation.crop_for_maximum_value is not None - and cfg.normalisation.maximum_value is None - ): - h, w = cfg.normalisation.crop_for_maximum_value - assert ( - h > 0 and w > 0 - ), f"Crop size must be positive integers currently {cfg.normalisation.crop_for_maximum_value}" - # make cutout of the image and compute max value - img_centre_region = _crop_center(data, h, w) - max_value = np.max(img_centre_region) - - else: - # Compute the maximum value of the image - max_value = ( - cfg.normalisation.maximum_value - if cfg.normalisation.maximum_value is not None - else np.max(data) - ) - - return max_value - - -def _compute_min_value(data, cfg): - """Compute the minimum value of the image for normalisation - Args: - data (numpy array): Input image array, can be high dynamic range - cfg (DotMap or None): Configuration with optional normalisation values. - Returns: - float: Maximum value for normalisation - """ - min_value = ( - cfg.normalisation.minimum_value - if cfg.normalisation.minimum_value is not None - else np.min(data) - ) - - return min_value - - -def _log_normalisation(data, cfg): - """A log normalisation based on a minimum as 0 (bkg subtracted) or higher (if calc_vmin is True) - and a dynamically determined maximum. If cfg.normalisation.crop_for_maximum_value is not None the maximum is determined - on a crop around the center, with the shape given by the Tuple crop_for_maximum_value. - - Args: - data (numpy array): Input image array, ideally a float32 or float64 array, can be high dynamic range - cfg (DotMap or None): Configuration with optional normalisation values. - cfg.normalisation.log_calculate_minimum_value (bool): If True, calculate the minimum value of the image, - otherwise set to 0 or cfg.normalisation.minimum_value if set - cfg.normalisation.crop_for_maximum_value (Tuple[int, int], optional): Width and height to crop around the center, - to calculate the maximum value in - - Returns: - numpy array: A normalised image in the [0,255] range as uint8 - """ - - if cfg.normalisation.log_calculate_minimum_value: - minimum = _compute_min_value(data, cfg=cfg) - else: - minimum = ( - cfg.normalisation.minimum_value if cfg.normalisation.minimum_value is not None else 0.0 - ) - - maximum = _compute_max_value(data, cfg=cfg) - if minimum < maximum: - norm = ImageNormalize(data, vmin=minimum, vmax=maximum, stretch=LogStretch(), clip=True) - else: - logger.warning( - "Image minimum value is larger than maximum, ignoring boundaries and using a LinearInterval" - ) - norm = ImageNormalize(data, vmin=None, vmax=None, stretch=LogStretch(), clip=True) - img_normalised = norm(data) # range 0,1 - # Convert back to uint8 range - return img_as_ubyte(img_normalised) - - -def _zscale_normalisation(data, cfg): - """A linear zscale normalisation - - Args: - image (numpy array): Input image array, ideally a float32 or float64 array - - Returns: - numpy array: A normalised image in the [0,255] range as uint8 - """ - # Min Max value do not apply, also no constrain to center - norm = ImageNormalize(data, interval=ZScaleInterval(), stretch=LinearStretch(), clip=True) - img_normalised = norm(data) # range 0,1 - if np.max(img_normalised) > np.min(img_normalised): - # Convert back to uint8 range - return img_as_ubyte(img_normalised) - else: - logger.warning( - "Zscale normalisation: image maximum value not larger than minimum, only converting image" - ) - return _conversiononly_normalisation(data, cfg) - - -def _conversiononly_normalisation(data, cfg): - """A normalisation that does not change the image, but only converts it to uint8 - - Args: - data (numpy array): Input image array, can have a high dynamic range - cfg (DotMap): Configuration with optional normalisation values. - cfg.normalisation.crop_for_maximum_value (Tuple[int, int], optional): Width and height to crop around the center, - to compute the maximum value in - - Returns: - numpy array: A normalised image in the [0,255] range as uint8 - """ - # Check dtype if any conversion is needed - if data.dtype != np.uint8: - # check for uint16 as a simply divison converts - if data.dtype == np.uint16: - # convert to uint8 - return img_as_ubyte(data / (256 * 256 - 1)) # devide by maximum val of unit16 65535 - - else: - # any dtype that is not uint16 or uint8 - # get min or max from config if available - maximum = _compute_max_value(data, cfg) - minimum = _compute_min_value(data, cfg) - - # ensure valid range - if maximum > minimum: - norm = ImageNormalize(data, vmin=minimum, vmax=maximum, clip=True) - img_normalised = norm(data) # range 0,1 - return img_as_ubyte(img_normalised) # Convert back to uint8 range - else: - logger.warning( - "Conversion normalisation: Image minimum value is larger than maximum, setting image to 0" - ) - return np.zeros_like(data, dtype=np.uint8) - else: - # already uint8 - return data - - -def _expand(value, length: int) -> np.ndarray: - """Turn a scalar or sequence into a length-`length` float32 array. - Used in the asinh normalisation to ensure that the scale and clip - parameters are always arrays of the correct length.""" - arr = ( - np.asarray(value, dtype=np.float32) - if isinstance(value, (list, tuple, np.ndarray)) - else np.full(length, value, dtype=np.float32) - ) - if arr.size != length: # keep caller honest - raise ValueError(f"Expected {length} values, got {arr.size}: {value!r}") - return arr - - -def _asinh_normalisation(data, cfg): - """A normalisation based on the asinh stretch. - Allows for per-channel scaling and clipping. - If cfg.normalisation.crop_for_maximum_value is not None the maximum is determined on a cutout around the center - - Args: - ---------- - data : np.ndarray - Image array. Either single-channel (any shape) or RGB with - ``data.ndim == 3`` and ``data.shape[2] == 3``. - cfg : DotMap - Configuration object holding - ``cfg.normalisation.asinh_scale`` and - ``cfg.normalisation.asinh_clip``. Each may be a scalar - or a three-element sequence. - - Returns - ------- - np.ndarray - Asinh-stretched (and possibly clipped) image as ``uint8``. - """ - # Determine whether we are dealing with RGB+.... or not - channels = data.shape[-1] if data.ndim == 3 else 1 - - # Prepare per-channel parameters - scale = _expand(cfg.normalisation.asinh_scale, channels) - clip = _expand(cfg.normalisation.asinh_clip, channels) - - # Get initial min and max and clip values if manual are set - max_value = _compute_max_value(data, cfg) - min_value = _compute_min_value(data, cfg) - data = np.clip(data, min_value, max_value) - - # Apply asinh normalisation & percentile clipping, potentially per-channel - if channels == 1: - norm = ImageNormalize( - data, interval=PercentileInterval(clip[0]), stretch=AsinhStretch(scale[0]), clip=True - ) - normalised = norm(data) - else: - normalised = np.zeros_like(data, dtype=np.float32) - for c in range(channels): - # Apply asinh stretch with scale parameter and percentile clipping for each channel - norm = ImageNormalize( - data[..., c], - interval=PercentileInterval(clip[c]), - stretch=AsinhStretch(scale[c]), - clip=True, - ) - normalised[..., c] = norm(data[..., c]) - # correct to 0-1 range and convert to uint8 - min_value = np.min(normalised) - max_value = np.max(normalised) - if min_value < max_value: - return img_as_ubyte((normalised - min_value) / (max_value - min_value)) - else: - logger.warning( - "Image maximum value is not larger than minimum, using minimal normalisation instead. Check settings" - ) - return _conversiononly_normalisation(data, cfg=cfg) - - -def normalise_image(data, cfg): - """Normalises all images based on the selected normalisation option - - If None is selected and a uint16 array given, it is linearly scaled to uint8 - Otherwise None applies linear normalisation to shift the image to the required [0,255] range if outside of it - - Args: - data (numpy array): Input image array, can have high dynamic range - method (NormalisationMethod): Normalisation method enum for test - cfg (DotMap): Configuration object containing normalisation settings - - Returns: - numpy array: A normalised image based on the selected method - """ - - method = cfg.normalisation_method - # Method selection - if isinstance(method, NormalisationMethod): - pass - else: - logger.critical(f"Normalisation method type {method} , {type(method)} not implemented") - # ensure uint8 - return _conversiononly_normalisation(data, cfg=cfg) - - # execute normalisations based on enum - if method == NormalisationMethod.LOG: - return _log_normalisation(data, cfg=cfg) - elif method == NormalisationMethod.CONVERSION_ONLY: - return _conversiononly_normalisation(data, cfg=cfg) - elif method == NormalisationMethod.ZSCALE: - return _zscale_normalisation(data, cfg=cfg) - elif method == NormalisationMethod.ASINH: - return _asinh_normalisation(data, cfg=cfg) - else: - logger.critical(f"Normalisation method {method} not implemented") - return _conversiononly_normalisation(data, cfg=cfg) diff --git a/anomaly_match/image_processing/transforms.py b/anomaly_match/image_processing/transforms.py index 390807b..3a8c443 100644 --- a/anomaly_match/image_processing/transforms.py +++ b/anomaly_match/image_processing/transforms.py @@ -41,21 +41,19 @@ def get_prediction_transforms(): def get_strong_transforms(): - """Get strong augmentations + """Get strong augmentations for FixMatch. - - Args: - None + Includes RandAugment followed by the same transforms as weak (ToTensor, + RandomHorizontalFlip, RandomAffine). Returns: - torchvision.transforms.Compose: transforms. - with random augmentations and horizontal flips + torchvision.transforms.Compose: Strong augmentation pipeline. """ - return transforms.Compose( [ - RandAugment(3, 5), # Apply RandAugment as first step + RandAugment(3, 5), transforms.ToTensor(), - transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(0, translate=(0, 0.125)), ] ) diff --git a/anomaly_match/models/FixMatch.py b/anomaly_match/models/FixMatch.py index 31f8c35..555c7ad 100644 --- a/anomaly_match/models/FixMatch.py +++ b/anomaly_match/models/FixMatch.py @@ -31,7 +31,6 @@ def __init__( lambda_u, hard_label=True, logger=None, - current_normalisation_method=None, session_tracker=None, ): """FixMatch implementation for semi-supervised learning. @@ -54,7 +53,6 @@ def __init__( super(FixMatch, self).__init__() # Store parameters - self.loader = {} self.num_classes = num_classes self.ema_m = ema_m @@ -64,7 +62,6 @@ def __init__( self.T = T self.p_cutoff = p_cutoff self.lambda_u = lambda_u - self.use_hard_label = hard_label # initialise the normalisation method "last" used self.last_normalisation_method = None @@ -173,7 +170,7 @@ def train(self, cfg, progressbar=None, progress_callback=None): logger.info( f"Starting FixMatch training for {cfg.num_train_iter} iterations on {ngpus_per_node} GPUs" - + f" with normalisation {cfg.normalisation_method.name}" + + f" with normalisation {cfg.normalisation.normalisation_method.name}" ) self.it = 0 @@ -275,10 +272,6 @@ def train(self, cfg, progressbar=None, progress_callback=None): progress_callback(self.it, cfg.num_train_iter) progressbar.refresh() - # Update widget progress bar if provided - if cfg.progress_bar: - cfg.progress_bar.value = (self.it + 1.0) / cfg.num_train_iter - # Periodic evaluation if cfg.num_eval_iter > 0 and self.it % cfg.num_eval_iter == 0 and self.it > 0: progressbar.close() diff --git a/anomaly_match/pipeline/SessionTracker.py b/anomaly_match/pipeline/SessionTracker.py index 1dc9856..47354f3 100644 --- a/anomaly_match/pipeline/SessionTracker.py +++ b/anomaly_match/pipeline/SessionTracker.py @@ -23,6 +23,9 @@ class IterationInfo: model_state_path: Optional[str] = None num_newly_labeled_anomalous: int = 0 num_newly_labeled_nominal: int = 0 + # Per-sample scores for this iteration (stored separately as CSV files) + unlabelled_scores_file: Optional[str] = None + test_scores_file: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert IterationInfo to dictionary.""" @@ -34,6 +37,8 @@ def to_dict(self) -> Dict[str, Any]: "model_state_path": self.model_state_path, "num_newly_labeled_anomalous": self.num_newly_labeled_anomalous, "num_newly_labeled_nominal": self.num_newly_labeled_nominal, + "unlabelled_scores_file": self.unlabelled_scores_file, + "test_scores_file": self.test_scores_file, } @@ -145,21 +150,6 @@ def add_labeled_sample(self, filename: str, label: str, iteration_number: int = logger.debug(f"Added labeled sample: {filename} -> {label} (iteration {current_iter_num})") - def add_initial_labeled_samples(self, labeled_data_df: pd.DataFrame) -> None: - """ - Add initial labeled samples (before any training iterations). - - Args: - labeled_data_df: DataFrame with columns ['filename', 'label'] for initial data. - """ - initial_df = labeled_data_df.copy() - initial_df["iteration"] = -1 # Mark as initial/pre-training data - - # Concatenate with existing data - self.labeled_data_df = pd.concat([self.labeled_data_df, initial_df], ignore_index=True) - - logger.debug(f"Added {len(initial_df)} initial labeled samples") - def update_test_performance(self, performance_metrics: Dict[str, float]) -> None: """ Update test performance for the current session iteration. @@ -182,6 +172,28 @@ def update_model_state_path(self, model_path: str) -> None: self.session_iterations[-1].model_state_path = model_path logger.debug(f"Updated model state path: {model_path}") + def update_unlabelled_scores_path(self, scores_path: str) -> None: + """ + Update the unlabelled scores file path for the current session iteration. + + Args: + scores_path: Path to the saved unlabelled scores CSV file. + """ + if self.session_iterations: + self.session_iterations[-1].unlabelled_scores_file = scores_path + logger.debug(f"Updated unlabelled scores path: {scores_path}") + + def update_test_scores_path(self, scores_path: str) -> None: + """ + Update the test scores file path for the current session iteration. + + Args: + scores_path: Path to the saved test scores CSV file. + """ + if self.session_iterations: + self.session_iterations[-1].test_scores_file = scores_path + logger.debug(f"Updated test scores path: {scores_path}") + def get_session_info(self) -> Dict[str, Any]: """ Get comprehensive session information. @@ -192,7 +204,8 @@ def get_session_info(self) -> Dict[str, Any]: # Count all labeled samples including initial data (iteration = -1) total_anomalous = len(self.labeled_data_df[self.labeled_data_df["label"] == "anomaly"]) total_nominal = len(self.labeled_data_df[self.labeled_data_df["label"] == "normal"]) - total_labeled = len(self.labeled_data_df) + total_labeled = len(self.labeled_data_df[self.labeled_data_df["label"] != "removed"]) + total_removed = len(self.labeled_data_df[self.labeled_data_df["label"] == "removed"]) # Handle case where iteration column might not exist (legacy data) if "iteration" in self.labeled_data_df.columns: @@ -213,6 +226,7 @@ def get_session_info(self) -> Dict[str, Any]: "total_anomalous_samples": total_anomalous, "total_nominal_samples": total_nominal, "total_labeled_samples": total_labeled, + "total_removed_samples": total_removed, "initial_labeled_samples": initial_samples, "iteration_labeled_samples": iteration_samples, "session_duration_minutes": ( @@ -266,22 +280,6 @@ def get_labeled_data_df(self) -> pd.DataFrame: """ return self.labeled_data_df.copy() - def save_training_run(self, model_path: str, config: Any = None) -> None: - """ - Record a completed training run with model and config. - - Args: - model_path: Path to the saved model - config: Optional configuration used for training - """ - if not self.session_iterations: - self.start_new_session_iteration() - - # Update current iteration with the final model - self.session_iterations[-1].model_state_path = model_path - - logger.info(f"Training run completed. Model saved to: {model_path}") - def update_labeled_data(self, labeled_data_df: pd.DataFrame) -> None: """ Update the complete labeled dataset, preserving existing iteration information. @@ -340,51 +338,3 @@ def update_labeled_data(self, labeled_data_df: pd.DataFrame) -> None: if "iteration" in self.labeled_data_df.columns: iter_counts = self.labeled_data_df["iteration"].value_counts().sort_index() logger.debug(f"Iteration distribution after merge: {dict(iter_counts)}") - - def set_total_model_iterations(self, total_iterations: int) -> None: - """ - Set the total model iterations count directly. - - This is useful when you need to correct or set the iteration count - based on actual training that has occurred. - - Args: - total_iterations: Total number of model training iterations. - """ - self.total_model_iterations = total_iterations - logger.debug(f"Set total model iterations to: {total_iterations}") - - def get_current_iteration_number(self) -> int: - """ - Get the current active iteration number. - - Returns: - int: Current iteration number, or -1 if no iterations started yet. - """ - if self.session_iterations: - return self.session_iterations[-1].iteration_number - else: - return -1 - - def debug_session_state(self) -> Dict[str, Any]: - """ - Get debug information about the current session state. - - Returns: - Dict with debug information. - """ - return { - "current_session_iteration": self.current_session_iteration, - "total_session_iterations": len(self.session_iterations), - "current_active_iteration": self.get_current_iteration_number(), - "total_model_iterations": self.total_model_iterations, - "labeled_data_count": len(self.labeled_data_df), - "iterations_info": [ - { - "iter_num": iter_info.iteration_number, - "anomalous": iter_info.num_newly_labeled_anomalous, - "normal": iter_info.num_newly_labeled_nominal, - } - for iter_info in self.session_iterations - ], - } diff --git a/anomaly_match/pipeline/session.py b/anomaly_match/pipeline/session.py index 9eec73f..83adf92 100644 --- a/anomaly_match/pipeline/session.py +++ b/anomaly_match/pipeline/session.py @@ -19,19 +19,26 @@ import h5py import zarr from pathlib import Path +from fitsbolt import SUPPORTED_IMAGE_EXTENSIONS + from anomaly_match.datasets.SSL_Dataset import SSL_Dataset from anomaly_match.datasets.data_utils import get_prediction_dataloader from anomaly_match.models.FixMatch import FixMatch -from anomaly_match.utils.constants import SUPPORTED_IMAGE_EXTENSIONS + from anomaly_match.utils.print_cfg import print_cfg from anomaly_match.utils.set_log_level import set_log_level from anomaly_match.utils.get_net_builder import get_net_builder +from anomaly_match.utils.cutana_stream_utils import ( + cutana_buffer_generator, + cutana_validate_files_and_count_sources, +) from anomaly_match.utils.get_optimizer import get_optimizer from anomaly_match.utils.validate_config import validate_config -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from anomaly_match.pipeline.SessionTracker import SessionTracker from anomaly_match.data_io.SessionIOHandler import SessionIOHandler +from anomaly_match.data_io.load_images import get_fitsbolt_config class Session: @@ -40,7 +47,6 @@ class Session: labeled_train_dataset = None unlabeled_train_dataset = None test_dataset = None - prediction_dataset = None widget = None model: FixMatch = None @@ -81,7 +87,7 @@ def __init__(self, cfg): validate_config(cfg) self.cfg = cfg - self.cached_image_normalisation_enum = cfg.normalisation_method + self.cached_image_normalisation_enum = cfg.normalisation.normalisation_method self.out = None # Initialize out attribute to None # Initialize label cache and distribution cache @@ -93,8 +99,6 @@ def __init__(self, cfg): self._load_datasets() logger.debug("Datasets loaded, initializing model") self._init_model() - self.top_N_filenames_scores = [] - self.eval_predictions = {} # Initialize empty dict for eval predictions def _init_model(self): """Initializes the model with the configuration settings.""" @@ -113,7 +117,6 @@ def _init_model(self): lambda_u=self.cfg.ulb_loss_ratio, hard_label=True, logger=logger, - current_normalisation_method=self.cfg.normalisation_method, session_tracker=self.session_tracker, ) @@ -151,7 +154,7 @@ def _load_datasets(self): self.labeled_train_dataset, self.unlabeled_train_dataset = self.train_dset.get_ssl_dset() # Update information about cached dataset - self.cached_image_normalisation_enum = self.cfg.normalisation_method + self.cached_image_normalisation_enum = self.cfg.normalisation.normalisation_method self.cfg.num_classes = self.train_dset.num_classes self.cfg.num_channels = self.train_dset.num_channels @@ -184,8 +187,8 @@ def update_predictions(self): ) def progress_callback(current, total): - if hasattr(self.cfg, "progress_bar") and self.cfg.progress_bar is not None: - self.cfg.progress_bar.value = current / total + if self.widget is not None and self.widget.ui["progress_bar"] is not None: + self.widget.ui["progress_bar"].value = current / total if self.widget is not None: self.widget.ui["train_label"].value = "Updating predictions..." @@ -363,7 +366,7 @@ def set_normalisation_method(self, method: NormalisationMethod): method (NormalisationMethod): The new normalization method to apply. """ # update norm method in session cfg, should - self.cfg.normalisation_method = method + self.cfg.normalisation.normalisation_method = method def _reload_datasets(self): """Reloads the datasets if normalisation changed.""" @@ -408,6 +411,10 @@ def remember_current_file(self, filename): def save_model(self): """Saves the current model state using SessionIOHandler.""" with self.out if self.out is not None else nullcontext(): + # Ensure fitsbolt config is set before saving model + # This creates cfg.fitsbolt_cfg from normalisation settings for prediction consistency + self.cfg = get_fitsbolt_config(self.cfg) + # Save model using SessionIOHandler model_path = self.session_io.save_model(self.model, self.cfg, self.session_tracker) @@ -416,23 +423,36 @@ def save_model(self): def load_model(self): """Loads the model state using SessionIOHandler.""" with self.out if self.out is not None else nullcontext(): + # Save the current normalisation method before loading + old_normalisation_method = self.cfg.normalisation.normalisation_method + success = self.session_io.load_model(self.model, self.cfg) if success: logger.info("Model loaded successfully") - # Check if normalisation method was updated from the loaded model - if hasattr(self.model, "last_normalisation_method"): - if self.model.last_normalisation_method != self.cfg.normalisation_method: - logger.info( - f"Normalisation method updated from loaded model: " - f"{self.model.last_normalisation_method}" - ) - - # Update cached normalisation and reload datasets if needed - if self.cached_image_normalisation_enum != self.cfg.normalisation_method: - logger.info("Normalisation method changed, reloading datasets...") - self._reload_datasets() + # Always inform user about loaded normalisation settings + # (parameters like asinh_scale may differ even if method is the same) + new_normalisation_method = self.cfg.normalisation.normalisation_method + logger.info( + f"Loaded model normalisation: {new_normalisation_method.name}. " + f"Note: normalisation parameters were also loaded from the model checkpoint." + ) + + # Warn if the method itself changed + if old_normalisation_method != new_normalisation_method: + logger.warning( + f"Normalisation method changed from {old_normalisation_method.name} " + f"to {new_normalisation_method.name}. Images may need to be refreshed." + ) + + # Update cached normalisation and reload datasets if method changed + if ( + self.cached_image_normalisation_enum + != self.cfg.normalisation.normalisation_method + ): + logger.info("Normalisation method changed, reloading datasets...") + self._reload_datasets() else: logger.error("Failed to load model") @@ -444,7 +464,6 @@ def train(self, cfg, progress_callback=None): progess_callback (function, optional): Callback function to update progress. Defaults to None. """ self.cfg = cfg - self.top_N_filenames_scores = [] # Clear top N filenames and scores with self.out if self.out is not None else nullcontext(): # Start a new session iteration self.session_tracker.start_new_session_iteration() @@ -476,6 +495,7 @@ def train(self, cfg, progress_callback=None): eval_results = self.model.train(cfg, progress_callback=progress_callback) # Update session tracker with training results + test_scores = None if eval_results: # Filter out large data fields that shouldn't be saved to session metadata filtered_eval_results = { @@ -486,13 +506,40 @@ def train(self, cfg, progress_callback=None): } self.session_tracker.update_test_performance(filtered_eval_results) + # Extract test scores for saving (filename -> anomaly probability) + if "eval/predictions_and_labels" in eval_results: + predictions_and_labels = eval_results["eval/predictions_and_labels"] + test_scores = { + filename: float(pred_label[0].item()) + for filename, pred_label in predictions_and_labels.items() + } + # Update total model iterations self.session_tracker.total_model_iterations = self.model.total_it logger.info("Training complete.") # Update cached image normalisation enum - self.cached_image_normalisation_enum = self.cfg.normalisation_method + self.cached_image_normalisation_enum = self.cfg.normalisation.normalisation_method + + # Update predictions to get unlabelled scores for this iteration + self.update_predictions() + + # Extract unlabelled scores (filename -> anomaly score) + unlabelled_scores = None + if self.scores is not None and self.filenames is not None: + unlabelled_scores = { + filename: float(score) for filename, score in zip(self.filenames, self.scores) + } + + # Save iteration scores (unlabelled and test set scores) + self.session_io.save_iteration_scores( + self.session_tracker, + unlabelled_scores=unlabelled_scores, + test_scores=test_scores, + ) + # Save model to session directory using centralized save_model method + # Note: save_model() ensures fitsbolt_cfg is set before saving self.save_model() # Save session again to capture training results (test performance, model path) @@ -584,7 +631,6 @@ def get_label(self, idx): def load_next_batch(self): """Loads the next batch of data and updates predictions.""" logger.debug("Loading next batch of data") - self.top_N_filenames_scores = [] # Clear top N filenames and scores # Note that we are updating also the labeled_dataset since the unlabeled # data are going to disappear from the unlabeled dataset once we call this function. self.labeled_train_dataset, self.unlabeled_train_dataset = self.train_dset.update_dsets( @@ -601,7 +647,7 @@ def load_next_batch(self): # Clear the label cache since active_learning_df is now empty self._label_cache = {} - self.cached_image_normalisation_enum = self.cfg.normalisation_method + self.cached_image_normalisation_enum = self.cfg.normalisation.normalisation_method # We don't rebuild the cache here since active_learning_df is empty # The get_label method will handle finding labels in the main dataset if needed self.update_predictions() @@ -637,6 +683,8 @@ def run_pipeline(self, temp_config_path, input_path, top_N, file_type=None): ".hdf5": "hdf5", ".zarr": "zarr", ".txt": "image", # Grouped image files + ".parquet": "stream", + ".csv": "stream", } file_type = extension_map.get(ext, "image") else: @@ -647,6 +695,7 @@ def run_pipeline(self, temp_config_path, input_path, top_N, file_type=None): "hdf5": "prediction_process_hdf5.py", "image": "prediction_process.py", "zarr": "prediction_process_zarr.py", + "stream": "prediction_process_cutana.py", } script = script_map.get(file_type) @@ -689,7 +738,7 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): """Evaluates all images and updates the session's img_catalog with the top N images.""" logger.info("Evaluating all images") # check if normalisation changed and reload if necessary - if self.cfg.normalisation_method != self.cached_image_normalisation_enum: + if self.cfg.normalisation.normalisation_method != self.cached_image_normalisation_enum: self._reload_datasets() # Check if model exists before proceeding @@ -704,17 +753,37 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): raise FileNotFoundError(error_msg) # Auto-detect file type based on prediction_search_dir - detected_file_type = None - if self.cfg.prediction_search_dir: - detected_file_type = self._auto_detect_prediction_file_type( - self.cfg.prediction_search_dir + if not self.cfg.prediction_search_dir: + error_msg = ( + "No prediction_search_dir configured. " + "Please set cfg.prediction_search_dir to a directory containing " + "images, HDF5 files, Zarr files, or Cutana buffer files." ) + logger.error(error_msg) + if self.widget is not None: + self.widget.ui["train_label"].value = "Error: No prediction directory!" + raise ValueError(error_msg) + + detected_file_type = self._auto_detect_prediction_file_type(self.cfg.prediction_search_dir) + + # Check for Cutana + MIDTONES incompatibility + if detected_file_type == "stream": + if self.cfg.normalisation.normalisation_method == NormalisationMethod.MIDTONES: + error_msg = ( + "MIDTONES normalisation is not supported for Cutana streaming predictions. " + "Please use CONVERSION_ONLY, LOG, ZSCALE, or ASINH." + ) + logger.error(error_msg) + if self.widget is not None: + self.widget.ui["train_label"].value = "Error: MIDTONES not supported!" + raise ValueError(error_msg) # Define supported file extensions supported_extensions = { "hdf5": [".h5", ".hdf5"], "image": SUPPORTED_IMAGE_EXTENSIONS, "zarr": [".zarr"], + "stream": [".csv", ".parquet"], } pattern = supported_extensions.get(detected_file_type) @@ -724,32 +793,57 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): # Get all matching files from the cfg.prediction_search_dir input_files = [] for f in os.listdir(self.cfg.prediction_search_dir): + file_path = os.path.join(self.cfg.prediction_search_dir, f) file_ext = os.path.splitext(f.lower())[1] - if file_ext in pattern: - input_files.append(os.path.join(self.cfg.prediction_search_dir, f)) - num_files = len(input_files) + if detected_file_type == "zarr": + # For zarr, check for direct .zarr files/directories + if file_ext in pattern and (os.path.isfile(file_path) or os.path.isdir(file_path)): + input_files.append(file_path) + # Also check for batch folders containing images.zarr subdirectory + elif os.path.isdir(file_path) and os.path.exists( + os.path.join(file_path, "images.zarr") + ): + # Add the path to the images.zarr subdirectory + input_files.append(os.path.join(file_path, "images.zarr")) + elif file_ext in pattern: + input_files.append(file_path) + total_images = 0 processed_images = 0 start_time = time.time() # First count total images logger.debug("Counting total images to process...") - for input_file in input_files: - try: - if detected_file_type == "hdf5": - with h5py.File(input_file, "r") as h5f: - total_images += len(h5f["images"]) - elif detected_file_type == "zarr": - root = zarr.open_group(input_file, mode="r") - if "images" in root: - total_images += root["images"].shape[0] - else: - logger.warning(f"No 'images' array found in Zarr file {input_file}") - else: # jpeg/image files - single file - total_images += 1 - except Exception as e: - logger.warning(f"Error counting images in {input_file}: {str(e)}") + if detected_file_type != "stream": + for input_file in input_files: + try: + if detected_file_type == "hdf5": + with h5py.File(input_file, "r") as h5f: + total_images += len(h5f["images"]) + elif detected_file_type == "zarr": + root = zarr.open_group(input_file, mode="r") + if "images" in root: + total_images += root["images"].shape[0] + else: + logger.warning(f"No 'images' array found in Zarr file {input_file}") + else: # jpeg/image files - single file + total_images += 1 + except Exception as e: + logger.warning(f"Error counting images in {input_file}: {str(e)}") + + else: # Validates files against cutana and counts sources in valid files + logger.info("Validating files against cutana") + input_files, total_images, total_chunks = cutana_validate_files_and_count_sources( + input_files, chunk_size=self.cfg.subprocess_buffer_size + ) + + if not input_files: + msg = "All found files are not compatible with cutana" + logger.error(msg) + raise RuntimeError(msg) + + num_files = len(input_files) logger.info(f"Found total of {total_images:,} images to process in {num_files} files") @@ -780,6 +874,18 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): f"for {total_input_images} images" ) + # Creating a generator that loads the csv/parquet in chunks and saves to a temporary file + elif detected_file_type == "stream": + + # Files are read in chunks and saved into this intermediate buffer + cutana_buffer_path = Path("tmp") / ".cutana_buffer.parquet" + input_files = cutana_buffer_generator( + files=input_files, + buffer_path=cutana_buffer_path, + chunk_size=self.cfg.subprocess_buffer_size, + ) + num_files = total_chunks + for file_idx, input_file in enumerate(input_files): # Get number of images in current file logger.debug(f"Processing file {file_idx + 1}/{num_files}: {input_file}") if detected_file_type == "hdf5": @@ -796,8 +902,14 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): except Exception as e: logger.error(f"Error reading Zarr file {input_file}: {e}") num_items = 0 + elif detected_file_type == "stream": + # Cutana input buffer file (CSV or parquet) + if str(input_file).endswith(".parquet"): + num_items = len(pd.read_parquet(input_file)) + else: + num_items = len(pd.read_csv(input_file)) else: # image files - if input_file.endswith(".txt"): # This is a group file + if str(input_file).endswith(".txt"): # This is a group file with open(input_file, "r") as f: num_items = len(f.readlines()) else: @@ -849,15 +961,14 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): # Make tmp folder if it doesn't exist os.makedirs("tmp", exist_ok=True) - # Clear progress bar key from cfg is present since it is not pickle serializable - if "progress_bar" in temp_config: - del temp_config["progress_bar"] - # Save the config to a temporary file as pickle with open(temp_config_path, "wb") as f: pickle.dump(temp_config, f) logger.debug(f"Temporary config saved to {temp_config_path}") + # Create output directory if it doesn't exist + os.makedirs(self.cfg.output_dir, exist_ok=True) + # Run the prediction process script self.run_pipeline(temp_config_path, input_file, top_N, detected_file_type) @@ -1037,8 +1148,12 @@ def _auto_detect_prediction_file_type(self, search_dir): ".tif": "image", ".tiff": "image", ".fits": "image", + ".csv": "stream", + ".parquet": "stream", } + tracked_extenstions = {key: 0 for key in extension_map.keys()} + # Count files by type file_type_counts = {} for filename in os.listdir(search_dir): @@ -1048,15 +1163,20 @@ def _auto_detect_prediction_file_type(self, search_dir): if os.path.isfile(file_path): _, ext = os.path.splitext(filename.lower()) if ext in extension_map: + tracked_extenstions[ext] += 1 file_type = extension_map[ext] file_type_counts[file_type] = file_type_counts.get(file_type, 0) + 1 # Check if it's a zarr directory (zarr stores can be directories) elif os.path.isdir(file_path): + # Check for direct zarr store (ends with .zarr or has zarr.json) if filename.lower().endswith(".zarr") or os.path.exists( os.path.join(file_path, "zarr.json") ): file_type_counts["zarr"] = file_type_counts.get("zarr", 0) + 1 + # Check for batch folders containing images.zarr subdirectory + elif os.path.exists(os.path.join(file_path, "images.zarr")): + file_type_counts["zarr"] = file_type_counts.get("zarr", 0) + 1 if not file_type_counts: logger.warning( @@ -1066,6 +1186,7 @@ def _auto_detect_prediction_file_type(self, search_dir): # Return the most common file type detected_type = max(file_type_counts, key=file_type_counts.get) + logger.debug( f"Auto-detected prediction file type: {detected_type} (found {file_type_counts[detected_type]} files)" ) diff --git a/anomaly_match/ui/Widget.py b/anomaly_match/ui/Widget.py index 0d56e5a..86b3865 100644 --- a/anomaly_match/ui/Widget.py +++ b/anomaly_match/ui/Widget.py @@ -17,18 +17,58 @@ from skimage.util import img_as_ubyte +from anomaly_match.data_io.load_images import load_and_process_single_wrapper +from anomaly_match.ui.preview_widget import PreviewWidget + # Import the newly created UI elements from anomaly_match.ui.ui_elements import ( create_ui_elements, attach_click_listeners, HTML_setup, ) -from anomaly_match.image_processing.display_transforms import ( - apply_transforms_ui, - display_image_normalisation, -) -from anomaly_match.utils.numpy_to_byte_stream import numpy_array_to_byte_stream -from anomaly_match.data_io.load_images import read_and_resize_image + + +def shorten_filename(filename: str, max_length: int = 25) -> str: + """ + Shorten a filename to fit within the specified maximum length. + + Preserves the extension and shows beginning/end of the basename. + + Args: + filename: The filename to shorten. + max_length: Maximum total length of the result. + + Returns: + Shortened filename if needed, original otherwise. + """ + if len(filename) <= max_length: + return filename + + # Split into base and extension, handling multiple dots + if "." in filename: + # Find the last dot for extension + last_dot_idx = filename.rfind(".") + basename = filename[:last_dot_idx] + extension = filename[last_dot_idx:] # includes the dot + else: + basename = filename + extension = "" + + # Calculate available space for basename + # Format: "start...end" + extension + ellipsis = "..." + available = max_length - len(extension) - len(ellipsis) + + if available <= 6: + # Very short max_length, just truncate + return filename[: max_length - 3] + "..." + + # Split available space: more at start, less at end + start_len = (available * 2) // 3 + end_len = available - start_len + + shortened = basename[:start_len] + ellipsis + basename[-end_len:] + extension + return shortened class Widget: @@ -45,24 +85,38 @@ def __init__(self, session): self.session = session self.cfg = session.cfg + # Create the preview widget (handles image display and transforms) + self.preview = PreviewWidget(session) + # Create all UI elements (moved from the big monolithic class) self.ui = create_ui_elements() - # This is to keep track of the current image index - self.current_index = 0 + # Replace the ui_elements image/text widgets with preview widget's versions + self.ui["image_widget"] = self.preview.image_widget + self.ui["filename_text"] = self.preview.filename_text - # Initialize transformation states - self.invert = False - self.brightness = 1.0 - self.contrast = 1.0 - self.unsharp_mask_applied = False - self.original_image = None - self.modified_image = None + # Rebuild center_row with preview widget's components + from ipywidgets import VBox + import ipywidgets as widgets - # Initialize RGB channel states - self.show_r = True - self.show_g = True - self.show_b = True + self.ui["center_row"] = VBox( + [self.preview.filename_text, self.preview.image_widget], + layout=widgets.Layout(background_color="black"), + ) + + # Update main_layout to use the new center_row with preview widget's components. + # The main_layout VBox has this structure: + # [0] model_controls, [1] top_row, [2] center_row, [3:] transform_controls + bottom rows + # We replace center_row (index 2) with our new one containing the preview widget. + main_layout = self.ui["main_layout"] + model_controls = main_layout.children[0] + top_row = main_layout.children[1] + remaining_rows = main_layout.children[3:] # transform_controls, bottom_row1-5 + + main_layout.children = (model_controls, top_row, self.ui["center_row"], *remaining_rows) + + # Set the full resolution button reference + self.preview.set_full_res_button(self.ui["transform_buttons"]["full_res"]) # Attach the output widget so the session logs go there session.set_terminal_out(self.ui["out"]) @@ -124,6 +178,16 @@ def _pack_layout(self, main_layout, side_display): """ Helps compose final layout for display. """ + + # Reduce font size on buttons using widget styling + for _, widget in self.ui.items(): + if hasattr(widget, "style") and hasattr(widget, "description"): + widget.style.font_size = "90%" # Slightly smaller font for buttons + + # Adjust the image widget size + self.ui["image_widget"].layout.width = "auto" + self.ui["image_widget"].layout.height = "auto" + return HBox([main_layout, side_display]) def search_all_files(self): @@ -191,16 +255,20 @@ def update_progress( ) # update models last_normalisation_method only after successful eval if self.session.model.last_normalisation_method is None: - self.session.model.last_normalisation_method = self.session.cfg.normalisation_method + self.session.model.last_normalisation_method = ( + self.session.cfg.normalisation.normalisation_method + ) elif ( self.session.model.last_normalisation_method - != self.session.cfg.normalisation_method + != self.session.cfg.normalisation.normalisation_method ): logger.warning( - f"Evaluated with a new normalisation {self.session.cfg.normalisation_method.name} method " + f"Evaluated with a new normalisation {self.session.cfg.normalisation.normalisation_method.name} method " + f"not previously used with the model: {self.session.model.last_normalisation_method.name}" ) - self.session.model.last_normalisation_method = self.session.cfg.normalisation_method + self.session.model.last_normalisation_method = ( + self.session.cfg.normalisation.normalisation_method + ) # Display will be updated by the callback when completed self.display_top_files_scores() @@ -208,210 +276,85 @@ def update_progress( def display_top_files_scores(self): """Displays the top files and their scores.""" - self.current_index = 0 - self.update_image_display() + self.preview.set_index(0) + self.preview.update_display() self.ui["progress_bar"].style = {"bar_color": "green"} self.display_gallery() def update_image_display(self): """Updates the display of the current image.""" - - filename = self.session.filenames[self.current_index] - score = self.session.scores[self.current_index] - if self.session.cfg.normalisation_method != self.session.cached_image_normalisation_enum: - try: - logger.debug( - f"Re-Loading image from {filename} with norm {self.session.cfg.normalisation_method}" - ) - - # Load the image using the centralized function - filepath = os.path.join(self.session.cfg.data_dir, filename) - - img = read_and_resize_image( - filepath, - cfg=self.session.cfg, - convert_to_rgb=True, - ) - - # Normalise the image array to 0-1 range, then to 255 and convert to PIL Image - self.original_image = display_image_normalisation(img) - except Exception as e: - logger.error(f"Error loading image {filepath}: {e}") - return - # If no reload is necessary, use cached images - else: - img = self.session.img_catalog[self.current_index] - - # Normalise the image array to 0-1 range, then to 255 and convert to PIL Image - self.original_image = display_image_normalisation(img) - - # Apply other transforms - self.modified_image = apply_transforms_ui( - self.original_image, - invert=self.invert, - brightness=self.brightness, - contrast=self.contrast, - unsharp_mask_applied=self.unsharp_mask_applied, - show_r=self.show_r, - show_g=self.show_g, - show_b=self.show_b, - ) - - self.display_image(self.modified_image, filename, score) - - def display_image(self, img, filename=None, score=None): - """Displays the given PIL image in the widget.""" - image_byte_stream = numpy_array_to_byte_stream(np.array(img)) - self.ui["image_widget"].value = image_byte_stream - self.update_image_UI_label(filename, score) + self.preview.update_display() def update_image_UI_label(self, filename=None, score=None): """Updates the UI label with the current image's filename, score, and label.""" - label_color = "white" - label_text = "None" - label = self.session.get_label(self.current_index) - if label == "anomaly": - label_color = "red" - label_text = "Anomalous" - elif label == "normal": - label_color = "green" - label_text = "Nominal" - - # Get counts for anomalies and nominal samples - normal_count, anomalous_count = self.session.get_label_distribution() - - # Calculate newly annotated samples (those in active_learning_df) using cached method - new_nominal, new_anomalous = self.session.get_active_learning_counts() - - # Format the file name (shortened version) - fname = self.session.filenames[self.current_index] - fname_short = os.path.basename(fname) # Just show filename without path - # Shorten the filename if it's too long - if len(fname_short) > 32: - fname_short = fname_short[:15] + "..." + fname_short[-14:] - sc = self.session.scores[self.current_index] - total_len = len(self.session.img_catalog) - 1 - - self.ui["filename_text"].value = ( - # first line ─ Name left, Score & Index right - f'' - f"Name: {fname_short}" - f"" - f'' - f"Score: {sc:.2f} | Index: {self.current_index}/{total_len}" - f"" - # clear the float so the next line starts cleanly - f'
' - # second line ─ Label left, overall stats right - f'Label: ' - f'{label_text}' - # right-aligned block - f'' - f'Anomalies: {anomalous_count}(+{new_anomalous}) | ' - f'Nominal: {normal_count}(+{new_nominal})' - f"" - ) + self.preview.update_label_only() # ======== Sorting Methods ======== def sort_by_anomalous(self): """Sorts the images by their anomalous scores and updates the display.""" self.session.sort_by_anomalous() - self.current_index = 0 - self.update_image_display() + self.preview.set_index(0) + self.preview.update_display() def sort_by_nominal(self): """Sorts the images by their nominal scores and updates the display.""" self.session.sort_by_nominal() - self.current_index = 0 - self.update_image_display() + self.preview.set_index(0) + self.preview.update_display() def sort_by_mean(self): """Sorts the images by distance to mean score and updates the display.""" self.session.sort_by_mean() - self.current_index = 0 - self.update_image_display() + self.preview.set_index(0) + self.preview.update_display() def sort_by_median(self): """Sorts the images by distance to median score and updates the display.""" self.session.sort_by_median() - self.current_index = 0 - self.update_image_display() + self.preview.set_index(0) + self.preview.update_display() # ======== Navigation ======== def next_image(self): """Displays the next image in the catalog.""" - self.current_index = min(len(self.session.img_catalog) - 1, self.current_index + 1) - self.update_image_display() + new_index = min(len(self.session.img_catalog) - 1, self.preview.current_index + 1) + self.preview.reset_full_resolution_mode() + self.preview.set_index(new_index) + self.preview.update_display() def previous_image(self): """Displays the previous image in the catalog.""" - self.current_index = max(0, self.current_index - 1) - self.update_image_display() + new_index = max(0, self.preview.current_index - 1) + self.preview.reset_full_resolution_mode() + self.preview.set_index(new_index) + self.preview.update_display() # ======== Image Transformations ======== def restore_image(self): """Restores the current image to its original state.""" - self.invert = False self.ui["brightness_slider"].value = 1.0 self.ui["contrast_slider"].value = 1.0 - self.unsharp_mask_applied = False - - # Reset RGB channels - self.show_r = True - self.show_g = True - self.show_b = True self.ui["red_channel_checkbox"].value = True self.ui["green_channel_checkbox"].value = True self.ui["blue_channel_checkbox"].value = True - - self.modified_image = self.original_image - self.display_image(self.modified_image) + self.preview.restore() def toggle_invert_image(self): """Toggles the inversion of the current image.""" - self.invert = not self.invert - self.modified_image = apply_transforms_ui( - self.original_image, - invert=self.invert, - brightness=self.brightness, - contrast=self.contrast, - unsharp_mask_applied=self.unsharp_mask_applied, - show_r=self.show_r, - show_g=self.show_g, - show_b=self.show_b, - ) - self.display_image(self.modified_image) + self.preview.toggle_invert() def toggle_unsharp_mask(self): """Toggles the application of an unsharp mask.""" - self.unsharp_mask_applied = not self.unsharp_mask_applied - self.modified_image = apply_transforms_ui( - self.original_image, - invert=self.invert, - brightness=self.brightness, - contrast=self.contrast, - unsharp_mask_applied=self.unsharp_mask_applied, - show_r=self.show_r, - show_g=self.show_g, - show_b=self.show_b, - ) - self.display_image(self.modified_image) + self.preview.toggle_unsharp_mask() + + def show_full_resolution(self): + """Toggles full resolution mode and updates the display.""" + self.preview.toggle_full_resolution() def adjust_brightness_contrast(self, _): """Adjusts brightness and contrast of the current image.""" - self.brightness = self.ui["brightness_slider"].value - self.contrast = self.ui["contrast_slider"].value - self.modified_image = apply_transforms_ui( - self.original_image, - invert=self.invert, - brightness=self.brightness, - contrast=self.contrast, - unsharp_mask_applied=self.unsharp_mask_applied, - show_r=self.show_r, - show_g=self.show_g, - show_b=self.show_b, - ) - self.display_image(self.modified_image) + self.preview.set_brightness(self.ui["brightness_slider"].value) + self.preview.set_contrast(self.ui["contrast_slider"].value) def display_gallery(self): """Displays a small gallery of either mispredicted or top anomalous/nominal images.""" @@ -441,13 +384,19 @@ def display_gallery(self): path = os.path.join(self.cfg.data_dir, filename) if os.path.exists(path): try: - img_array = read_and_resize_image( - path, cfg=self.session.cfg, convert_to_rgb=True + img_array = load_and_process_single_wrapper( + path, + self.session.cfg, + desc="widget loading image", + show_progress=False, ) + img = Image.fromarray(img_array) mispredicted_images.append(img_array) + display_name = shorten_filename(filename) + image_text.append( - f"{filename}\nPred: {pred:.2f} | Label: {label:.2f}" + f"{display_name}\nPred: {pred:.2f} | Label: {label:.2f}" ) except Exception as e: @@ -538,8 +487,9 @@ def display_gallery(self): pil_img = Image.fromarray(img_arr) images.append(pil_img) filename = self.session.filenames[idx] + display_name = shorten_filename(filename) score = scores[idx] - image_text.append(f"{filename}\nScore: {score:.4f}") + image_text.append(f"{display_name}\nScore: {score:.4f}") for idx in top_nominal_indices: img_arr = self.session.img_catalog[idx] @@ -551,8 +501,9 @@ def display_gallery(self): pil_img = Image.fromarray(img_arr) images.append(pil_img) filename = self.session.filenames[idx] + display_name = shorten_filename(filename) score = scores[idx] - image_text.append(f"{filename}\nScore: {score:.4f}") + image_text.append(f"{display_name}\nScore: {score:.4f}") num_images = len(images) plt.figure(figsize=(12, 6), facecolor="black") @@ -580,7 +531,7 @@ def save_labels(self): def remember_current_file(self, _): """Remembers the currently displayed file.""" - self.session.remember_current_file(self.session.filenames[self.current_index]) + self.session.remember_current_file(self.session.filenames[self.preview.current_index]) def save_model(self): """Saves the model using the session.""" @@ -589,15 +540,19 @@ def save_model(self): def load_model(self): """Loads the model using the session.""" with self.ui["out"]: - logger.debug(f"Loading model, cfg norm: {self.session.cfg.normalisation_method}, ") + logger.debug( + f"Loading model, cfg norm: {self.session.cfg.normalisation.normalisation_method}, " + ) self.session.load_model() # Update the normalization dropdown to match the session's method - self.ui["normalisation_dropdown"].value = self.session.cfg.normalisation_method + self.ui["normalisation_dropdown"].value = ( + self.session.cfg.normalisation.normalisation_method + ) with self.ui["out"]: logger.debug( - f"Loaded model, cfg norm: {self.session.cfg.normalisation_method}," + f"Loaded model, cfg norm: {self.session.cfg.normalisation.normalisation_method}," + f" model norm: {self.session.model.last_normalisation_method}" ) self.update() @@ -608,14 +563,13 @@ def train(self): with self.ui["out"]: logger.debug( f"Session norm: {self.session.cached_image_normalisation_enum}, " - f"selected norm: {self.session.cfg.normalisation_method}" + f"selected norm: {self.session.cfg.normalisation.normalisation_method}" ) with self.ui["out"]: logger.debug("Starting training...") self.ui["progress_bar"].style = {"bar_color": "blue"} self.ui["progress_bar"].value = 0.0 - self.cfg.progress_bar = self.ui["progress_bar"] self.cfg.num_train_iter = self.ui["train_iteration_slider"].value logger.debug( @@ -669,16 +623,20 @@ def update_training_progress(iteration, total_iterations): # update models last_normalisation_method after successful training if self.session.model.last_normalisation_method is None: - self.session.model.last_normalisation_method = self.session.cfg.normalisation_method + self.session.model.last_normalisation_method = ( + self.session.cfg.normalisation.normalisation_method + ) elif ( self.session.model.last_normalisation_method - != self.session.cfg.normalisation_method + != self.session.cfg.normalisation.normalisation_method ): logger.warning( - f"Trained with a new normalisation {self.session.cfg.normalisation_method.name} method " + f"Trained with a new normalisation {self.session.cfg.normalisation.normalisation_method.name} method " + f"not previously used with the model: {self.session.model.last_normalisation_method.name}" ) - self.session.model.last_normalisation_method = self.session.cfg.normalisation_method + self.session.model.last_normalisation_method = ( + self.session.cfg.normalisation.normalisation_method + ) # Calculate total time taken total_time = time.time() - start_time @@ -693,7 +651,7 @@ def update(self): """Updates the UI components and performs evaluation.""" self.ui["progress_bar"].style = {"bar_color": "cyan"} self.session.update_predictions() - self.current_index = 0 + self.preview.set_index(0) if self.cfg.test_ratio > 0: if self.session.eval_performance is not None: @@ -745,29 +703,24 @@ def load_top_files(self): # Add channel toggle methods def toggle_red_channel(self, change): """Toggles the red channel on/off.""" - self.show_r = change["new"] - self.update_image_display() + self.preview.set_rgb_channels(r=change["new"]) def toggle_green_channel(self, change): """Toggles the green channel on/off.""" - self.show_g = change["new"] - self.update_image_display() + self.preview.set_rgb_channels(g=change["new"]) def toggle_blue_channel(self, change): """Toggles the blue channel on/off.""" - self.show_b = change["new"] - self.update_image_display() + self.preview.set_rgb_channels(b=change["new"]) def select_normalisation(self, change): """Updates the normalization method when dropdown selection changes.""" new_value = change["new"] - if new_value != self.cfg.normalisation_method: + if new_value != self.cfg.normalisation.normalisation_method: self.session.set_normalisation_method(new_value) - self.update_image_display() + self.preview.update_display() def unlabel_current_image(self): """Removes the label from the currently displayed image.""" - # Call session's unlabel_image method - self.session.unlabel_image(self.current_index) - # Update the UI to reflect the change - self.update_image_UI_label() + self.session.unlabel_image(self.preview.current_index) + self.preview.update_label_only() diff --git a/anomaly_match/ui/preview_widget.py b/anomaly_match/ui/preview_widget.py new file mode 100644 index 0000000..0d9df5c --- /dev/null +++ b/anomaly_match/ui/preview_widget.py @@ -0,0 +1,271 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +""" +PreviewWidget: A self-contained widget for displaying and manipulating preview images. +""" +import os +import numpy as np +import ipywidgets as widgets +from ipywidgets import VBox +from loguru import logger + +from anomaly_match.data_io.load_images import load_and_process_single_wrapper +from anomaly_match.image_processing.display_transforms import ( + apply_transforms_ui, + display_image_normalisation, +) +from anomaly_match.utils.numpy_to_byte_stream import numpy_array_to_byte_stream + + +class PreviewWidget: + """A widget component for displaying images with transformation controls.""" + + def __init__(self, session): + """ + Initialize the preview widget. + + Args: + session: The session object providing image data and configuration. + """ + self.session = session + + # Create UI elements + self.filename_text = widgets.HTML( + value="", + layout=widgets.Layout(background_color="black"), + style={"color": "white"}, + ) + + self.image_widget = widgets.Image( + value=b"", + width=600, + height=600, + layout=widgets.Layout(background_color="black"), + ) + + # Compose into a VBox + self.widget = VBox( + [self.filename_text, self.image_widget], + layout=widgets.Layout(background_color="black"), + ) + + # Current image index + self.current_index = 0 + + # Transform states + self.invert = False + self.brightness = 1.0 + self.contrast = 1.0 + self.unsharp_mask_applied = False + self.show_r = True + self.show_g = True + self.show_b = True + self.full_resolution_mode = False + + # Image data + self.original_image = None + self.modified_image = None + + # Reference to full resolution button (set by parent) + self._full_res_button = None + + def set_full_res_button(self, button): + """Set the full resolution button reference for updating its state.""" + self._full_res_button = button + + def set_index(self, index): + """Set the current image index.""" + self.current_index = index + + def update_display(self): + """Updates the display of the current image.""" + filename = self.session.filenames[self.current_index] + score = self.session.scores[self.current_index] + filepath = os.path.join(self.session.cfg.data_dir, filename) + + # Determine if we need to reload from disk + needs_reload = ( + self.full_resolution_mode + or self.session.cfg.normalisation.normalisation_method + != self.session.cached_image_normalisation_enum + ) + + if needs_reload: + try: + size_override = None if self.full_resolution_mode else "default" + logger.debug( + f"Loading image from {filename} (full_res={self.full_resolution_mode})" + ) + + img = load_and_process_single_wrapper( + filepath, + self.session.cfg, + desc="widget loading image", + show_progress=False, + size_override=size_override, + ) + + self.original_image = display_image_normalisation(img) + except Exception as e: + logger.error(f"Error loading image {filepath}: {e}") + return + else: + img = self.session.img_catalog[self.current_index] + self.original_image = display_image_normalisation(img) + + # Apply transforms + self.modified_image = apply_transforms_ui( + self.original_image, + invert=self.invert, + brightness=self.brightness, + contrast=self.contrast, + unsharp_mask_applied=self.unsharp_mask_applied, + show_r=self.show_r, + show_g=self.show_g, + show_b=self.show_b, + ) + + self._display_image(self.modified_image, filename, score) + + def _display_image(self, img, filename=None, score=None): + """Displays the given PIL image in the widget.""" + image_byte_stream = numpy_array_to_byte_stream(np.array(img)) + self.image_widget.value = image_byte_stream + self._update_label(filename, score) + + def _update_label(self, filename=None, score=None): + """Updates the UI label with the current image's filename, score, and label.""" + label_color = "white" + label_text = "None" + label = self.session.get_label(self.current_index) + if label == "anomaly": + label_color = "red" + label_text = "Anomalous" + elif label == "normal": + label_color = "green" + label_text = "Nominal" + + # Get counts for anomalies and nominal samples + normal_count, anomalous_count = self.session.get_label_distribution() + + # Calculate newly annotated samples using cached method + new_nominal, new_anomalous = self.session.get_active_learning_counts() + + # Format the file name (shortened version) + fname = self.session.filenames[self.current_index] + fname_short = os.path.basename(fname) + if len(fname_short) > 57: + fname_short = ( + fname_short[:45] + + "..." + + fname_short.split(".")[-2][-5:] + + "." + + fname_short.split(".")[-1] + ) + sc = self.session.scores[self.current_index] + total_len = len(self.session.img_catalog) - 1 + + self.filename_text.value = ( + f'' + f"Name: {fname_short}" + f"" + f'' + f"Score: {sc:.2f} | Index: {self.current_index}/{total_len}" + f"" + f'
' + f'Label: ' + f'{label_text}' + f'' + f'Anomalies: {anomalous_count}(+{new_anomalous}) | ' + f'Nominal: {normal_count}(+{new_nominal})' + f"" + ) + + def update_label_only(self): + """Updates only the label without reloading the image.""" + self._update_label() + + # ======== Transform Methods ======== + def restore(self): + """Restores the current image to its original state.""" + self.invert = False + self.brightness = 1.0 + self.contrast = 1.0 + self.unsharp_mask_applied = False + self.show_r = True + self.show_g = True + self.show_b = True + + self.modified_image = self.original_image + self._display_image(self.modified_image) + + def toggle_invert(self): + """Toggles the inversion of the current image.""" + self.invert = not self.invert + self._apply_transforms_and_display() + + def toggle_unsharp_mask(self): + """Toggles the application of an unsharp mask.""" + self.unsharp_mask_applied = not self.unsharp_mask_applied + self._apply_transforms_and_display() + + def set_brightness(self, value): + """Sets brightness and updates display.""" + self.brightness = value + self._apply_transforms_and_display() + + def set_contrast(self, value): + """Sets contrast and updates display.""" + self.contrast = value + self._apply_transforms_and_display() + + def set_rgb_channels(self, r=None, g=None, b=None): + """Sets RGB channel visibility.""" + if r is not None: + self.show_r = r + if g is not None: + self.show_g = g + if b is not None: + self.show_b = b + self._apply_transforms_and_display() + + def toggle_full_resolution(self): + """Toggles full resolution mode and updates the display.""" + self.full_resolution_mode = not self.full_resolution_mode + self._update_full_res_button() + self.update_display() + + def reset_full_resolution_mode(self): + """Resets full resolution mode to preview mode.""" + if self.full_resolution_mode: + self.full_resolution_mode = False + self._update_full_res_button() + + def _update_full_res_button(self): + """Updates the full resolution button appearance.""" + if self._full_res_button is None: + return + if self.full_resolution_mode: + self._full_res_button.description = "Show Preview" + self._full_res_button.style.button_color = "#17a2b8" + else: + self._full_res_button.description = "Show Full Resolution" + self._full_res_button.style.button_color = "#ffffff" + + def _apply_transforms_and_display(self): + """Applies current transforms and updates display.""" + self.modified_image = apply_transforms_ui( + self.original_image, + invert=self.invert, + brightness=self.brightness, + contrast=self.contrast, + unsharp_mask_applied=self.unsharp_mask_applied, + show_r=self.show_r, + show_g=self.show_g, + show_b=self.show_b, + ) + self._display_image(self.modified_image) diff --git a/anomaly_match/ui/ui_elements.py b/anomaly_match/ui/ui_elements.py index 190c230..cce2bfa 100644 --- a/anomaly_match/ui/ui_elements.py +++ b/anomaly_match/ui/ui_elements.py @@ -11,7 +11,7 @@ from anomaly_match import __version__ from anomaly_match.ui.memory_monitor import MemoryMonitor -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod HTML_setup = HTML( """ @@ -80,7 +80,12 @@ "Save Labels", "Load Top Files", ] -transform_button_names = ["Invert Image", "Restore", "Toggle Unsharp Mask"] +transform_button_config = { + "invert": "Invert", + "restore": "Restore", + "unsharp_mask": "Unsharp Mask", + "full_res": "Show Full Resolution", +} normalisation_button_names = ["Normalisation"] @@ -147,7 +152,7 @@ def create_ui_elements(): remember_button = Button( description="Remember", button_style="warning", - layout=widgets.Layout(background_color="black"), + layout=widgets.Layout(background_color="black", width="auto", flex="1 1 auto"), style={"button_color": ORANGE_COLOR}, ) @@ -161,7 +166,7 @@ def create_ui_elements(): # Add a dropdown for normalisation methods normalisation_dropdown = widgets.Dropdown( - options=NormalisationMethod.get_dropdown_options(), + options=NormalisationMethod.get_options(), value=NormalisationMethod.CONVERSION_ONLY, description=normalisation_button_names[0], layout=widgets.Layout(background_color="black", width="250px"), @@ -171,15 +176,16 @@ def create_ui_elements(): ) # Update transform buttons with white background but black text - transform_buttons = [ - Button( - description=w, + # Use flex layout to make buttons share space evenly + transform_buttons = { + key: Button( + description=label, button_style="success", - layout=widgets.Layout(background_color="black"), + layout=widgets.Layout(background_color="black", width="auto", flex="1 1 auto"), style={"button_color": WHITE_COLOR, "text_color": "black"}, ) - for w in transform_button_names - ] + for key, label in transform_button_config.items() + } # Sliders with adjusted widths for side-by-side display brightness_slider = widgets.FloatSlider( @@ -257,9 +263,10 @@ def create_ui_elements(): transform_controls = VBox( [ HBox( - transform_buttons + [remember_button] - ), # Add remember button to transform controls - slider_row, # Single row with both sliders and RGB toggles + list(transform_buttons.values()) + [remember_button], + layout=widgets.Layout(background_color="black", width="600px"), + ), + slider_row, ], layout=widgets.Layout(background_color="black"), ) @@ -299,9 +306,9 @@ def create_ui_elements(): train_iteration_slider = widgets.IntSlider( value=50, - min=50, - max=2000, - step=50, + min=10, + max=600, + step=20, description="Train Iterations", continuous_update=True, style={ @@ -514,20 +521,21 @@ def attach_click_listeners(widget): # Decision buttons def mark_anomalous(_): - widget.session.label_image(widget.current_index, "anomaly") + widget.session.label_image(widget.preview.current_index, "anomaly") widget.update_image_UI_label() def mark_nominal(_): - widget.session.label_image(widget.current_index, "normal") + widget.session.label_image(widget.preview.current_index, "normal") widget.update_image_UI_label() widget.ui["decision_buttons"][0].on_click(mark_anomalous) widget.ui["decision_buttons"][1].on_click(mark_nominal) # Transform buttons - widget.ui["transform_buttons"][0].on_click(lambda _: widget.toggle_invert_image()) - widget.ui["transform_buttons"][1].on_click(lambda _: widget.restore_image()) - widget.ui["transform_buttons"][2].on_click(lambda _: widget.toggle_unsharp_mask()) + widget.ui["transform_buttons"]["invert"].on_click(lambda _: widget.toggle_invert_image()) + widget.ui["transform_buttons"]["restore"].on_click(lambda _: widget.restore_image()) + widget.ui["transform_buttons"]["unsharp_mask"].on_click(lambda _: widget.toggle_unsharp_mask()) + widget.ui["transform_buttons"]["full_res"].on_click(lambda _: widget.show_full_resolution()) # Brightness/Contrast observers widget.ui["brightness_slider"].observe(widget.adjust_brightness_contrast, names="value") diff --git a/anomaly_match/ui/utility_functions.py b/anomaly_match/ui/utility_functions.py deleted file mode 100644 index edf4441..0000000 --- a/anomaly_match/ui/utility_functions.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) European Space Agency, 2025. -# -# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which -# is part of this source code package. No part of the package, including -# this file, may be copied, modified, propagated, or distributed except according to -# the terms contained in the file 'LICENCE.txt'. -from PIL import ImageOps, ImageEnhance, ImageFilter, Image -import numpy as np - - -def apply_transforms( - img, - invert=False, - brightness=1.0, - contrast=1.0, - unsharp_mask_applied=False, - show_r=True, - show_g=True, - show_b=True, -): - """ - Applies the requested transformations to the given PIL Image. - - Args: - img (PIL.Image.Image): The original image. - invert (bool): Whether to invert colors. - brightness (float): Brightness factor. - contrast (float): Contrast factor. - unsharp_mask_applied (bool): Whether to apply an unsharp mask. - show_r (bool): Whether to show the red channel. - show_g (bool): Whether to show the green channel. - show_b (bool): Whether to show the blue channel. - - Returns: - PIL.Image.Image: The transformed image. - """ - # Apply inversion - if invert: - img = ImageOps.invert(img) - - # Apply brightness - if brightness != 1.0: - enhancer = ImageEnhance.Brightness(img) - img = enhancer.enhance(brightness) - - # Apply contrast - if contrast != 1.0: - enhancer = ImageEnhance.Contrast(img) - img = enhancer.enhance(contrast) - - # Apply unsharp mask if enabled - if unsharp_mask_applied: - img = img.filter(ImageFilter.UnsharpMask()) - - # Apply channel toggling - if not (show_r and show_g and show_b): - # Convert PIL image to numpy array - img_array = np.array(img) - - # Create a mask for RGB channels - channels_mask = [show_r, show_g, show_b] - - # Apply masking to the image array (zero out disabled channels) - for i, show_channel in enumerate(channels_mask): - if not show_channel: - img_array[:, :, i] = 0 - - # Convert back to PIL image - img = Image.fromarray(img_array) - - return img diff --git a/anomaly_match/utils/constants.py b/anomaly_match/utils/constants.py index 4454005..9157e1c 100644 --- a/anomaly_match/utils/constants.py +++ b/anomaly_match/utils/constants.py @@ -6,5 +6,4 @@ # the terms contained in the file 'LICENCE.txt'. """Constants used across the AnomalyMatch module.""" -# Supported image file extensions -SUPPORTED_IMAGE_EXTENSIONS = [".fits", ".jpeg", ".jpg", ".png", ".tif", ".tiff"] +# Supported image file extensions was moved to fitsbolt diff --git a/anomaly_match/utils/cutana_stream_utils.py b/anomaly_match/utils/cutana_stream_utils.py new file mode 100644 index 0000000..5b4514c --- /dev/null +++ b/anomaly_match/utils/cutana_stream_utils.py @@ -0,0 +1,130 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. + +import warnings +from pathlib import Path + +import pyarrow.parquet as pq +import pandas as pd +from loguru import logger +from cutana.catalogue_preprocessor import validate_catalogue_columns, check_fits_files_exist + + +def cutana_validate_files_and_count_sources( + files: list[Path | str], chunk_size: int = 100_000 +) -> tuple[list[Path], int, int]: + """Validate catalogue files for cutana compatibility and count total number sources and total number of chunks to process. + + Args: + files (list[Path | str]): list of file paths to validate (CSV or Parquet). + chunk_size (int): number of rows to read per chunk. + + Returns: + tuple[list[Path], int, int]: valid files, total number of sources, and total number of chunks. + """ + + def _validate_against_cutana(index: int, dataframe: pd.DataFrame) -> bool: + # Check header once + if index == 0: + errors = validate_catalogue_columns(dataframe) + if errors: + return errors + + errors, _ = check_fits_files_exist(dataframe) + if errors: + return errors + return [] + + valid_files = [] + total_sources = 0 + total_chunks = 0 + + for file in files: + + is_file_valid = True + + current_file_sources = 0 + current_file_chunks = 0 + + if isinstance(file, Path): + file_type = file.name.split(".")[-1] + else: + file_type = file.split(".")[-1] + + if file_type == "csv": + for i, df in enumerate(pd.read_csv(file, chunksize=chunk_size)): + + errors = _validate_against_cutana(i, df) + if errors: + current_file_sources = 0 + current_file_chunks = 0 + is_file_valid = False + msg = f"File {file} did not pass cutana compatibility check and will be skipped ({errors})" + logger.warning(msg) + warnings.warn(msg, RuntimeWarning) + break + current_file_sources += len(df) + current_file_chunks += 1 + + elif file_type == "parquet": + parquet_file = pq.ParquetFile(file) + for i, batch in enumerate(parquet_file.iter_batches(batch_size=chunk_size)): + df = batch.to_pandas() + + errors = _validate_against_cutana(i, df) + if errors: + current_file_sources = 0 + current_file_chunks = 0 + is_file_valid = False + msg = f"File {file} did not pass cutana compatibility check and will be skipped ({errors})" + logger.warning(msg) + warnings.warn(msg, RuntimeWarning) + break + current_file_sources += len(df) + current_file_chunks += 1 + else: + is_file_valid = False + + total_sources += current_file_sources + total_chunks += current_file_chunks + if is_file_valid: + valid_files.append(file) + + return valid_files, total_sources, total_chunks + + +def cutana_buffer_generator(files: list[Path | str], buffer_path: Path, chunk_size: int = 100_000): + """Generate temporary buffer files by reading catalogue files in chunks. + + Args: + files (list[Path | str]): list of file paths to process (CSV or Parquet). + buffer_path (Path): path where temporary buffer parquet will be written. + chunk_size (int): number of rows to read per chunk. + + Yields: + Path: path to the buffer file containing the current chunk. + """ + buffer_path.parent.mkdir(parents=True, exist_ok=True) + + for file in files: + + if isinstance(file, Path): + file_type = file.name.split(".")[-1] + else: + file_type = file.split(".")[-1] + + if file_type == "csv": + for df in pd.read_csv(file, chunksize=chunk_size): + df.to_parquet(buffer_path, index=False) + yield buffer_path + + else: # if not CSV then Parquet + parquet_file = pq.ParquetFile(file) + for batch in parquet_file.iter_batches(batch_size=chunk_size): + df = batch.to_pandas() + df.to_parquet(buffer_path, index=False) + yield buffer_path diff --git a/anomaly_match/utils/get_default_cfg.py b/anomaly_match/utils/get_default_cfg.py index a780e33..59520c3 100644 --- a/anomaly_match/utils/get_default_cfg.py +++ b/anomaly_match/utils/get_default_cfg.py @@ -7,9 +7,10 @@ from dotmap import DotMap import os +import numpy as np from .create_model_string import create_model_string -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod def get_default_cfg(): @@ -33,31 +34,60 @@ def get_default_cfg(): cfg.save_path = os.path.join(cfg.save_dir) cfg.save_file = create_model_string(cfg) + ".pth" cfg.model_path = None # Will be set by SessionIOHandler when session is active + cfg.N_batch_prediction = None # User specified batch size for evaluating a directory, if None: determined automatically + cfg.subprocess_buffer_size = ( + 100_000 # Number of sources packed into intermediate files for subprocesses + ) cfg.seed = 42 cfg.test_ratio = 0.0 # DataLoader settings cfg.N_to_load = 1000 - cfg.size = [224, 224] - cfg.num_workers = 4 cfg.pin_memory = True cfg.oversample = True - cfg.interpolation_order = 1 # order of interpolation for resizing with skimage, 0-5 - # Normalisation settings - cfg.normalisation_method = NormalisationMethod.CONVERSION_ONLY - # Optional normalisation settings + + cfg.num_workers = 4 + # normalisation settings for fitsbolt settings cfg.normalisation = DotMap() - cfg.normalisation.maximum_value = None # None or float - cfg.normalisation.minimum_value = None # None or float - cfg.normalisation.crop_for_maximum_value = None # None or integer tuple (height, width) - # Bool, if False assumes min value to be 0 or cfg.normalisation.minimum_value if not None - cfg.normalisation.log_calculate_minimum_value = False - # only used if cfg.normalisation_method == NormalisationMethod.ASINH: - # asinh_scale list of 3 floats > 0, defining the scale for each channel (lower = higher stretch): - cfg.normalisation.asinh_scale = [0.7, 0.7, 0.7] - # asinh_clip list of 3 floats in ]0.,100.], defining the clip for each channel: - cfg.normalisation.asinh_clip = [99.8, 99.8, 99.8] + cfg.normalisation.output_dtype = np.uint8 # output dtype of the images + # NOTE: image_size has no default - user must explicitly set it + cfg.normalisation.n_output_channels = 3 # number of output channels (e.g. 3 for RGB) + + # FITS file handling settings + # fits_extension: Extension(s) to use when loading FITS files + # (can be int, string, or list of int/string, or list of lists of int/string) + cfg.normalisation.fits_extension = None + + # channel_combination: (np.array) combine FITS extensions into n_output (3 = RGB) channels, shape n_out x n_input = len + # cfg.normalisation.fits_extension, or None if only one extension is used or n_out=n_input + cfg.normalisation.channel_combination = None + + # further interpolation and normalisation settings + cfg.normalisation.interpolation_order = ( + 1 # order of interpolation for resizing with skimage, 0-5 + ) + cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY + # settings for normalisation: + cfg.normalisation.norm_maximum_value = None # None or float + cfg.normalisation.norm_minimum_value = None # None or float + cfg.normalisation.norm_crop_for_maximum_value = None # None or integer tuple (height, width) + # Bool, if False assumes min value to be 0 or cfg.normalisation.norm_minimum_value if not None + cfg.normalisation.norm_log_calculate_minimum_value = False + # only used if cfg.normalisation.normalisation_method == NormalisationMethod.ASINH: asinh_scale list of n_output_channel - + # floats > 0, defining the scale for each channel (lower = higher stretch): + cfg.normalisation.norm_asinh_scale = [ + 0.7, + 0.7, + 0.7, + ] + # norm_asinh_clip: asinh_clip list of n_output_channel floats in ]0.,100.], defining the clip for each channel: + cfg.normalisation.norm_asinh_clip = [ + 99.8, + 99.8, + 99.8, + ] + # end of fitsbolt settings # FixMatch settings cfg.ema_m = 0.99 @@ -83,7 +113,4 @@ def get_default_cfg(): cfg.pretrained = True cfg.net = "efficientnet-lite0" - # FITS file handling settings - cfg.fits_extension = None # Extension(s) to use when loading FITS files (can be int, string, or list of int/string) - return cfg diff --git a/anomaly_match/utils/print_cfg.py b/anomaly_match/utils/print_cfg.py index b897ded..947fe00 100644 --- a/anomaly_match/utils/print_cfg.py +++ b/anomaly_match/utils/print_cfg.py @@ -31,7 +31,7 @@ def print_cfg(cfg: DotMap): "seed", "test_ratio", "N_to_load", - "size", + "image_size", "num_workers", "pin_memory", ], diff --git a/anomaly_match/utils/validate_config.py b/anomaly_match/utils/validate_config.py index 2b7b124..f3b37da 100644 --- a/anomaly_match/utils/validate_config.py +++ b/anomaly_match/utils/validate_config.py @@ -9,7 +9,7 @@ from loguru import logger import os -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.cfg.create_config import create_config as fb_create_cfg def _return_required_and_optional_keys(): @@ -37,7 +37,7 @@ def _return_required_and_optional_keys(): # Required file parameters "label_file": ["file", None, None, False, None], "metadata_file": ["file", None, None, True, None], # Optional, can be None - # Required numeric parameters + # Required numeric parameter" "seed": [float, None, None, False, None], # accepts int or float # Required positive integers "num_workers": [int, 1, None, False, None], @@ -48,6 +48,7 @@ def _return_required_and_optional_keys(): # Required integers >= 10 "N_to_load": [int, 10, None, False, None], "top_N": [int, 10, None, False, None], + "subprocess_buffer_size": [int, 100, None, False, None], # Required floats in range [0, 1] "test_ratio": [float, 0.0, 1.0, False, None], "ema_m": [float, 0.0, 1.0, False, None], @@ -63,7 +64,6 @@ def _return_required_and_optional_keys(): "oversample": [bool, None, None, False, None], "hard_label": [bool, None, None, False, None], "pretrained": [bool, None, None, False, None], - "normalisation.log_calculate_minimum_value": [bool, None, None, False, None], # Required parameters with allowed values "opt": [str, None, None, False, ["SGD", "Adam"]], "net": [str, None, None, False, ["efficientnet-lite0"]], @@ -75,21 +75,13 @@ def _return_required_and_optional_keys(): ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "TRACE"], ], # Required special parameters - "size": ["special_size", None, None, False, None], "num_eval_iter": ["special_eval_iter", None, None, False, None], - "normalisation_method": ["special_normalisation_method", None, None, False, None], - "normalisation": ["special_normalisation", None, None, False, None], - "normalisation.asinh_scale": ["special_asinh_scale", None, None, False, None], - "normalisation.asinh_clip": ["special_asinh_clip", None, None, False, None], - "interpolation_order": [int, 0, 5, False, None], # 0-5 for skimage interpolation" # Optional directory parameters "prediction_search_dir": ["directory", None, None, True, None], - # Optional numeric parameters - "normalisation.maximum_value": [float, None, None, True, None], - "normalisation.minimum_value": [float, None, None, True, None], - # Optional special parameters - "normalisation.crop_for_maximum_value": ["special_crop", None, None, True, None], - "fits_extension": ["special_fits_extension", None, None, True, None], + "N_batch_prediction": [int, 1, None, True, None], + # fitsbolt config parameters - only validate that it's a DotMap and check size + "normalisation": ["special_fitsbolt", None, None, False, None], + "normalisation.image_size": ["special_size", None, None, False, None], } return config_spec @@ -254,11 +246,12 @@ def _format_constraints(): # Handle special validation cases elif dtype == "special_size": - if not isinstance(value, (list, tuple)) or len(value) != 2: - raise ValueError( - f"{param_name} must be a list or tuple of length 2, got {type(value).__name__}" - + f"with length {len(value) if hasattr(value, '__len__') else 'unknown'}" - ) + if value is not None: + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError( + f"{param_name} must be a list or tuple of length 2, got {type(value).__name__}" + + f"with length {len(value) if hasattr(value, '__len__') else 'unknown'}" + ) elif dtype == "special_eval_iter": if not isinstance(value, int) or (value != -1 and value <= 0): @@ -266,111 +259,54 @@ def _format_constraints(): f"{param_name} must be an integer > 0 or -1, got {value} (type: {type(value).__name__})" ) - elif dtype == "special_normalisation_method": - if not isinstance(value, NormalisationMethod): - raise ValueError( - f"{param_name} must be a NormalisationMethod enum value, got {type(value).__name__}" - ) - - elif dtype == "special_normalisation": + elif dtype == "special_fitsbolt": + # The fitsbolt DotMap will be validated by fb_create_cfg if not isinstance(value, DotMap): raise ValueError(f"{param_name} must be a DotMap, got {type(value).__name__}") - - elif dtype == "special_asinh_scale": - if not isinstance(value, (list, tuple, int, float)): - raise ValueError( - f"{param_name} must be a number or list/tuple of 3 numbers > 0, got {type(value).__name__}" - ) - if isinstance(value, (list, tuple)): - if len(value) != 3: - raise ValueError( - f"{param_name} if list/tuple, must have length 3, got length {len(value)}" - ) - if not all(isinstance(x, (int, float)) for x in value): - raise ValueError( - f"{param_name} values must be numbers, got types: {[type(x).__name__ for x in value]}" - ) - if not all(0 < x for x in value): - raise ValueError(f"{param_name} values must be > 0, got: {value}") - else: - # Single value - if not isinstance(value, (int, float)): - raise ValueError( - f"{param_name} must be a number > 0, got {type(value).__name__}" - ) - if not (0 < value): - raise ValueError(f"{param_name} must > 0, got: {value}") - - elif dtype == "special_asinh_clip": - if not isinstance(value, (list, tuple, int, float)): - raise ValueError( - f"{param_name} must be a number or list/tuple of 3 numbers in ]0,100.], got {type(value).__name__}" - ) - if isinstance(value, (list, tuple)): - if len(value) != 3: - raise ValueError( - f"{param_name} if list/tuple, must have length 3, got length {len(value)}" - ) - if not all(isinstance(x, (int, float)) for x in value): - raise ValueError( - f"{param_name} values must be numbers, got types: {[type(x).__name__ for x in value]}" - ) - if not all(0 < x <= 100 for x in value): - raise ValueError(f"{param_name} values must be in range ]0,100.], got: {value}") - else: - # Single value - if not isinstance(value, (int, float)): - raise ValueError( - f"{param_name} must be a number in ]0,100.], got {type(value).__name__}" - ) - if not (0 < value <= 100): - raise ValueError(f"{param_name} must be in range ]0,100.], got: {value}") - - elif dtype == "special_crop": - if value is not None: - if not isinstance(value, (tuple, list)) or len(value) != 2: - raise ValueError( - f"{param_name} if set, must be a tuple of two integers, got {type(value).__name__}" - ) - if not all(isinstance(x, int) for x in value): - raise ValueError( - f"{param_name} values must be integers, got types: {[type(x).__name__ for x in value]}" - ) - - elif dtype == "special_fits_extension": - if value is not None: - if isinstance(value, list): - if len(value) not in [1, 3]: - raise ValueError( - f"{param_name} must be a str/int or list of strings/ints of length 1 or 3," - + f" got list of length {len(value)}" - ) - for v in value: - if not isinstance(v, (str, int)): - raise ValueError( - f"{param_name} list elements must be str or int, got {type(v).__name__}" - ) - elif not isinstance(value, (str, int)): - raise ValueError( - f"{param_name} must be a str/int or list of strings/ints, got {type(value).__name__}" - ) - else: raise ValueError(f"Unknown data type for {param_name}: {dtype}") - # Custom cross-parameter validation - if "normalisation" in cfg: - if ( - hasattr(cfg.normalisation, "maximum_value") - and hasattr(cfg.normalisation, "minimum_value") - and isinstance(cfg.normalisation.maximum_value, (int, float)) - and isinstance(cfg.normalisation.minimum_value, (int, float)) - ): - if cfg.normalisation.maximum_value <= cfg.normalisation.minimum_value: - raise ValueError( - f"normalisation.maximum_value {cfg.normalisation.maximum_value} must be larger than " - f"normalisation.minimum_value {cfg.normalisation.minimum_value}" - ) + # Also validate normalisation configuration with its own validation function if possible + if hasattr(cfg, "normalisation"): + try: + # Use fitsbolt's own validation by calling its create_config function + _ = fb_create_cfg( + output_dtype=cfg.normalisation.output_dtype, + size=cfg.normalisation.image_size, + fits_extension=cfg.normalisation.fits_extension, + interpolation_order=cfg.normalisation.interpolation_order, + normalisation_method=cfg.normalisation.normalisation_method, + channel_combination=cfg.normalisation.channel_combination, + num_workers=cfg.num_workers, + norm_maximum_value=cfg.normalisation.norm_maximum_value, + norm_minimum_value=cfg.normalisation.norm_minimum_value, + norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, + norm_crop_for_maximum_value=cfg.normalisation.norm_crop_for_maximum_value, + norm_asinh_scale=cfg.normalisation.norm_asinh_scale, + norm_asinh_clip=cfg.normalisation.norm_asinh_clip, + ) + logger.debug("fitsbolt configuration validated successfully") + # add the fitsbolt keys to expected keys used in above function call to expected_keys + expected_keys.update( + [ + "normalisation.output_dtype", + "normalisation.image_size", + "normalisation.n_output_channels", + "normalisation.fits_extension", + "normalisation.interpolation_order", + "normalisation.normalisation_method", + "normalisation.channel_combination", + "normalisation.norm_maximum_value", + "normalisation.norm_minimum_value", + "normalisation.norm_log_calculate_minimum_value", + "normalisation.norm_crop_for_maximum_value", + "normalisation.norm_asinh_scale", + "normalisation.norm_asinh_clip", + ] + ) + except Exception as e: + logger.error(f"normalisation configuration validation failed: {e}") + raise ValueError(f"normalisation configuration validation failed: {e}") # Check for unexpected keys actual_keys = _get_all_keys(cfg) diff --git a/environment.yml b/environment.yml index c2de0c5..870dc52 100644 --- a/environment.yml +++ b/environment.yml @@ -20,7 +20,7 @@ dependencies: - loguru - matplotlib - numpy - - pandas + - pandas<3 - python=3.11 - pytorch - pytorch-cuda=12.4 @@ -37,3 +37,5 @@ dependencies: - efficientnet_lite_pytorch - efficientnet_lite0_pytorch_model - opencv-python-headless + - fitsbolt>=0.1.6 + - cutana>=0.2.1 diff --git a/environment_CI.yml b/environment_CI.yml index fa644cc..3af1cad 100644 --- a/environment_CI.yml +++ b/environment_CI.yml @@ -19,7 +19,7 @@ dependencies: - loguru - matplotlib - numpy - - pandas + - pandas<3 - python=3.11 - pytorch - pytest @@ -37,3 +37,5 @@ dependencies: - albumentations - efficientnet_lite_pytorch - efficientnet_lite0_pytorch_model + - fitsbolt>=0.1.6 + - cutana>=0.2.1 diff --git a/paper_scripts/README.md b/paper_scripts/README.md index dc93cf3..8ba0b38 100644 --- a/paper_scripts/README.md +++ b/paper_scripts/README.md @@ -97,6 +97,11 @@ python create_results.py --galaxymnist --input-dir /path/to/datasets --output-di # Run all experiments with a different seed python create_results.py --all --seed 123 + +# Run the experiments for the astronomaly comparison with multi seeds +# all settings are optional +python galaxyzoo_multi_seed.py --input_dir X --output_dir Y --seeds [1,2,3,4,5] + ``` ## Additional Visualizations diff --git a/paper_scripts/dataset_plot.py b/paper_scripts/dataset_plot.py index d66a434..dac6293 100644 --- a/paper_scripts/dataset_plot.py +++ b/paper_scripts/dataset_plot.py @@ -196,7 +196,7 @@ def get_galaxyzoo_samples(df, data_dir, anomaly_samples=6, normal_samples=6): ) for _, row in anomaly_sample_rows.iterrows(): - img_path = os.path.join(data_dir, row["filename"]) + img_path = os.path.join(data_dir, row["original_filename"]) if os.path.exists(img_path): samples.append( { @@ -215,7 +215,7 @@ def get_galaxyzoo_samples(df, data_dir, anomaly_samples=6, normal_samples=6): normal_sample_rows = normal_df.sample(min(normal_samples, len(normal_df)), random_state=42) for _, row in normal_sample_rows.iterrows(): - img_path = os.path.join(data_dir, row["filename"]) + img_path = os.path.join(data_dir, row["original_filename"]) if os.path.exists(img_path): samples.append( { @@ -305,17 +305,17 @@ def create_compact_figure( ax.imshow(img) # Add class name as annotation inside the image - rect = Rectangle((0, 0), img.shape[1], 20, color="black", alpha=0.6) + rect = Rectangle((0, 0), img.shape[1], img.shape[1] * 0.25, color="black", alpha=0.6) ax.add_patch(rect) ax.text( img.shape[1] / 2, - 10, + img.shape[1] * 0.005, sample["class_name"], color="white", fontsize=8, ha="center", - va="center", + va="top", ) # Remove ticks and add border @@ -336,11 +336,17 @@ def create_compact_figure( img = np.array(Image.open(sample["path"])) ax.imshow(img) - rect = Rectangle((0, 0), img.shape[1], 20, color="red", alpha=0.6) + rect = Rectangle((0, 0), img.shape[1], img.shape[1] * 0.13, color="red", alpha=0.6) ax.add_patch(rect) ax.text( - img.shape[1] / 2, 10, "Anomaly", color="white", fontsize=8, ha="center", va="center" + img.shape[1] / 2, + img.shape[1] * 0.01, + "Anomaly", + color="white", + fontsize=8, + ha="center", + va="top", ) ax.set_xticks([]) @@ -357,10 +363,18 @@ def create_compact_figure( img = np.array(Image.open(sample["path"])) ax.imshow(img) - rect = Rectangle((0, 0), img.shape[1], 20, color="black", alpha=0.6) + rect = Rectangle((0, 0), img.shape[1], img.shape[1] * 0.13, color="black", alpha=0.6) ax.add_patch(rect) - ax.text(img.shape[1] / 2, 10, "Normal", color="white", fontsize=8, ha="center", va="center") + ax.text( + img.shape[1] / 2, + img.shape[1] * 0.01, + "Normal", + color="white", + fontsize=8, + ha="center", + va="top", + ) ax.set_xticks([]) ax.set_yticks([]) @@ -378,17 +392,17 @@ def create_compact_figure( img = np.array(Image.open(sample["path"])) ax.imshow(img) - rect = Rectangle((0, 0), img.shape[1], 20, color="red", alpha=0.6) + rect = Rectangle((0, 0), img.shape[1], img.shape[1] * 0.13, color="red", alpha=0.6) ax.add_patch(rect) ax.text( img.shape[1] / 2, - 10, + img.shape[1] * 0.01, sample["class_name"], color="white", fontsize=8, ha="center", - va="center", + va="top", ) ax.set_xticks([]) @@ -407,11 +421,17 @@ def create_compact_figure( img = np.array(Image.open(sample["path"])) ax.imshow(img) - rect = Rectangle((0, 0), img.shape[1], 20, color="black", alpha=0.6) + rect = Rectangle((0, 0), img.shape[1], img.shape[1] * 0.13, color="black", alpha=0.6) ax.add_patch(rect) ax.text( - img.shape[1] / 2, 10, "Nominal", color="white", fontsize=8, ha="center", va="center" + img.shape[1] / 2, + img.shape[1] * 0.01, + "Nominal", + color="white", + fontsize=8, + ha="center", + va="top", ) ax.set_xticks([]) @@ -431,17 +451,19 @@ def create_compact_figure( img = np.array(Image.open(sample["path"])) ax.imshow(img) - rect = Rectangle((0, 0), img.shape[1], 20, color="black", alpha=0.6) + rect = Rectangle( + (0, 0), img.shape[1], img.shape[1] * 0.13, color="black", alpha=0.6 + ) ax.add_patch(rect) ax.text( img.shape[1] / 2, - 10, + img.shape[1] * 0.01, "Nominal", color="white", fontsize=8, ha="center", - va="center", + va="top", ) ax.set_xticks([]) @@ -452,12 +474,13 @@ def create_compact_figure( normal_idx += 1 # Add dataset labels - fig.text(0.01, 0.96, "GalaxyMNIST", fontsize=12, fontweight="bold", ha="left") - fig.text(0.01, 0.76, "GalaxyZoo", fontsize=12, fontweight="bold", ha="left") + fig.text(0.01, 0.94, "GalaxyMNIST", fontsize=12, fontweight="bold", ha="left") + fig.text(0.01, 0.75, "Galaxy Zoo 2", fontsize=12, fontweight="bold", ha="left") fig.text(0.01, 0.56, "MiniImageNet", fontsize=12, fontweight="bold", ha="left") # Add "Anomaly Classes in Red" annotation at top right - fig.text(0.99, 0.96, "Anomaly Classes in Red", fontsize=10, color="red", ha="right") + fig.text(0.99, 0.75, "Anomaly Classes in Red", fontsize=10, color="red", ha="right") + fig.text(0.99, 0.56, "Anomaly Classes in Red", fontsize=10, color="red", ha="right") return fig @@ -465,12 +488,12 @@ def create_compact_figure( def main(): """Main function to create and save the dataset visualization.""" # Define base paths - datasets_dir = os.path.join("datasets/") + datasets_dir = os.path.join("/media/team_workspaces/AnomalyMatch/paper_datasets/") # Define paths for dataset files galaxymnist_csv_path = os.path.join(datasets_dir, "labels_galaxymnist.csv") miniimagenet_csv_path = os.path.join(datasets_dir, "labels_miniimagenet.csv") - galaxyzoo_csv_path = os.path.join(datasets_dir, "labels_galaxyzoo.csv") + galaxyzoo_csv_path = os.path.join(datasets_dir, "galaxyzoo_labels.csv") galaxymnist_image_dir = os.path.join(datasets_dir, "galaxymnist") miniimagenet_image_dir = os.path.join(datasets_dir, "miniimagenet") diff --git a/paper_scripts/galaxyzoo_multi_seed.py b/paper_scripts/galaxyzoo_multi_seed.py new file mode 100644 index 0000000..bd9b0cc --- /dev/null +++ b/paper_scripts/galaxyzoo_multi_seed.py @@ -0,0 +1,652 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +""" +Multi-seed GalaxyZoo benchmark script for AnomalyMatch paper experiments. + +This script runs GalaxyZoo experiments with multiple predetermined seeds to generate +robust statistical results with averages and standard deviations. + +The script will: +1. Run GalaxyZoo experiments for each seed using the same configurations as create_results.py +2. Collect all results into a summary folder with statistics (mean ± std) +3. Create average astronomaly comparison plots with standard deviation shading + +Usage: + # Run with default seeds and directories + python galaxyzoo_multi_seed.py + + # Custom output directory + python galaxyzoo_multi_seed.py --output_dir my_results + + # Custom seeds + python galaxyzoo_multi_seed.py --seeds 42 123 456 + + # Custom input directory for datasets + python galaxyzoo_multi_seed.py --input_dir /path/to/datasets + +Output Structure: + results_YYYYMMDD_HHMMSS/ + ├── seed_42/ + │ └── galaxyzoo/ + │ ├── zoo_class1_n200_ratio0.500/ + │ ├── zoo_class1_n400_ratio0.500/ + │ └── zoo_class1_n400_ratio0.015/ + ├── seed_123/ + │ └── ... + ├── summary/ + │ ├── galaxyzoo_multi_seed_summary.csv + │ └── galaxyzoo_results.csv (compatible with results_analysis.py) + └── average_plots/ + ├── average_astronomaly_comparison_zoo_class1_n200_ratio0.500_iter{DEFAULT_TRAINING_RUNS}.pdf + ├── average_astronomaly_comparison_zoo_class1_n400_ratio0.500_iter{DEFAULT_TRAINING_RUNS}.pdf + └── average_astronomaly_comparison_zoo_class1_n400_ratio0.015_iter{DEFAULT_TRAINING_RUNS}.pdf +""" + +import sys +import time +import datetime +import subprocess +import pickle +from pathlib import Path +import argparse +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from typing import List, Dict + +# Import plotting constants from paper_plots +sys.path.append(str(Path(__file__).parent)) +try: + from paper_plots import ( + BLUE, + GREEN, + ORANGE, + PURPLE, + RED, + REFERENCE_LINE_COLOR, + REFERENCE_LINE_STYLE, + PERFECT_LINE_STYLE, + DEFAULT_DPI, + ) +except ImportError: + # Fallback colors if import fails + BLUE = "#1f77b4" + GREEN = "#2ca02c" + ORANGE = "#ff7f0e" + PURPLE = "#9467bd" + RED = "#d62728" + REFERENCE_LINE_COLOR = "gray" + REFERENCE_LINE_STYLE = "--" + PERFECT_LINE_STYLE = "-." + DEFAULT_DPI = 300 + +# ========== CONFIGURATION ========== +# Predetermined seeds for reproducible results +GALAXYZOO_SEEDS = [42, 76032, 730, 83209, 13798, 4538, 5923, 99271, 3762] + +# Default parameters (matching create_results.py) +DEFAULT_IMAGE_SIZE = 224 +DEFAULT_TRAINING_RUNS = 3 +DEFAULT_TRAIN_ITERATIONS = 100 +DEFAULT_N_MISLABELED = 20 +DEFAULT_N_TO_LOAD = 10000 +GALAXYZOO_N_SAMPLES = 500 +GALAXYZOO_ANOMALY_RATIO = 0.05 +GALAXYZOO_THRESHOLDS = [0.95, 0.9, 0.8, 0.7] +GALAXYZOO_CLASSES = [1] # Only anomaly class since it's binary classification + +# Other settings +SKIP_MOCK_UI = True +SAVE_LABELED_IMAGES = False + +# Define specific configurations for GalaxyZoo (matching create_results.py) +GALAXYZOO_CONFIGS = [ + (200, 0.5), # 200 samples with 50% anomaly ratio + (400, 0.5), # 400 samples with 50% anomaly ratio + (400, 0.015), # 400 samples with 1.5% anomaly ratio +] + + +def run_benchmark(args: List[str], log_file: Path) -> int: + """Run paper_benchmark.py with the given arguments and log to file.""" + benchmark_script = "paper_benchmark.py" + + # Build command with all arguments + cmd = [sys.executable, benchmark_script] + args + + # Add save_labeled_images flag if needed + if SAVE_LABELED_IMAGES: + cmd.append("--save_labeled_images") + + # Print command for logging + cmd_str = " ".join(cmd) + print(f"Running: {cmd_str}") + + # Open log file for this run + with open(log_file, "w") as log: + # Write command to log + log.write(f"Command: {cmd_str}\n\n") + log.flush() + + # Run process and capture output + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + ) + + # Stream output to both console and log file + for line in process.stdout: + print(line, end="") + log.write(line) + log.flush() + + # Wait for process to complete + process.wait() + + return process.returncode + + +def run_single_seed_experiment( + seed: int, output_dir: Path, input_dir: Path, plot_only=False +) -> Path: + """Run GalaxyZoo experiments for a single seed.""" + print(f"\n======= Running GalaxyZoo Experiments with Seed {seed} =======") + + seed_output_dir = output_dir / f"seed_{seed}" + zoo_output_dir = seed_output_dir / "galaxyzoo" + zoo_output_dir.mkdir(parents=True, exist_ok=True) + if not plot_only: + # Test anomaly detection on Galaxy Zoo dataset (binary classification) + for cls in GALAXYZOO_CLASSES: + for n_samples, anomaly_ratio in GALAXYZOO_CONFIGS: + experiment_name = f"zoo_class{cls}_n{n_samples}_ratio{anomaly_ratio:.3f}" + print(f"\nRunning experiment: {experiment_name} (seed {seed})") + exp_start_time = time.time() + + args = [ + "--dataset", + "galaxyzoo", + "--anomaly_classes", + str(cls), + "--n_samples", + str(n_samples), + "--anomaly_ratio", + str(anomaly_ratio), + "--train_iterations", + str(DEFAULT_TRAIN_ITERATIONS), + "--n_mislabeled", + str(DEFAULT_N_MISLABELED), + "--output_dir", + str(zoo_output_dir / experiment_name), + "--input_dir", + str(input_dir), + "--seed", + str(seed), + "--size", + str(DEFAULT_IMAGE_SIZE), + "--n_to_load", + str(DEFAULT_N_TO_LOAD), + "--training_runs", + str(DEFAULT_TRAINING_RUNS), + "--create_galaxy_plots", + ] + + if SKIP_MOCK_UI: + args.append("--skip_mock_ui") + + log_file = zoo_output_dir / f"{experiment_name}.log" + return_code = run_benchmark(args, log_file) + + if return_code != 0: + print( + f"Warning: Experiment {experiment_name} with seed {seed} failed with return code {return_code}" + ) + + exp_duration = time.time() - exp_start_time + print(f"Experiment completed in {exp_duration:.2f} seconds") + + return seed_output_dir + + +def collect_results_from_seeds(output_dir: Path, seeds: List[int]) -> Dict[str, pd.DataFrame]: + """Collect and organize results from all seed experiments.""" + print("\n======= Collecting Results from All Seeds =======") + + all_results = {} + + # For each configuration, collect results from all seeds + for cls in GALAXYZOO_CLASSES: + for n_samples, anomaly_ratio in GALAXYZOO_CONFIGS: + config_key = f"zoo_class{cls}_n{n_samples}_ratio{anomaly_ratio:.3f}" + config_results = [] + sub_config_key = ( + f"galaxyzoo_anomaly{cls}_n{n_samples-int(n_samples*anomaly_ratio)}" + + f"_a{int(n_samples*anomaly_ratio)}" + ) + for seed in seeds: + seed_dir = output_dir / f"seed_{seed}" / "galaxyzoo" / config_key / sub_config_key + summary_file = seed_dir / "results_summary.csv" + + if summary_file.exists(): + try: + df = pd.read_csv(summary_file) + df["seed"] = seed + df["config"] = config_key + config_results.append(df) + print(f"Loaded results for {config_key}, seed {seed}") + except Exception as e: + print(f"Error loading {summary_file}: {e}") + else: + print(f"Warning: Results file not found: {summary_file}") + + if config_results: + all_results[config_key] = pd.concat(config_results, ignore_index=True) + else: + print(f"No results found for configuration: {config_key}") + + return all_results + + +def create_summary_statistics( + all_results: Dict[str, pd.DataFrame], summary_dir: Path +) -> pd.DataFrame: + """Create summary statistics with means and standard deviations.""" + print("\n======= Creating Summary Statistics =======") + + summary_rows = [] + + for config_key, df in all_results.items(): + if df.empty: + continue + + # Calculate statistics for numeric columns + numeric_cols = [ + "final_auroc", + "final_auprc", + "top_0.1pct_anomalies_found", + "top_0.1pct_precision", + "top_1.0pct_anomalies_found", + "top_1.0pct_precision", + ] + + stats = {} + stats["config"] = config_key + stats["dataset"] = "galaxyzoo" + stats["n_seeds"] = len(df["seed"].unique()) + + # Extract configuration details + if "anomaly_class" in df.columns: + stats["anomaly_class"] = df["anomaly_class"].iloc[0] + + # Calculate mean and std for each metric + for col in numeric_cols: + if col in df.columns: + stats[f"{col}_mean"] = df[col].mean() + stats[f"{col}_std"] = df[col].std() + stats[col] = stats[f"{col}_mean"] # For compatibility + + summary_rows.append(stats) + + # Print summary for this configuration + print(f"\nConfiguration: {config_key}") + print(f" Number of seeds: {stats['n_seeds']}") + for col in numeric_cols: + if col in df.columns: + print(f" {col}: {stats[f'{col}_mean']:.4f} ± {stats[f'{col}_std']:.4f}") + + # Create summary DataFrame + summary_df = pd.DataFrame(summary_rows) + + # Save summary + summary_file = summary_dir / "galaxyzoo_multi_seed_summary.csv" + summary_df.to_csv(summary_file, index=False) + print(f"\nSaved summary statistics to {summary_file}") + + # Also create a format compatible with results_analysis.py + compatible_df = summary_df.copy() + compatible_df = compatible_df.rename(columns={"config": "run_dir"}) + compatible_file = summary_dir / "galaxyzoo_results.csv" + compatible_df.to_csv(compatible_file, index=False) + print(f"Saved compatible results to {compatible_file}") + + return summary_df + + +def load_plot_data_from_seeds( + output_dir: Path, seeds: List[int], iteration: int = DEFAULT_TRAINING_RUNS +) -> Dict[str, List[Dict]]: + """Load plot data from all seeds for creating average plots.""" + print(f"\n======= Loading Plot Data from All Seeds (iteration {iteration}) =======") + + plot_data_by_config = {} + + for cls in GALAXYZOO_CLASSES: + for n_samples, anomaly_ratio in GALAXYZOO_CONFIGS: + config_key = f"zoo_class{cls}_n{n_samples}_ratio{anomaly_ratio:.3f}" + sub_config_key = ( + f"galaxyzoo_anomaly{cls}_n{n_samples-int(n_samples*anomaly_ratio)}" + + f"_a{int(n_samples*anomaly_ratio)}" + ) + + plot_data_list = [] + + for seed in seeds: + seed_dir = output_dir / f"seed_{seed}" / "galaxyzoo" / config_key / sub_config_key + plot_data_dir = seed_dir / "plots" / "plot_data" + plot_data_file = ( + plot_data_dir / f"data_for_astronomaly_comparison_iter{iteration}.pkl" + ) + + if plot_data_file.exists(): + try: + with open(plot_data_file, "rb") as f: + plot_data = pickle.load(f) + plot_data["seed"] = seed + plot_data_list.append(plot_data) + print(f"Loaded plot data for {config_key}, seed {seed}") + except Exception as e: + print(f"Error loading plot data from {plot_data_file}: {e}") + else: + print(f"Warning: Plot data file not found: {plot_data_file}") + + if plot_data_list: + plot_data_by_config[config_key] = plot_data_list + else: + print(f"No plot data found for configuration: {config_key}") + + return plot_data_by_config + + +def create_average_astronomaly_comparison_plot( + plot_data_by_config: Dict[str, List[Dict]], + output_dir: Path, + iteration: int = DEFAULT_TRAINING_RUNS, +): + """Create average astronomaly comparison plots with std deviation shading.""" + print("\n======= Creating Average Astronomaly Comparison Plots =======") + + plots_dir = output_dir / "average_plots" + plots_dir.mkdir(exist_ok=True) + + # Load the Astronomaly Figure 5a reference data + try: + astronomaly_data = pd.read_csv( + Path(__file__).parent / "AstronomalyFigure5a.csv", skiprows=[0] + ) + astronomaly_x = astronomaly_data["xaxis"].values + astronomaly_y = astronomaly_data["yaxis"].values + except Exception as e: + print(f"Error loading Astronomaly reference data: {e}") + astronomaly_x = astronomaly_y = None + + for config_key, plot_data_list in plot_data_by_config.items(): + print(f"\nCreating average plot for {config_key}") + + # Create the figure + plt.figure(figsize=(8, 8)) + + # Plot perfect prediction curve + plt.plot( + np.linspace(0, 2000, 10), + np.linspace(0, 2000, 10), + color=REFERENCE_LINE_COLOR, + linestyle=REFERENCE_LINE_STYLE, + linewidth=4, + label="1:1 limit", + ) + + # Plot lines for each threshold with average and std deviation + colors = [BLUE, GREEN, ORANGE, PURPLE] + thresholds = GALAXYZOO_THRESHOLDS + + for i, threshold in enumerate(thresholds): + # Collect data for this threshold from all seeds + all_anomalies_found = [] + + for plot_data in plot_data_list: + try: + # Extract the data for this seed + scores = np.array(plot_data["scores"]) + filenames = plot_data["filenames"] + true_labels_df = plot_data["true_labels_df"] + + # Create scores dataframe and merge with true labels + scores_df = pd.DataFrame({"filename": filenames, "score": scores}) + merged_df = pd.merge(scores_df, true_labels_df, on="filename") + + if len(merged_df) == 0: + continue + + # Create binary labels based on threshold + merged_df["true_anomaly"] = ( + merged_df["anomaly_score_raw"] >= threshold + ).astype(int) + + # Sort by model scores + filtered_df = merged_df.sort_values("score", ascending=False).reset_index( + drop=True + ) + + # Calculate cumulative sum of anomalies found + filtered_df["cum_anomalies"] = filtered_df["true_anomaly"].cumsum() + + # Create x-axis values and corresponding anomalies found + inspection_indices = np.arange(0, min(2000, len(filtered_df))) + anomalies_found = [] + for idx in inspection_indices: + if idx == 0: + anomalies_found.append(0) + else: + anomalies_found.append(filtered_df.loc[idx - 1, "cum_anomalies"]) + + # Pad or truncate to ensure consistent length + if len(anomalies_found) < 2000: + # Pad with the last value + last_val = anomalies_found[-1] if anomalies_found else 0 + anomalies_found.extend([last_val] * (2000 - len(anomalies_found))) + else: + anomalies_found = anomalies_found[:2000] + + all_anomalies_found.append(anomalies_found) + + except Exception as e: + print( + f"Error processing plot data for seed {plot_data.get('seed', 'unknown')}: {e}" + ) + continue + + if all_anomalies_found: + # Convert to numpy array for easier calculation + all_anomalies_found = np.array(all_anomalies_found) + + # Calculate mean and std + mean_anomalies = np.mean(all_anomalies_found, axis=0) + std_anomalies = np.std(all_anomalies_found, axis=0) + + x_values = np.arange(2000) + + # Plot the mean line + plt.plot( + x_values, + mean_anomalies, + color=colors[i % len(colors)], + linewidth=2, + label=f"AnomalyMatch (t={threshold})", + ) + + # Add std deviation shading + plt.fill_between( + x_values, + mean_anomalies - std_anomalies, + mean_anomalies + std_anomalies, + color=colors[i % len(colors)], + alpha=0.3, + ) + + # Plot Astronomaly reference curve if available + if astronomaly_x is not None and astronomaly_y is not None: + plt.plot( + astronomaly_x, + astronomaly_y, + color=RED, + linestyle=PERFECT_LINE_STYLE, + linewidth=2, + label="Astronomaly (t=0.9)", + ) + + # Add labels and legend + plt.xlabel("Index in ranked list") + plt.ylabel("Number of anomalies detected") + plt.grid(True, alpha=0.3) + plt.legend(loc="lower right", frameon=True, framealpha=0.7) + + # Set axis limits + plt.xlim([0, 2000]) + plt.ylim([0, 300]) + + # Add title with configuration info + # plt.title( + # f"Astronomaly Comparison - {config_key}\n(Average over {len(plot_data_list)} seeds)" + # ) + plt.figtext( + 0.5, + -0.05, # (x, y) position in figure coordinates (0.5 = centered, -0.05 = below axis) + f"(Average over {len(plot_data_list)} seeds)", + ha="center", + va="top", + fontsize=10, + ) + # Save figure + output_path = ( + plots_dir / f"average_astronomaly_comparison_{config_key}_cycle{iteration}.pdf" + ) + plt.tight_layout() + plt.savefig(output_path, dpi=DEFAULT_DPI) + plt.close() + + print(f"Average plot saved to {output_path}") + + +def main(): + """Main function to run multi-seed GalaxyZoo experiments.""" + parser = argparse.ArgumentParser(description="Run GalaxyZoo experiments with multiple seeds") + parser.add_argument( + "--output_dir", + type=str, + default="galaxyzoo_multi_seed_results", + help="Output directory for results", + ) + parser.add_argument( + "--input_dir", + type=str, + default="datasets", + help="Input directory containing prepared datasets", + ) + parser.add_argument( + "--seeds", + type=int, + nargs="+", + default=GALAXYZOO_SEEDS, + help="List of seeds to use for experiments", + ) + # add --plot only argument that when set wil set plot_only in run_single_seed_experiment to True + parser.add_argument( + "--plot_only", + action="store_true", + help="Only create plots from existing results without running experiments", + ) + # add argument to include a timestamp string + parser.add_argument( + "--timestamp", + type=str, + default=datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), + help="Enforce a timestamp/description in the output (sub)directory name", + ) + + args = parser.parse_args() + + # Setup directories + output_dir = Path(args.output_dir) + input_dir = Path(args.input_dir) + + if not input_dir.exists(): + print(f"Error: Input directory not found: {input_dir}") + return 1 + + # Create timestamped output directory + timestamp = args.timestamp # datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = output_dir / f"results_{timestamp}" + output_dir.mkdir(parents=True, exist_ok=True) + + print("Starting GalaxyZoo multi-seed experiments") + print(f"Seeds: {args.seeds}") + print(f"Output directory: {output_dir}") + print(f"Input directory: {input_dir}") + + start_time = time.time() + + # Run experiments for each seed + for i, seed in enumerate(args.seeds): + print(f"\n{'='*60}") + print(f"Running experiments for seed {seed} ({i+1}/{len(args.seeds)})") + print(f"{'='*60}") + + try: + seed_output_dir = run_single_seed_experiment( + seed, output_dir, input_dir, args.plot_only + ) + print(f"Completed experiments for seed {seed} -> saved at {seed_output_dir}") + except Exception as e: + print(f"Error running experiments for seed {seed}: {e}") + continue + + # Collect and analyze results + print(f"\n{'='*60}") + print("Collecting and analyzing results from all seeds") + print(f"{'='*60}") + + # Create summary directory + summary_dir = output_dir / "summary" + summary_dir.mkdir(exist_ok=True) + + # Collect results + all_results = collect_results_from_seeds(output_dir, args.seeds) + + if not all_results: + print("No results found from any seed experiments!") + return 1 + + # Create summary statistics + _ = create_summary_statistics(all_results, summary_dir) + # summary_df + # Load plot data and create average plots + plot_data_by_config = load_plot_data_from_seeds(output_dir, args.seeds) + + if plot_data_by_config: + create_average_astronomaly_comparison_plot(plot_data_by_config, output_dir) + else: + print("Warning: No plot data found for creating average plots") + + # Final summary + total_time = time.time() - start_time + print(f"\n{'='*60}") + print("EXPERIMENT SUMMARY") + print(f"{'='*60}") + print(f"Total execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)") + print(f"Seeds processed: {len(args.seeds)}") + print(f"Configurations per seed: {len(GALAXYZOO_CONFIGS)}") + print(f"Results saved to: {output_dir}") + print(f"Summary statistics: {summary_dir}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/paper_scripts/paper_benchmark.py b/paper_scripts/paper_benchmark.py index 6bc8f9b..318e60b 100644 --- a/paper_scripts/paper_benchmark.py +++ b/paper_scripts/paper_benchmark.py @@ -85,6 +85,8 @@ USE_TURBOJPEG = True logger.info("Using TurboJPEG for faster image decoding") +from anomaly_match.utils.set_seeds import set_seeds + def read_and_decode_image(image_data): """Decode image data from HDF5 with optimized handling.""" @@ -276,8 +278,18 @@ def get_prediction_scores(session, labeled_filenames, hdf5_path, progress_bar=No all_filenames = all_filenames[: len(all_scores)] # Filter out labeled files - labeled_set = set(labeled_filenames) - unlabeled_indices = [i for i, fname in enumerate(all_filenames) if fname not in labeled_set] + # Normalise paths before comparison to handle different path formats + from os.path import basename, normpath + + # Convert paths to a standardized format + normalised_labeled = set(basename(normpath(fname)) for fname in labeled_filenames) + + # Get indices of unlabeled files + unlabeled_indices = [ + i + for i, fname in enumerate(all_filenames) + if basename(normpath(fname)) not in normalised_labeled + ] logger.info(f"Filtered out {len(all_filenames) - len(unlabeled_indices)} labeled samples") logger.info(f"Remaining unlabeled samples: {len(unlabeled_indices)}") @@ -321,6 +333,8 @@ def get_available_filenames(session): def run_benchmark(args): """Run the full benchmarking process.""" + set_seeds(args.seed, deterministic=True) # initialise torch, np and random seed + # Set up directories run_dir, model_dir, plots_dir = setup_directories(args) logger.info(f"Results will be saved to {run_dir}") @@ -370,6 +384,7 @@ def run_benchmark(args): labeled_data_path, run_dir, output_widget, + args.seed, None if args.skip_mock_ui else progress_bar, ) @@ -728,21 +743,27 @@ def run_benchmark(args): logger.info(f"Found {len(corrections)} mislabeled samples to correct") - # Apply corrections by labeling the samples in the session - for _, row in corrections.iterrows(): - filename = row["filename"] - label = row["label"] + # if this is the last iteration we will not apply corrections + if iteration == args.training_runs - 1: + logger.info("Skipping corrections in the last iteration") + else: + # Apply corrections by labeling the samples in the session + for _, row in corrections.iterrows(): + filename = row["filename"] + label = row["label"] - # Find index of this filename in session.filenames - try: - idx = session.filenames.tolist().index(filename) - session.label_image(idx, label) - logger.debug(f"Labeled {filename} as {label}") - except ValueError: - logger.warning(f"Could not find {filename} in session filenames") + # Find index of this filename in session.filenames + try: + idx = session.filenames.tolist().index(filename) + session.label_image(idx, label) + logger.debug(f"Labeled {filename} as {label}") + except ValueError: + logger.warning(f"Could not find {filename} in session filenames") # Save updated labels session.save_labels() + session_path = session.session_io.get_session_save_path(session.session_tracker) + labeled_data_path = session_path / "labeled_data.csv" # Update labeled_df with new labels labeled_df = pd.read_csv(labeled_data_path) @@ -755,6 +776,8 @@ def run_benchmark(args): # Copy newly labeled images to output directory after each iteration copy_labeled_images(labeled_df, data_dir, iter_dir) + # End of iterations + # Plot metrics over time with training batches as x-axis plot_metrics_over_time(metrics_history, plots_dir, batch_size=args.train_iterations) @@ -823,6 +846,7 @@ def run_multi_class_benchmark(args): Args: args: Command line arguments """ + set_seeds(args.seed, deterministic=True) # initialise torch, np and random seed # Store metrics for each anomaly class all_class_metrics = {} @@ -898,6 +922,7 @@ def run_multi_class_benchmark(args): labeled_data_path, run_dir, output_widget, + args.seed, None if args.skip_mock_ui else progress_bar, ) @@ -1251,21 +1276,27 @@ def run_multi_class_benchmark(args): logger.info(f"Found {len(corrections)} mislabeled samples to correct") - # Apply corrections by labeling the samples in the session - for _, row in corrections.iterrows(): - filename = row["filename"] - label = row["label"] - - # Find index of this filename in session.filenames - try: - idx = session.filenames.tolist().index(filename) - session.label_image(idx, label) - logger.debug(f"Labeled {filename} as {label}") - except ValueError: - logger.warning(f"Could not find {filename} in session filenames") + # if this is the last iteration we will not apply corrections + if iteration == args.training_runs - 1: + logger.info("Skipping corrections in the last iteration") + else: + # Apply corrections by labeling the samples in the session + for _, row in corrections.iterrows(): + filename = row["filename"] + label = row["label"] + + # Find index of this filename in session.filenames + try: + idx = session.filenames.tolist().index(filename) + session.label_image(idx, label) + logger.debug(f"Labeled {filename} as {label}") + except ValueError: + logger.warning(f"Could not find {filename} in session filenames") # Save updated labels session.save_labels() + session_path = session.session_io.get_session_save_path(session.session_tracker) + labeled_data_path = session_path / "labeled_data.csv" # Update labeled_df with new labels labeled_df = pd.read_csv(labeled_data_path) @@ -1278,6 +1309,8 @@ def run_multi_class_benchmark(args): # Copy newly labeled images to output directory after each iteration copy_labeled_images(labeled_df, data_dir, iter_dir) + # End of training iterations for this anomaly class + # Plot metrics over time with training batches as x-axis plot_metrics_over_time(metrics_history, plots_dir, batch_size=args.train_iterations) diff --git a/paper_scripts/paper_plots.py b/paper_scripts/paper_plots.py index 717f214..305c49b 100644 --- a/paper_scripts/paper_plots.py +++ b/paper_scripts/paper_plots.py @@ -74,7 +74,7 @@ sns.set_style("whitegrid") # Default DPI for saving figures -DEFAULT_DPI = 600 +DEFAULT_DPI = 450 def plot_score_histogram(anomaly_scores, normal_scores, iteration, plots_dir): @@ -92,14 +92,14 @@ def plot_score_histogram(anomaly_scores, normal_scores, iteration, plots_dir): normal_scores = np.array(normal_scores).flatten() # Create figure with square aspect ratio for publication - plt.figure(figsize=(8, 8)) + plt.figure(figsize=(8, 6 * 0.8)) # Plot histograms with density=True for normalization sns.histplot( normal_scores, color=NORMAL_COLOR, alpha=HIST_ALPHA, - label="Normal", + label="Nominal", kde=True, bins=30, stat="density", @@ -115,12 +115,12 @@ def plot_score_histogram(anomaly_scores, normal_scores, iteration, plots_dir): ) # Add labels (no title for publication) - plt.xlabel("Model Anomaly Score") + plt.xlabel("AnomalyMatch score") plt.ylabel("Density") plt.legend(frameon=True, framealpha=0.7) # Save figure with high DPI for publication - output_path = os.path.join(plots_dir, f"score_histogram_iter{iteration}.png") + output_path = os.path.join(plots_dir, f"score_histogram_iter{iteration}.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI) plt.close() @@ -147,7 +147,7 @@ def plot_metrics_over_time(metrics_history, plots_dir, batch_size=None): x_label = "Training Batches" else: x_values = iterations - x_label = "Training Iteration" # Plot metrics with emphasis on data points + x_label = "Training Cycle" # Plot metrics with emphasis on data points plt.plot(x_values, auroc_values, "-", color=BLUE, marker="o", label="AUROC", markersize=8) plt.plot(x_values, auprc_values, "-", color=RED, marker="o", label="AUPRC", markersize=8) @@ -164,7 +164,7 @@ def plot_metrics_over_time(metrics_history, plots_dir, batch_size=None): ) # Save figure with high DPI for publication - output_path = os.path.join(plots_dir, "metrics_over_time.png") + output_path = os.path.join(plots_dir, "metrics_over_time.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI) plt.close() @@ -219,7 +219,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir): plt.ylim([0, 1]) # Save the ROC curve - roc_path = os.path.join(plots_dir, f"roc_curve_iter{iteration}.png") + roc_path = os.path.join(plots_dir, f"roc_curve_iter{iteration}.pdf") plt.tight_layout() plt.savefig(roc_path, dpi=DEFAULT_DPI) plt.close() # 2. Precision-Recall Curve @@ -244,13 +244,13 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir): plt.ylim([0, 1]) # Save the PR curve - pr_path = os.path.join(plots_dir, f"pr_curve_iter{iteration}.png") + pr_path = os.path.join(plots_dir, f"pr_curve_iter{iteration}.pdf") plt.tight_layout() plt.savefig(pr_path, dpi=DEFAULT_DPI) plt.close() # 3. Combined figure (side by side) - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) # Plot ROC curve + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) # Plot ROC curve ax1.plot(fpr, tpr, color=BLUE, linewidth=2, label=f'AUROC = {metrics["auroc"]:.3f}') ax1.plot( [0, 1], @@ -284,7 +284,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir): ax2.set_ylim([0, 1]) # Save the combined figure - combined_path = os.path.join(plots_dir, f"roc_prc_curves_iter{iteration}.png") + combined_path = os.path.join(plots_dir, f"roc_prc_curves_iter{iteration}.pdf") plt.tight_layout() plt.savefig(combined_path, dpi=DEFAULT_DPI) plt.close() @@ -380,7 +380,7 @@ def plot_top_mispredicted( fig.text(0.01, 0.75, "False Positives", ha="left", va="center", fontsize=14, rotation=90) fig.text(0.01, 0.25, "False Negatives", ha="left", va="center", fontsize=14, rotation=90) - output_path = os.path.join(plots_dir, f"mispredicted_images_iter{iteration}.png") + output_path = os.path.join(plots_dir, f"mispredicted_images_iter{iteration}.jpg") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI) plt.close() @@ -473,7 +473,7 @@ def plot_top_n_anomaly_detection( anomalies_found, color=BLUE, linewidth=2, - label="Anomaly detection rate", + label=f"Anomaly detection rate (Cycle {iteration})", ) # Add reference line (perfect detection - if all anomalies come first) @@ -497,53 +497,54 @@ def plot_top_n_anomaly_detection( color=PERFECT_LINE_COLOR, linestyle=PERFECT_LINE_STYLE, alpha=PERFECT_LINE_ALPHA, - linewidth=1.5, + linewidth=3, label="Perfect detection", ) # Calculate percentage of anomalies found at key inspection points - percent_at_0_1pct = np.interp(0.1, inspection_points / total_samples * 100, anomalies_found) - percent_at_1pct = np.interp( - 1, inspection_points / total_samples * 100, anomalies_found - ) # Add vertical line at 0.1% inspection + # remove the values which are interpolated - they can be taken from the tables + # percent_at_0_1pct = np.interp(0.1, inspection_points / total_samples * 100, anomalies_found) + # percent_at_1pct = np.interp( + # 1, inspection_points / total_samples * 100, anomalies_found + # ) # Add vertical line at 0.1% inspection plt.axvline(x=0.1, color=VLINE_COLOR, linestyle=VLINE_STYLE, alpha=VLINE_ALPHA) plt.text( 0.07, - 50, + 45, f"0.1% inspected = {int(total_samples * 0.001)} samples", rotation=90, - va="bottom", - fontsize=8 * FONT_SCALE, - ) - plt.text( - 0.095, - 102, - f"found {percent_at_0_1pct:.1f}% \n of anomalies", - ha="center", + va="center", fontsize=8 * FONT_SCALE, ) + # plt.text( + # 0.095, + # 102, + # f"found {percent_at_0_1pct:.1f}% \n of anomalies", + # ha="center", + # fontsize=8 * FONT_SCALE, + # ) # Add vertical line at 1% inspection plt.axvline(x=1, color=VLINE_COLOR, linestyle=VLINE_STYLE, alpha=VLINE_ALPHA) plt.text( 0.7, - 50, + 45, f"1% inspected = {int(total_samples * 0.01)} samples", rotation=90, - va="bottom", + va="center", fontsize=8 * FONT_SCALE, ) - plt.text( - 0.95, - 102, - f"found {percent_at_1pct:.1f}% \n of anomalies", - ha="center", - fontsize=8 * FONT_SCALE, - ) # Add labels - NO TITLE for publication + # plt.text( + # 0.95, + # 102, + # f"found {percent_at_1pct:.1f}% \n of anomalies", + # ha="center", + # fontsize=8 * FONT_SCALE, + # ) # Add labels - NO TITLE for publication plt.xlabel("% of top-scoring predictions inspected") plt.ylabel("% of Total Anomalies Found") plt.grid(True) - plt.legend(loc="lower right", frameon=True, framealpha=0.7) + plt.legend(loc="upper left", frameon=True, framealpha=0.7, fontsize=8 * FONT_SCALE) # Set axis limits and log scale plt.xscale("log") @@ -557,7 +558,7 @@ def plot_top_n_anomaly_detection( ) # Save figure with high resolution for publication - output_path = os.path.join(plots_dir, f"top_n_detection_iter{iteration}.png") + output_path = os.path.join(plots_dir, f"top_n_detection_iter{iteration}.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI, bbox_inches="tight") plt.close() @@ -588,9 +589,9 @@ def plot_combined_anomaly_detection(detection_curves, plots_dir, anomaly_prevale x, y = detection_curves[iteration] # Use special color for last iteration to match anomaly detection rate if iteration == iterations[-1]: # if this is the last iteration - plt.plot(x, y, color=LAST_ITER_COLOR, linewidth=2.5, label=f"Iteration {iteration}") + plt.plot(x, y, color=LAST_ITER_COLOR, linewidth=2.5, label=f"Cycle {iteration}") else: - plt.plot(x, y, color=colors[i], linewidth=2, label=f"Iteration {iteration}") + plt.plot(x, y, color=colors[i], linewidth=2, label=f"Cycle {iteration}") # Add reference line (perfect detection - if all anomalies come first) # Use very high resolution for the perfect line to avoid interpolation issues @@ -622,7 +623,7 @@ def plot_combined_anomaly_detection(detection_curves, plots_dir, anomaly_prevale color=PERFECT_LINE_COLOR, linestyle=PERFECT_LINE_STYLE, alpha=PERFECT_LINE_ALPHA, - linewidth=1.5, + linewidth=3, label="Perfect detection", ) @@ -636,7 +637,7 @@ def plot_combined_anomaly_detection(detection_curves, plots_dir, anomaly_prevale plt.xlabel("% of top-scoring predictions inspected") plt.ylabel("% of Total Anomalies Found") plt.grid(True, alpha=0.3) - plt.legend(loc="lower right", frameon=True, framealpha=0.7) + plt.legend(loc="upper left", frameon=True, framealpha=0.7, fontsize=8 * FONT_SCALE) # Set axis limits and log scale plt.xscale("log") @@ -650,7 +651,7 @@ def plot_combined_anomaly_detection(detection_curves, plots_dir, anomaly_prevale ) # Save figure with high resolution for publication - output_path = os.path.join(plots_dir, "combined_top_n_detection.png") + output_path = os.path.join(plots_dir, "combined_top_n_detection.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI, bbox_inches="tight") plt.close() @@ -761,7 +762,7 @@ def plot_comparative_anomaly_detection(detection_curves, output_dir): results_df.to_csv(csv_path, index=False) # Save figure with high resolution for publication - output_path = os.path.join(output_dir, "comparative_top_n_detection.png") + output_path = os.path.join(output_dir, "comparative_top_n_detection.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI) plt.close() @@ -796,7 +797,7 @@ def plot_comparative_metrics(class_metrics, output_dir): x - width / 2, first_iter_auroc, width, - label="First Iteration", + label="First Cycle", color="lightblue", edgecolor="blue", ) @@ -834,7 +835,7 @@ def autolabel(rects, ax): x - width / 2, first_iter_auprc, width, - label="First Iteration", + label="First Cycle", color="lightpink", edgecolor="red", ) @@ -859,7 +860,7 @@ def autolabel(rects, ax): # Save the figure with high resolution for publication plt.tight_layout() - output_path = os.path.join(output_dir, "comparative_metrics.png") + output_path = os.path.join(output_dir, "comparative_metrics.pdf") plt.savefig(output_path, dpi=DEFAULT_DPI) plt.close() @@ -1028,7 +1029,7 @@ def plot_top_n_with_thresholds( color=PERFECT_LINE_COLOR, linestyle=PERFECT_LINE_STYLE, alpha=PERFECT_LINE_ALPHA, - linewidth=1.5, + linewidth=3, label="Perfect detection", ) @@ -1050,7 +1051,7 @@ def plot_top_n_with_thresholds( ) # Save figure with high resolution for publication - output_path = os.path.join(plots_dir, f"top_n_detection_thresholds_iter{iteration}.png") + output_path = os.path.join(plots_dir, f"top_n_detection_thresholds_iter{iteration}.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI, bbox_inches="tight") plt.close() @@ -1153,7 +1154,7 @@ def plot_roc_with_thresholds( plt.ylim([0, 1]) # Save the ROC curve - output_path = os.path.join(plots_dir, f"roc_curve_thresholds_iter{iteration}.png") + output_path = os.path.join(plots_dir, f"roc_curve_thresholds_iter{iteration}.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI) plt.close() @@ -1231,8 +1232,8 @@ def plot_astronomaly_comparison( np.linspace(0, 2000, 10), color=REFERENCE_LINE_COLOR, linestyle=REFERENCE_LINE_STYLE, - linewidth=2, - label="Perfect Prediction", + linewidth=4, + label="1:1 limit", ) # Plot lines for each threshold @@ -1288,7 +1289,7 @@ def plot_astronomaly_comparison( plt.ylim([0, 300]) # Save figure - output_path = os.path.join(plots_dir, f"astronomaly_comparison_iter{iteration}.png") + output_path = os.path.join(plots_dir, f"astronomaly_comparison_iter{iteration}.pdf") plt.tight_layout() plt.savefig(output_path, dpi=DEFAULT_DPI) plt.close() @@ -1356,47 +1357,47 @@ def plot_score_vs_user_score_grid( # MAIN PLOT - Full grid view create_grid_plot( merged_df, - n_grid, + 12, data_dir, plots_dir, iteration, "full", fig_title="ML Scores vs User Scores - Full Grid", ) + if False: + # PLOT 1 - Top Left Corner: High user scores (top 20%), low ML scores (bottom 50%) + filtered_df1 = merged_df[ + (merged_df["anomaly_score_raw"] > user_score_80th) + & (merged_df["ml_score"] <= ml_score_50th) + ] + if len(filtered_df1) > 0: + create_grid_plot( + filtered_df1, + 7, + data_dir, + plots_dir, + iteration, + "topleft", + fig_title="High User Scores (>80th percentile), Low ML Scores (<50th percentile)", + ) + else: + logger.warning("No data points for top-left quadrant plot") - # PLOT 1 - Top Left Corner: High user scores (top 20%), low ML scores (bottom 50%) - filtered_df1 = merged_df[ - (merged_df["anomaly_score_raw"] > user_score_80th) - & (merged_df["ml_score"] <= ml_score_50th) - ] - if len(filtered_df1) > 0: - create_grid_plot( - filtered_df1, - 8, - data_dir, - plots_dir, - iteration, - "topleft", - fig_title="High User Scores (>P80), Low ML Scores ( ml_score_80th) - ] - if len(filtered_df2) > 0: - create_grid_plot( - filtered_df2, - 8, - data_dir, - plots_dir, - iteration, - "bottomright", - fig_title="Low User Scores (P80)", - ) + # PLOT 2 - Bottom Right Corner: Low user scores (bottom 50%), high ML scores (top 20%) + filtered_df2 = merged_df[ + (merged_df["anomaly_score_raw"] <= user_score_50th) + & (merged_df["ml_score"] > ml_score_80th) + ] + if len(filtered_df2) > 0: + create_grid_plot( + filtered_df2, + 7, + data_dir, + plots_dir, + iteration, + "bottomright", + fig_title="Low User Scores (<50th percentile), High ML Scores (>80th percentile)", + ) else: logger.warning("No data points for bottom-right quadrant plot") @@ -1406,7 +1407,7 @@ def plot_score_vs_user_score_grid( def create_grid_plot(merged_df, n_grid, data_dir, plots_dir, iteration, suffix, fig_title=None): """ - Create a grid plot for the given data. + Create a grid plot for the given data with percentile-based sampling and visual indicators. Args: merged_df: DataFrame with merged scores and filenames @@ -1419,109 +1420,189 @@ def create_grid_plot(merged_df, n_grid, data_dir, plots_dir, iteration, suffix, """ from PIL import Image import matplotlib.gridspec as gridspec - - # Create figure with equal width and height, accounting for histograms - fig = plt.figure(figsize=(12, 12)) + import matplotlib.patches as patches + + # Number of user score bins (fewer than ML score bins) + n_user_grid = 6 + # Number of rows for nominal and anomalies + n_nominal_rows = 3 + n_anomaly_rows = 3 + + # Create figure with square aspect ratio for grid cells + # Calculate better figure size to ensure square grid cells + # Add some margin for labels and spacing + fig_width = 11 + # Adjust height to ensure grid cells are square + # We need to account for the fact that the grid is n_grid columns by n_user_grid rows + # The 0.9 factor helps account for the extra space taken by labels + fig_height = (fig_width * n_user_grid / n_grid) * 1.13 + fig = plt.figure(figsize=(fig_width, fig_height)) # Create grid layout with space for axis labels, histograms, and colorbar gs = gridspec.GridSpec( - n_grid + 2, # Add 1 for x-axis labels + 1 for histogram - n_grid + 2, # Add 1 for y-axis labels + 1 for histogram - width_ratios=[1] + [1] * n_grid + [1], # Extra column for histogram - height_ratios=[1] + [1] * n_grid + [1], # Extra row for histogram + n_user_grid + 2, # Add 1 for x-axis labels + 1 for histogram + n_grid + 1, # Add 1 for y-axis labels + width_ratios=[1] + [1] * n_grid, # Extra column for labels + height_ratios=[1] + [1] * n_user_grid + [1], # Extra row for labels wspace=0.0, hspace=0.0, ) - # Calculate percentile bin edges for both scores - ml_percentiles = np.percentile(merged_df["ml_score"], np.linspace(0, 100, n_grid + 1)) - user_percentiles = np.percentile( - merged_df["anomaly_score_raw"], np.linspace(0, 100, n_grid + 1) - ) - - # Calculate bin centers for labeling - ml_centers = [(ml_percentiles[i] + ml_percentiles[i + 1]) / 2 for i in range(n_grid)] - user_centers = [(user_percentiles[i] + user_percentiles[i + 1]) / 2 for i in range(n_grid)] - - # Create empty grid to store axes - axes = np.empty((n_grid, n_grid), dtype=object) + # We'll use the original ML scores for percentile calculation + # but we'll have separate bins for nominal and anomalous classes - # Create bins and initialize occupancy matrix for tracking which cells have images - bins = np.zeros((n_grid, n_grid), dtype=int) - - # Dictionary to store the representative image for each cell - representative_images = {} - - # Calculate total number of images for alpha scaling - total_images = len(merged_df) - - for _, row in merged_df.iterrows(): - ml_score = row["ml_score"] - user_score = row["anomaly_score_raw"] - - # Determine bin indices based on percentiles - x_idx = np.searchsorted(ml_percentiles, ml_score) - 1 - # Ensure the index is within bounds - x_idx = min(max(x_idx, 0), n_grid - 1) - - y_idx = np.searchsorted(user_percentiles, user_score) - 1 - # Ensure the index is within bounds - y_idx = min(max(y_idx, 0), n_grid - 1) + # Calculate ranks for both scores (1 -> lowest value) + ml_ranks = merged_df["ml_score"].rank(method="first") + user_ranks = merged_df["anomaly_score_raw"].rank(method="first") - # Invert y_idx to make 0,0 at bottom left - y_idx_inverted = (n_grid - 1) - y_idx + # Calculate rank differences with sign (positive = higher in AM, negative = lower in AM) + signed_rank_differences = ml_ranks - user_ranks # eg 10 - 5 = 5 (AM score higher) + rank_differences = np.abs(signed_rank_differences) + merged_df["rank_diff"] = rank_differences + merged_df["signed_rank_diff"] = signed_rank_differences - # Update bin count - bins[y_idx_inverted, x_idx] += 1 + # Create empty grid to store axes + axes = np.empty((n_user_grid, n_grid), dtype=object) - # Calculate distance to bin center (using actual values, not indices) - ml_center = ml_centers[x_idx] - user_center = user_centers[y_idx] - distance = np.sqrt((ml_score - ml_center) ** 2 + (user_score - user_center) ** 2) + # Dictionary to store the selected images for each cell + selected_images = {} - # Store image if it's the closest to the bin center so far - bin_key = (y_idx_inverted, x_idx) - if bin_key not in representative_images or distance < representative_images[bin_key][1]: - representative_images[bin_key] = (row["filename"], distance) + # Split data into nominal and anomalous samples + threshold = 0.9 + anomalous_df = merged_df[merged_df["anomaly_score_raw"] >= threshold].copy() + nominal_df = merged_df[merged_df["anomaly_score_raw"] < threshold].copy() - # Calculate maximum occupancy for alpha normalization - max_occupancy = max(1, np.max(bins)) # Avoid division by zero - min_occupancy = max(1, np.min(bins[bins > 0])) if np.any(bins > 0) else 1 + # Calculate separate percentile bin edges for anomalies and nominal samples + # This ensures each class has its own independent distribution + anomaly_ml_percentiles = np.percentile( + anomalous_df["ml_score"], np.linspace(0, 100, n_grid + 1) + ) + nominal_ml_percentiles = np.percentile(nominal_df["ml_score"], np.linspace(0, 100, n_grid + 1)) - # Define alpha scaling function - dynamic scaling based on distribution - # Use log scale for better visibility with potentially uneven distributions - def alpha_scaling(count, max_count, min_count, total_images): - if count == 0: - return 0 + # Record the actual ML score values at each percentile for axis labels + anomaly_score_values = [] + nominal_score_values = [] + for x in range(n_grid + 1): + anomaly_score_values.append(anomaly_ml_percentiles[x]) + nominal_score_values.append(nominal_ml_percentiles[x]) - # Linear scaling in log space between min_count and max_count - # alpha = 0.05 when count = min_count, alpha = 1.0 when count = max_count - if max_count == min_count: - return 1.0 # All counts are the same + # For each ML score percentile bin + for x in range(n_grid): + # Define ML score bin boundaries using separate percentiles for each class + if x == n_grid - 1: + ml_mask_anomaly = (anomalous_df["ml_score"] >= anomaly_ml_percentiles[x]) & ( + anomalous_df["ml_score"] <= anomaly_ml_percentiles[x + 1] + ) + ml_mask_nominal = (nominal_df["ml_score"] >= nominal_ml_percentiles[x]) & ( + nominal_df["ml_score"] <= nominal_ml_percentiles[x + 1] + ) + else: + ml_mask_anomaly = (anomalous_df["ml_score"] >= anomaly_ml_percentiles[x]) & ( + anomalous_df["ml_score"] < anomaly_ml_percentiles[x + 1] + ) + ml_mask_nominal = (nominal_df["ml_score"] >= nominal_ml_percentiles[x]) & ( + nominal_df["ml_score"] < nominal_ml_percentiles[x + 1] + ) - log_count = np.log10(count + 1) - log_min = np.log10(min_count + 1) - log_max = np.log10(max_count + 1) + # Get samples in this ML score bin for each class + anomaly_bin_samples = anomalous_df[ml_mask_anomaly].copy() + nominal_bin_samples = nominal_df[ml_mask_nominal].copy() - # Linear interpolation in log space - alpha = 0.1 + 0.9 * (log_count - log_min) / (log_max - log_min) + # Process nominal samples for top rows + if len(nominal_bin_samples) > 0: + # Sort by absolute rank difference (lowest to highest) + nominal_bin_samples = nominal_bin_samples.sort_values("rank_diff") - return alpha + # Sample images with increasing rank differences: closest, median, furthest + if len(nominal_bin_samples) <= n_nominal_rows: + # If we have 2 or fewer images, just use what we have + nominal_indices = np.arange(len(nominal_bin_samples)) + else: + # For 3 or more images: get lowest, median, and highest rank difference samples + if n_nominal_rows == 3: + # Get closest (top), median (middle), and furthest (bottom) samples + median_idx = len(nominal_bin_samples) // 2 + nominal_indices = [ + 0, # Closest - smallest rank diff + median_idx, # Median rank diff + len(nominal_bin_samples) - 1, # Furthest - largest rank diff + ] + else: + # Sample with increasing density toward higher rank differences + sample_points = np.linspace(0, 1, n_nominal_rows) + nominal_indices = (sample_points * (len(nominal_bin_samples) - 1)).astype(int) + + # Store selected nominal images + for y_idx, idx in enumerate(nominal_indices): + if y_idx >= n_nominal_rows: + continue + + if idx < len(nominal_bin_samples): + row = nominal_bin_samples.iloc[idx] + selected_images[(y_idx, x)] = { + "filename": row["filename"], + "rank_diff": row["rank_diff"], + "signed_rank_diff": row["signed_rank_diff"], + "is_anomaly": False, + "ml_score": row["ml_score"], + "user_score": row["anomaly_score_raw"], + } + + # Process anomaly samples for bottom rows + if len(anomaly_bin_samples) > 0: + # Sort by absolute rank difference (lowest to highest) + anomaly_bin_samples = anomaly_bin_samples.sort_values("rank_diff") + + # Sample images with increasing rank differences: closest, median, furthest + if len(anomaly_bin_samples) <= n_anomaly_rows: + # If we have 2 or fewer images, just use what we have + anomaly_indices = np.arange(len(anomaly_bin_samples)) + else: + # For 3 or more images: get lowest, median, and highest rank difference samples + if n_anomaly_rows == 3: + # Get closest (top), median (middle), and furthest (bottom) samples + median_idx = len(anomaly_bin_samples) // 2 + anomaly_indices = [ + 0, # Closest - smallest rank diff + median_idx, # Median rank diff + len(anomaly_bin_samples) - 1, # Furthest - largest rank diff + ] + else: + # Sample with increasing density toward higher rank differences + sample_points = np.linspace(0, 1, n_anomaly_rows) + anomaly_indices = (sample_points * (len(anomaly_bin_samples) - 1)).astype(int) + + # Store selected anomaly images (in bottom rows) + for y_idx, idx in enumerate(anomaly_indices): + if y_idx >= n_anomaly_rows: + continue + + # Position in bottom half of grid + y_position = n_nominal_rows + y_idx + + if idx < len(anomaly_bin_samples): + row = anomaly_bin_samples.iloc[idx] + selected_images[(y_position, x)] = { + "filename": row["filename"], + "rank_diff": row["rank_diff"], + "signed_rank_diff": row["signed_rank_diff"], + "is_anomaly": True, + "ml_score": row["ml_score"], + "user_score": row["anomaly_score_raw"], + } # Plot the grid - for y in range(n_grid): + for y in range(n_user_grid): for x in range(n_grid): # Create subplot at the right position (add 1 to account for labels) ax = plt.subplot(gs[y + 1, x + 1]) axes[y, x] = ax - # Get bin count - bin_count = bins[y, x] - # If we have an image for this cell, display it bin_key = (y, x) - if bin_key in representative_images and bin_count > 0: - filename, _ = representative_images[bin_key] + if bin_key in selected_images: + image_data = selected_images[bin_key] + filename = image_data["filename"] try: img_path = os.path.join(data_dir, os.path.basename(filename)) img = np.array(Image.open(img_path)) @@ -1530,297 +1611,397 @@ def alpha_scaling(count, max_count, min_count, total_images): if img.ndim == 2 or (img.ndim == 3 and img.shape[2] == 1): img = np.repeat(img[..., None], 3, axis=2) - # Calculate alpha based on bin count - img_alpha = alpha_scaling(bin_count, max_occupancy, min_occupancy, total_images) - - # Create a semi-transparent gray overlay - overlay = np.ones_like(img) * 128 # Gray color - overlay_alpha = 1.0 - img_alpha # Invert alpha for overlay - - # Blend the image with the gray overlay - blended_img = img * img_alpha + overlay * overlay_alpha - blended_img = np.clip(blended_img, 0, 255).astype(np.uint8) - - # Display the blended image - ax.imshow(blended_img) + # Display the image + ax.imshow(img) + + # Add frame based on anomaly status - Fix the red box positioning + if image_data["is_anomaly"]: + # Create a properly positioned rectangle that covers the whole image + # Use the axes coordinates instead of data coordinates + rect = patches.Rectangle( + (0, 0), # Start at the bottom left + 1, + 1, # Full width and height in axes coordinates + transform=ax.transAxes, # Use axes coordinates + linewidth=3, # Slightly thinner line for better appearance + edgecolor="red", + facecolor="none", + alpha=0.8, + zorder=10, # Ensure the box is drawn on top + ) + ax.add_patch(rect) - # Add count as small number in corner - if bin_count > 1: - ax.text( - 0.05, - 0.05, - str(bin_count), + # Add a second rectangle slightly smaller to create a thicker border effect + rect_inner = patches.Rectangle( + (0.01, 0.01), # Slightly inset + 0.98, + 0.98, # Slightly smaller transform=ax.transAxes, - color="white", - fontsize=6, - bbox=dict(facecolor="black", alpha=0.7, pad=1), + linewidth=2, + edgecolor="red", + facecolor="none", + alpha=0.6, + zorder=10, ) + ax.add_patch(rect_inner) + + # Add small text showing signed rank difference + signed_diff = int(image_data["signed_rank_diff"]) + sign = "+" if signed_diff > 0 else "" # Plus sign for positive values + text_color = "white" + # Use different background colors to indicate direction + bg_color = ( + "black" if signed_diff > 0 else "black" if signed_diff < 0 else "black" + ) + + ax.text( + 0.95, + 0.05, + f"{sign}{signed_diff}", + transform=ax.transAxes, + color=text_color, + fontsize=6, + fontweight="bold", + ha="right", + va="bottom", + bbox=dict(facecolor=bg_color, alpha=0.8, pad=2), + zorder=20, # Ensure text is on top + ) + except Exception as e: logger.warning(f"Error loading image {filename}: {e}") - ax.set_facecolor("black") + # Create a black empty image for error cases + empty_img = np.zeros((224, 224, 3), dtype=np.uint8) + ax.imshow(empty_img) else: - # Empty cell - ax.set_facecolor("lightgray") + # Empty cell - create a light gray empty image + empty_img = np.ones((224, 224, 3), dtype=np.uint8) * 200 # Light gray (200,200,200) + ax.imshow(empty_img) # Remove axis ticks and labels for grid cells ax.set_xticks([]) - ax.set_yticks([]) # Add y-axis labels (left side) with bin boundaries - for y in range(n_grid): + ax.set_yticks([]) + + # Add a white separator line at the boundary between nominal and anomalies + if y == n_nominal_rows - 1: + # Add a white border at the bottom of the cell + ax.spines["bottom"].set_visible(True) + ax.spines["bottom"].set_color("white") + ax.spines["bottom"].set_linewidth(4) + ax.spines["bottom"].set_zorder(250) + if y == n_nominal_rows: + # Add a white border at the bottom of the cell + ax.spines["top"].set_visible(True) + ax.spines["top"].set_color("white") + ax.spines["top"].set_linewidth(2) + ax.spines["top"].set_zorder(250) + + # Font size for axis labels + if fig_title and not ("Low" in fig_title or "High" in fig_title): + ticks_fontsize = 8 + tick_rotation = 0 + else: + ticks_fontsize = 12 + tick_rotation = 90 + + # Add row markers + for y in range(n_user_grid): ax = plt.subplot(gs[y + 1, 0]) - # Display bin boundaries for user scores (inverted y-axis) - y_idx = n_grid - 1 - y # Invert to match the grid orientation - lower_bound = user_percentiles[y_idx] - upper_bound = user_percentiles[y_idx + 1] - ax.text( - 0.25, - 0.5, - f"{upper_bound:.2f}\nto\n{lower_bound:.2f}", - ha="center", - va="center", - fontsize=10 if n_grid <= 20 else 8, - ) + # Add text indicators for sections + if y == 0: + ax.text( + 0.5, + 0.5, + "Normal\nImages", + ha="center", + va="center", + rotation=tick_rotation, + fontsize=ticks_fontsize, + color="black", + fontweight="bold", + ) + elif y == n_nominal_rows: + ax.text( + 0.5, + 0.5, + "Anomaly\nImages", + ha="center", + va="center", + rotation=tick_rotation, + fontsize=ticks_fontsize, + color="red", + fontweight="bold", + ) ax.set_xticks([]) ax.set_yticks([]) ax.set_facecolor("none") for spine in ax.spines.values(): spine.set_visible(False) - # Add x-axis labels (bottom) with bin boundaries + # Add column markers to show ML score percentiles for x in range(n_grid): - ax = plt.subplot(gs[n_grid + 1, x + 1]) - lower_bound = ml_percentiles[x] - upper_bound = ml_percentiles[x + 1] - ax.text( - 0.5, - 0.25, - f"{lower_bound:.2f}\nto\n{upper_bound:.2f}", - ha="center", - va="center", - rotation=90, - fontsize=10 if n_grid <= 20 else 8, - ) + ax = plt.subplot(gs[n_user_grid + 1, x + 1]) + # If it's a specific column, add percentile indicator + if x == 0: + ax.text(0.5, 0.5, "Low AM", ha="center", va="center", fontsize=ticks_fontsize) + elif x == n_grid - 1: + ax.text(0.5, 0.5, "High AM", ha="center", va="center", fontsize=ticks_fontsize) + # You could also show specific percentile values here if needed + # e.g., ax.text(0.5, 0.5, f"{x * 100/n_grid:.0f}%", ha="center", va="center", fontsize=8) ax.set_xticks([]) ax.set_yticks([]) ax.set_facecolor("none") for spine in ax.spines.values(): spine.set_visible(False) - # Add horizontal histogram (top) for anomaly distribution with one bar per grid column - ax_hist_top = plt.subplot(gs[0, 1 : n_grid + 1]) + # Add horizontal histogram (top) for anomaly distribution + if fig_title and not ("Low" in fig_title or "High" in fig_title) and False: + ax_hist_top = plt.subplot(gs[0, 1 : n_grid + 1]) - # Create anomaly mask based on threshold - threshold = 0.9 - anomaly_mask = merged_df["anomaly_score_raw"] >= threshold - - # Create histogram showing anomaly counts per ML score percentile bin - anomaly_counts = np.zeros(n_grid) - for i in range(n_grid): - # Find samples in this ML score percentile bin - if i == n_grid - 1: - bin_mask = (merged_df["ml_score"] >= ml_percentiles[i]) & ( - merged_df["ml_score"] <= ml_percentiles[i + 1] - ) - else: - bin_mask = (merged_df["ml_score"] >= ml_percentiles[i]) & ( - merged_df["ml_score"] < ml_percentiles[i + 1] - ) + # Create anomaly mask based on threshold + threshold = 0.9 + anomaly_mask = merged_df["anomaly_score_raw"] >= threshold - # Count anomalies in this bin - anomaly_counts[i] = np.sum(anomaly_mask & bin_mask) + # Calculate global percentile bin edges for ML scores for the histogram + global_ml_percentiles = np.percentile( + merged_df["ml_score"], np.linspace(0, 100, n_grid + 1) + ) - bar_positions = np.arange(n_grid) + 0.5 # Center of each grid cell - bar_width = 0.8 # Width of each bar (slightly less than 1 to have small gaps) + # Create histogram showing anomaly counts per ML score percentile bin + anomaly_counts = np.zeros(n_grid) + for i in range(n_grid): + # Find samples in this ML score percentile bin + if i == n_grid - 1: + bin_mask = (merged_df["ml_score"] >= global_ml_percentiles[i]) & ( + merged_df["ml_score"] <= global_ml_percentiles[i + 1] + ) + else: + bin_mask = (merged_df["ml_score"] >= global_ml_percentiles[i]) & ( + merged_df["ml_score"] < global_ml_percentiles[i + 1] + ) - ax_hist_top.bar( - bar_positions, - anomaly_counts, - width=bar_width, - align="center", - color="orange", - edgecolor="darkorange", - alpha=0.7, - ) + # Count anomalies in this bin + anomaly_counts[i] = np.sum(anomaly_mask & bin_mask) - ax_hist_top.spines["top"].set_visible(False) - ax_hist_top.spines["right"].set_visible(False) - ax_hist_top.set_ylabel("Anomalies", fontsize=10) - ax_hist_top.yaxis.set_tick_params(labelsize=10) - ax_hist_top.set_title(f"ML Score Anomaly Distribution (threshold={threshold})", fontsize=12) - ax_hist_top.grid(alpha=0.3) - ax_hist_top.set_xticks([]) # Remove xticks - - # Set x-axis limits to align with grid - ax_hist_top.set_xlim(0, n_grid) - - # Add vertical histogram (right) for user scores with one bar per grid row - ax_hist_right = plt.subplot(gs[1 : n_grid + 1, n_grid + 1]) - - # Create histogram with exactly n_grid bars aligned with grid cells - hist_counts, _ = np.histogram(merged_df["anomaly_score_raw"], bins=user_percentiles) - # We need to invert the counts to match the inverted y-axis in the grid - hist_counts = hist_counts[::-1] - bar_positions = np.arange(n_grid) + 0.5 # Center of each grid cell - bar_width = 0.8 # Width of each bar (slightly less than 1 to have small gaps) - - ax_hist_right.barh( - bar_positions, - hist_counts, - height=bar_width, - align="center", - color="lightgreen", - edgecolor="darkgreen", - alpha=0.7, - ) + # Count anomalies in this bin + anomaly_counts[i] = np.sum(anomaly_mask & bin_mask) - ax_hist_right.spines["top"].set_visible(False) - ax_hist_right.spines["right"].set_visible(False) - ax_hist_right.set_xlabel("Count", fontsize=12) - ax_hist_right.xaxis.set_tick_params(labelsize=10) - ax_hist_right.set_yticks([]) # Remove yticks + bar_positions = np.arange(n_grid) + 0.5 # Center of each grid cell + bar_width = 0.8 # Width of each bar - # Set y-axis limits to align with grid - ax_hist_right.set_ylim(0, n_grid) + ax_hist_top.bar( + bar_positions, + anomaly_counts, + width=bar_width, + align="center", + color="orange", + edgecolor="darkorange", + alpha=0.7, + ) - # Replace the regular title with properly positioned text on the right side - # Remove the original title call - # ax_hist_right.set_title("User Score Distribution", fontsize=12) + ax_hist_top.spines["top"].set_visible(False) + ax_hist_top.spines["right"].set_visible(False) + ax_hist_top.set_ylabel("Anomalies", fontsize=10) + ax_hist_top.yaxis.set_tick_params(labelsize=10) + ax_hist_top.set_title( + f"Distribution of AM Scores of Anomalies (GZ Scores > {threshold})", + fontsize=12, + ) + ax_hist_top.grid(alpha=0.3) + ax_hist_top.set_xticks([]) # Remove xticks - # Add rotated text to the right of the histogram, flipped 180° - ax_hist_right.text( - 1.15, - 0.5, - "User Score Distribution", - rotation=270, # Flipped 180° from original 90° - transform=ax_hist_right.transAxes, - ha="center", - va="center", - fontsize=12, - ) + # Set x-axis limits to align with grid + ax_hist_top.set_xlim(0, n_grid) - ax_hist_right.grid(alpha=0.3) + # Set y-axis limits based on figure title + if fig_title and "Low" in fig_title: + ax_hist_top.set_ylim(0, 1) + ax_hist_top.yaxis.set_label_position("right") + ax_hist_top.yaxis.tick_right() # Add axis titles - fig.text(0.5, 0.01, "AnomalyMatch Scores", ha="center", fontsize=16) - fig.text(0.01, 0.5, "User Scores", va="center", rotation=90, fontsize=16) - - # Add figure title if provided - if fig_title: - fig.suptitle(fig_title, fontsize=16, y=0.99) - - # Add a legend for the alpha transparency - ax_legend = plt.subplot(gs[n_grid + 1, 0]) - ax_legend.set_xticks([]) - ax_legend.set_yticks([]) - ax_legend.set_facecolor("none") - for spine in ax_legend.spines.values(): - spine.set_visible(False) + if fig_title and ("Low User Scores" in fig_title): + fig.text( + 0.5, 0.01, "AnomalyMatch Score (higher than 80th percentile)", ha="center", fontsize=16 + ) + fig.text( + 0.01, + 0.5, + "GZ Score (lower than 50th percentile)", + va="center", + rotation=90, + fontsize=ticks_fontsize + 2, + ) + elif fig_title and ("High User Scores" in fig_title): + fig.text( + 0.5, 0.01, "AnomalyMatch Score (lower than 50th percentile)", ha="center", fontsize=16 + ) + fig.text( + 0.01, + 0.5, + "GZ Score (higher than 80th percentile)", + va="center", + rotation=90, + fontsize=ticks_fontsize + 2, + ) + else: + fig.text( + 0.5, + 0.09, + "AnomalyMatch Score (class-specific percentile bins)", + ha="center", + fontsize=ticks_fontsize + 3, + ) + fig.text( + 0.1, + 0.5, + "Sample Type & Rank Difference", + va="center", + rotation=90, + fontsize=ticks_fontsize + 3, + ) - # Add text explaining the alpha transparency + # Add a legend for the visual indicators in the bottom-left corner + ax_legend = plt.subplot(gs[n_user_grid + 1, 0]) + legend_text = "+/-n: AM ranks\n" "higher/lower\n" "score than GZ\n" ax_legend.text( + 0.35, 0.5, - 0.5, - "Darker = \n More\nSamples", + legend_text, ha="center", va="center", - fontsize=10, - transform=ax_legend.transAxes, + fontsize=8, + fontweight="bold", + linespacing=1.3, ) + ax_legend.set_xticks([]) + ax_legend.set_yticks([]) + ax_legend.set_facecolor("none") + + for spine in ax_legend.spines.values(): + spine.set_visible(False) # Save figure with high resolution - output_path = os.path.join(plots_dir, f"score_vs_user_score_grid_{suffix}_iter{iteration}.png") - plt.tight_layout() - if fig_title: - plt.subplots_adjust(top=0.95) # Make room for the title - plt.savefig(output_path, dpi=DEFAULT_DPI) + output_path = os.path.join(plots_dir, f"score_vs_user_score_grid_{suffix}_iter{iteration}.jpg") + # plt.tight_layout() + if False: + if fig_title: + # Add title with proper padding + plt.suptitle(fig_title, fontsize=16, fontweight="bold", y=0.98) + plt.subplots_adjust(top=0.92) # Make room for the title + else: + plt.subplots_adjust(top=0.95) + plt.savefig(output_path, dpi=300) plt.close() - logger.info(f"Score vs user score grid plot ({suffix}) saved to {output_path}") def create_rank_comparison_plot(merged_df, plots_dir, iteration): """ - Create a scatter plot comparing the rank positions of ML scores vs user scores. - A perfect correlation would show as a diagonal line. - - Args: - merged_df: DataFrame with merged scores and filenames - plots_dir: Directory to save plots - iteration: Current training iteration + Create a histogram showing the relative occurrence of absolute rank deviations + between ML scores and user scores, with an overlaid histogram for anomalies only. """ - plt.figure(figsize=(8, 8)) + plt.figure(figsize=(10, 6)) + + # Calculate ranks for both scores using 'first' method to handle ties consistently - # Calculate ranks for both scores - # Use 'first' method to handle ties consistently ml_ranks = merged_df["ml_score"].rank(method="first") user_ranks = merged_df["anomaly_score_raw"].rank(method="first") - # Convert to percentile ranks (0-100) + # Calculate absolute relative rank deviation n_samples = len(merged_df) - ml_ranks_pct = (ml_ranks / n_samples) * 100 - user_ranks_pct = (user_ranks / n_samples) * 100 + rank_deviation = np.abs((ml_ranks - user_ranks)) + n_bins = 50 + + # Create mask for anomalies (user_score > 0.9) + anomaly_mask = merged_df["anomaly_score_raw"] > 0.9 + n_anomalies = anomaly_mask.sum() + + # Plot histogram for all samples + counts, bins = np.histogram(rank_deviation, bins=n_bins) + percentages = (counts / n_samples) * 100 + bin_centers = 0.5 * (bins[1:] + bins[:-1]) + bar_width = bins[1] - bins[0] + + plt.bar( + bin_centers, + percentages, + width=bar_width, + color=BLUE, + alpha=0.5, # Reduced alpha for better overlay visibility + edgecolor="black", + linewidth=1, + label=f"All Objects, N={n_samples}", + ) - # Plot vertical bars to perfect correlation line first - for ml_rank, user_rank in zip(ml_ranks_pct, user_ranks_pct): - # Calculate the point on the perfect correlation line - perfect_point = ml_rank - plt.plot( - [ml_rank, ml_rank], - [user_rank, perfect_point], - color="gray", - alpha=0.005, - linewidth=1, - zorder=1, - ) + # Plot histogram for anomalies + anomaly_counts, _ = np.histogram(rank_deviation[anomaly_mask], bins=bins) + anomaly_percentages = (anomaly_counts / n_anomalies) * 100 - # Plot perfect correlation line - plt.plot( - [0, 100], - [0, 100], - color=PERFECT_LINE_COLOR, - linestyle=PERFECT_LINE_STYLE, - alpha=PERFECT_LINE_ALPHA, - linewidth=2, - label="Perfect Correlation", - zorder=2, + plt.bar( + bin_centers, + anomaly_percentages, + width=bar_width, + color=RED, + alpha=0.5, + edgecolor="darkred", + linewidth=1, + label=f"Anomalies (GZ Score > 0.9, N={n_anomalies})", ) - # Create scatter plot with small points - plt.scatter( - ml_ranks_pct, user_ranks_pct, s=5, color=BLUE, alpha=0.01, label="Samples", zorder=3 + # Calculate statistics for both distributions + mean_deviation = np.mean(rank_deviation) + median_deviation = np.median(rank_deviation) + anomaly_mean = np.mean(rank_deviation[anomaly_mask]) + anomaly_median = np.median(rank_deviation[anomaly_mask]) + + # Add vertical lines for statistics + plt.axvline( + x=mean_deviation, + color=BLUE, + linestyle="--", + linewidth=1.5, + label=f"Mean (All) = {mean_deviation:.0f}", + ) + plt.axvline( + x=median_deviation, + color=BLUE, + linestyle=":", + linewidth=1.5, + label=f"Median (All) = {median_deviation:.0f}", + ) + plt.axvline( + x=anomaly_mean, + color=RED, + linestyle="--", + linewidth=1.5, + label=f"Mean (Anomalies) = {anomaly_mean:.0f}", + ) + plt.axvline( + x=anomaly_median, + color=RED, + linestyle=":", + linewidth=1.5, + label=f"Median (Anomalies) = {anomaly_median:.0f}", ) # Add grid and labels plt.grid(alpha=0.3) - plt.xlabel("AnomalyMatch Score Rank (%)", fontsize=14) - plt.ylabel("User Score Rank (%)", fontsize=14) - plt.title("Comparison of AnomalyMatch and User Score Rankings", fontsize=16) + plt.xlabel("$|$Rank$_{{AM}} - $Rank$_{{GZ}}|$", fontsize=24) + plt.ylabel("Relative Occurrence [%]", fontsize=24) - # Set axis limits - plt.xlim(0, 100) - plt.ylim(0, 100) - - # Add legend - plt.legend(loc="lower right", frameon=True, framealpha=0.7) - - # Calculate rank correlation coefficients - spearman_corr = merged_df["ml_score"].corr(merged_df["anomaly_score_raw"], method="spearman") - kendall_corr = merged_df["ml_score"].corr(merged_df["anomaly_score_raw"], method="kendall") - - # Add correlation coefficients as text - plt.text( - 5, - 92, - f"Spearman ρ = {spearman_corr:.3f}\nKendall τ = {kendall_corr:.3f}", - fontsize=12, - bbox=dict(facecolor="white", alpha=0.7), - ) + # Add legend with better placement + plt.legend(loc="upper right", frameon=True, framealpha=0.7, bbox_to_anchor=(1.0, 0.95)) # Save the figure - output_path = os.path.join(plots_dir, f"score_rank_correlation_iter{iteration}.png") + output_path = os.path.join(plots_dir, f"score_rank_correlation_iter{iteration}.pdf") plt.tight_layout() - plt.savefig(output_path, dpi=DEFAULT_DPI) + plt.savefig(output_path, dpi=200) plt.close() + # Log statistics + logger.info(f"Mean rank deviation (all): {mean_deviation:.3f}") + logger.info(f"Median rank deviation (all): {median_deviation:.3f}") + logger.info(f"Mean rank deviation (anomalies): {anomaly_mean:.3f}") + logger.info(f"Median rank deviation (anomalies): {anomaly_median:.3f}") logger.info(f"Rank correlation plot saved to {output_path}") - logger.info(f"Spearman correlation: {spearman_corr:.3f}") - logger.info(f"Kendall correlation: {kendall_corr:.3f}") diff --git a/paper_scripts/paper_utils.py b/paper_scripts/paper_utils.py index 94cd3eb..9aabffc 100644 --- a/paper_scripts/paper_utils.py +++ b/paper_scripts/paper_utils.py @@ -227,7 +227,7 @@ def setup_mock_ui(): return out, progress_bar -def setup_pipeline(args, data_dir, labeled_data_path, run_dir, output_widget, progress_bar): +def setup_pipeline(args, data_dir, labeled_data_path, run_dir, output_widget, seed, progress_bar): """Set up the AnomalyMatch pipeline for training.""" model_path = os.path.join(run_dir, "models", "model.pth") @@ -237,20 +237,21 @@ def setup_pipeline(args, data_dir, labeled_data_path, run_dir, output_widget, pr cfg.model_path = model_path cfg.data_dir = data_dir cfg.label_file = labeled_data_path - cfg.size = [args.size, args.size] + cfg.normalisation.image_size = [args.size, args.size] cfg.N_to_load = args.n_to_load # Use parameter for number of images to load cfg.test_ratio = 0.0 # No test evaluation within the session cfg.output_dir = str(run_dir) cfg.num_train_iter = args.train_iterations - cfg.progress_bar = progress_bar cfg.num_workers = 4 cfg.pin_memory = True + cfg.seed = seed # Configure logging am.set_log_level("info", cfg) # Create session session = am.Session(cfg) + # change the output path of io to the run directory session.session_io.base_save_path = Path(run_dir) diff --git a/paper_scripts/recreate_plots.py b/paper_scripts/recreate_plots.py index 1837d1b..1aae0c4 100644 --- a/paper_scripts/recreate_plots.py +++ b/paper_scripts/recreate_plots.py @@ -99,6 +99,8 @@ def recreate_plots_from_data(results_dir, plot_type=None, custom_thresholds=None # Find all plot data files all_plot_files = [] for plot_data_dir in glob.glob(str(results_dir / "**" / "plot_data"), recursive=True): + if "plots_recreated" in plot_data_dir: + continue plot_files = glob.glob(os.path.join(plot_data_dir, "*.pkl")) all_plot_files.extend(plot_files) @@ -172,7 +174,7 @@ def recreate_plots_from_data(results_dir, plot_type=None, custom_thresholds=None elif plot_type_name == "roc_prc_curves": logger.info(f"Recreating ROC/PRC curves for iteration {iteration}") # The function expects a metrics dictionary - plot_roc_prc_curves(plot_data, iteration, plots_dir) + plot_roc_prc_curves(plot_data["metrics"], plot_data["iteration"], plots_dir) elif plot_type_name == "top_n_anomaly_detection": logger.info(f"Recreating top-N anomaly detection for iteration {iteration}") plot_top_n_anomaly_detection( @@ -333,7 +335,7 @@ def plot_active_learning_comparison(results_dir): # Look for directories matching the expected names for exp_name in experiment_dirs.keys(): - matching_dirs = list(al_dir.glob(f"**/*{exp_name}*")) + matching_dirs = list(al_dir.glob(f"**/*{exp_name}/")) if matching_dirs: experiment_dirs[exp_name] = matching_dirs[0] logger.info(f"Found {exp_name} directory: {matching_dirs[0]}") diff --git a/paper_scripts/test_plots.py b/paper_scripts/test_plots.py index 984111d..47e57a9 100644 --- a/paper_scripts/test_plots.py +++ b/paper_scripts/test_plots.py @@ -180,7 +180,7 @@ def generate_detection_data(total_samples=10000, total_anomalies=100): # For score vs user score grid plot -def generate_grid_mock_data(n_samples=500): +def generate_grid_mock_data(n_samples=15000): """Generate mock data for testing the score vs user score grid plot.""" # Generate random scores between 0 and 1 ml_scores = np.random.beta(2, 2, n_samples) # Using beta distribution for better spread diff --git a/prediction_process.py b/prediction_process.py index 5a38054..5efb448 100644 --- a/prediction_process.py +++ b/prediction_process.py @@ -14,50 +14,65 @@ import numpy as np from loguru import logger from concurrent.futures import ThreadPoolExecutor -import pandas as pd from tqdm import tqdm import time +from anomaly_match.data_io.load_images import ( + load_and_process_single_wrapper, +) + from prediction_utils import ( load_model, save_results, process_batch_predictions, + estimate_batch_size, + clear_gpu_cache_if_needed, ) from anomaly_match.image_processing.transforms import ( get_prediction_transforms, ) -from anomaly_match.data_io.load_images import read_and_resize_image - -# Configure logging -logs_dir = os.path.join(os.path.dirname(__file__), "logs") -os.makedirs(logs_dir, exist_ok=True) -logger.remove() -logger.add( - os.path.join(logs_dir, "prediction_thread_{time}.log"), - rotation="1 MB", - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", - level="DEBUG", -) def load_and_preprocess(args): - filename, transform, cfg = args - image = read_and_resize_image( - filename, - cfg=cfg, - convert_to_rgb=True, + """Load and preprocess a single image file. + + Note: Returns numpy array, not tensor. Tensor conversion is done on main + thread to avoid CUDA context issues in ThreadPoolExecutor. + """ + filepath, cfg = args + image = load_and_process_single_wrapper( + filepath, + cfg, + desc="image prediction process", + show_progress=False, + prediction=True, ) - image = transform(image) - return filename, image + return filepath, image def evaluate_files(file_list, cfg, top_n=1000, batch_size=1000, max_workers=1): - """Evaluate files in batches and return top N scores.""" + """Evaluate files in batches and return top N scores. + file list is a list of cfg.prediction_search_dir+filename + """ logger.trace(f"{len(file_list)} unlabeled images remain.") + # Load model first - this loads the fitsbolt config from the checkpoint + model = load_model(cfg) + model.eval() + + # Require fitsbolt config from model checkpoint for consistent predictions + if not hasattr(cfg, "fitsbolt_cfg") or cfg.fitsbolt_cfg is None: + raise ValueError( + "Fitsbolt config not found in model checkpoint. " + "Please retrain the model with the updated version to include normalisation settings." + ) + logger.debug("Using fitsbolt config loaded from model checkpoint") + transform = get_prediction_transforms() - args_list = [(filename, transform, cfg) for filename in file_list] + + # I/O in ThreadPool (returns numpy arrays) + args_list = [(filepath, cfg) for filepath in file_list] with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list( @@ -68,27 +83,32 @@ def evaluate_files(file_list, cfg, top_n=1000, batch_size=1000, max_workers=1): ) ) - model = load_model(cfg) - model.eval() - # Process in batches scores_list = [] filenames_list = [] imgs_list = [] - for i in range(0, len(results), batch_size): + for batch_idx, i in enumerate(range(0, len(results), batch_size)): batch = results[i : i + batch_size] # noqa: E203 batch_filenames = [item[0] for item in batch] - batch_images = [item[1] for item in batch] + numpy_images = [item[1] for item in batch] - # Stack images into a batch tensor - images = torch.stack(batch_images, dim=0) + # Tensor conversion on main thread (not in ThreadPool) + batch_tensors = [transform(img) for img in numpy_images] + images = torch.stack(batch_tensors, dim=0) + del numpy_images, batch_tensors # Free memory before CUDA ops + + # CUDA inference with explicit cleanup batch_scores, batch_imgs = process_batch_predictions(model, images) + del images # Free CUDA tensor reference scores_list.append(batch_scores) filenames_list.extend(batch_filenames) imgs_list.append(batch_imgs) + # Periodic GPU cache clearing to prevent fragmentation + clear_gpu_cache_if_needed(batch_idx) + # Concatenate results all_scores = np.concatenate(scores_list) all_imgs = np.concatenate(imgs_list) @@ -120,6 +140,12 @@ def main(): logger.error(f"Failed to load config from {args.config_path}: {e}") sys.exit(1) + logger.info("Setting batch size") + batch_size = ( + estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction + ) + logger.info(f"Batch size set to: {batch_size}") + logger.info(f"Loading file list from {args.file_list_path}") with open(args.file_list_path, "r") as f: group_list = [line.strip() for line in f] @@ -128,59 +154,27 @@ def main(): file_list = [line.strip() for line in f] logger.info(f"Found {len(file_list)} files to process") - # Load existing results if they exist - output_csv_path = os.path.join(cfg.output_dir, f"{cfg.save_file}_top{args.top_n}.csv") - output_npy_path = os.path.join(cfg.output_dir, f"{cfg.save_file}_top{args.top_n}.npy") - - if os.path.exists(output_csv_path) and os.path.exists(output_npy_path): - logger.info("Found existing results, loading...") - existing_df = pd.read_csv(output_csv_path) - existing_filenames = existing_df["Filename"].values - existing_scores = existing_df["Score"].values - - existing_imgs = np.load(output_npy_path) - else: - existing_filenames = np.array([]) - existing_scores = np.array([]) - # Define image shape: (num_samples, channels, height, width) - existing_imgs = np.empty((0, 3, cfg.size[0], cfg.size[1]), dtype=np.float32) - logger.info("Starting evaluation...") - scores, filenames, imgs = evaluate_files(file_list, cfg, top_n=args.top_n) - logger.success(f"Evaluation complete. Computed {len(scores)} scores") - - # Merge new results with existing results - all_filenames = np.concatenate([existing_filenames, filenames]) - all_scores = np.concatenate([existing_scores, scores]) - # Merge new results with existing results - if existing_imgs.size == 0: - all_imgs = imgs - else: - all_imgs = np.concatenate([existing_imgs, imgs]) - - # Keep only top N results - top_indices = np.argsort(all_scores)[::-1][: args.top_n] - top_filenames = all_filenames[top_indices] - top_scores = all_scores[top_indices] - top_imgs = all_imgs[top_indices] - - logger.info( - f"Score statistics - Min: {np.min(top_scores):.4f}, Max: {np.max(top_scores):.4f}" - + f", Mean: {np.mean(top_scores):.4f}, Std: {np.std(top_scores):.4f}" + # evaluate_files calls save_results internally which handles accumulation + # across multiple batches - no additional merging needed here + scores, filenames, imgs = evaluate_files( + file_list, cfg, batch_size=batch_size, top_n=args.top_n ) - - logger.info(f"Saving results to {output_csv_path} and {output_npy_path}") - - # Save merged results to CSV using pandas - df = pd.DataFrame({"Filename": top_filenames, "Score": top_scores}) - df.to_csv(output_csv_path, index=False) - - # Save merged images using numpy - np.save(output_npy_path, top_imgs) + logger.success(f"Evaluation complete. Top {len(scores)} scores returned") elapsed_time = time.time() - start_time logger.success(f"Script completed in {elapsed_time:.2f} seconds") if __name__ == "__main__": + # Configure logging + logs_dir = os.path.join(os.path.dirname(__file__), "logs") + os.makedirs(logs_dir, exist_ok=True) + logger.remove() + logger.add( + os.path.join(logs_dir, "prediction_thread_{time}.log"), + rotation="1 MB", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", + level="DEBUG", + ) main() diff --git a/prediction_process_cutana.py b/prediction_process_cutana.py new file mode 100644 index 0000000..7a661e2 --- /dev/null +++ b/prediction_process_cutana.py @@ -0,0 +1,323 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +import argparse +import os +import sys +import pickle + +from dotmap import DotMap +import torch +import numpy as np +from loguru import logger +from concurrent.futures import ThreadPoolExecutor +import time +from tqdm import tqdm +import cutana + +from anomaly_match.data_io.load_images import process_single_wrapper + +from prediction_utils import ( + load_model, + save_results, + process_batch_predictions, + estimate_batch_size, + clear_gpu_cache_if_needed, +) + +from anomaly_match.image_processing.transforms import ( + get_prediction_transforms, +) + + +def read_and_preprocess_image_from_zarr(image_data, cfg): + """Read and preprocess image data from Zarr array using standardized functions.""" + try: + # Convert Zarr data to numpy array if it's not already + if not isinstance(image_data, np.ndarray): + image_data = np.array(image_data) + + # Check if we need to transpose based on the shape + # If last dimension is 3 (RGB channels), data is already in HWC format + # If first dimension is 3, data is in CHW format and needs transposing + if image_data.shape[0] == cfg.normalisation.n_output_channels: + # In CHW format, convert to HWC + image = image_data.transpose(1, 2, 0) + else: + # Assume HWC format if neither first nor last dimension is 3 + # This handles grayscale or other formats + image = image_data + + # Use the centralized processing function - this handles RGB conversion, + # normalization, and resizing efficiently without temporary files + processed_image = process_single_wrapper(image, cfg, desc="zarr") + return processed_image + + except Exception as e: + logger.error(f"Error processing image from Zarr: {e}") + raise + + +def load_and_preprocess_zarr(args): + """Load and preprocess a single image from Zarr. + + Note: Returns numpy array, not tensor. Tensor conversion is done on main + thread to avoid CUDA context issues in ThreadPoolExecutor. + """ + image_data, cfg = args + return read_and_preprocess_image_from_zarr(image_data, cfg) + + +def evaluate_images_from_cutana( + cutana_sources_path, cfg, top_n=1000, batch_size=1000, max_workers=4 +): + """Evaluate images provided by Cutana stream and return top N scores.""" + + cutana_config = cutana.get_default_config() + + cutana_config.target_resolution = cfg.normalisation.image_size[0] + cutana_config.source_catalogue = cutana_sources_path + + # Configure FITS extensions from AM config, default to PRIMARY if not specified + # fits_extension can be: None, str/int, list of str/int, or list of tuples (name, ext_type) + fits_ext = cfg.normalisation.fits_extension + if fits_ext is None: + fits_ext = ["PRIMARY"] + elif isinstance(fits_ext, (str, int)): + fits_ext = [fits_ext] + + # Build selected_extensions - handle both simple names and (name, ext_type) tuples + selected_extensions = [] + extension_names = [] + for ext in fits_ext: + if isinstance(ext, tuple): + name, ext_type = ext + selected_extensions.append({"name": str(name), "ext": ext_type}) + extension_names.append(name) + else: + selected_extensions.append({"name": str(ext), "ext": "PrimaryHDU"}) + extension_names.append(ext) + + cutana_config.fits_extensions = extension_names + cutana_config.selected_extensions = selected_extensions + + # Pass channel combination - required for multi-extension data + if cfg.normalisation.channel_combination is not None: + cutana_config.channel_weights = cfg.normalisation.channel_combination + elif len(fits_ext) > 1: + raise ValueError( + "cfg.normalisation.channel_combination must be set when using multiple FITS extensions. " + "This defines how extensions are combined into RGB channels." + ) + + # Pass AnomalyMatch's fitsbolt_cfg directly to cutana for normalization + # This ensures cutana uses the exact same normalization settings as training + if hasattr(cfg, "fitsbolt_cfg") and cfg.fitsbolt_cfg is not None: + cutana_config.external_fitsbolt_cfg = cfg.fitsbolt_cfg + logger.debug("Passed fitsbolt_cfg to cutana for normalization") + + try: + logger.info(f"Creating Cutana orchestrator, streaming from {cutana_sources_path}") + logger.debug( + f"Cutana config: target_resolution={cutana_config.target_resolution}, " + f"fits_extensions={cutana_config.fits_extensions}, " + f"selected_extensions={cutana_config.selected_extensions}" + ) + + cutana_orchestrator = cutana.StreamingOrchestrator(cutana_config) + + cutana_orchestrator.init_streaming( + batch_size=batch_size, write_to_disk=False, synchronised_loading=False + ) + except Exception as e: + logger.error(f"Failed to initialize Cutana orchestrator: {e}") + raise + + logger.info("Cutana orchestrator streaming mode initalized") + + logger.info(f"Available batches in cutana: {cutana_orchestrator.get_batch_count()}") + + model = load_model(cfg) + model.eval() + transform = get_prediction_transforms() + + # Process images in batches + scores_list = [] + imgs_list = [] + + start_time = time.time() + last_log_time = start_time + processed_since_last_log = 0 + + # Require fitsbolt config from model checkpoint for consistent predictions + # Note: DotMap auto-creates empty DotMaps when accessing missing keys + # So we check for 'size' key which must exist in a valid fitsbolt config + fitsbolt_cfg = cfg.fitsbolt_cfg + if fitsbolt_cfg is None or (isinstance(fitsbolt_cfg, DotMap) and "size" not in fitsbolt_cfg): + raise ValueError( + "fitsbolt_cfg not found in model checkpoint. " + "Models must be saved with fitsbolt config for prediction. " + "Please retrain and save the model to include fitsbolt config." + ) + logger.debug("Using fitsbolt config loaded from model checkpoint") + + batches_count = cutana_orchestrator.get_batch_count() + + num_images = 0 + filenames = [] + + for batch_idx in tqdm(range(batches_count), desc="Processing batches"): + + loaded_batch = cutana_orchestrator.next_batch() + batch_data = loaded_batch["cutouts"] + + # Debug: Log what we received + logger.debug( + f"Batch {batch_idx}: cutouts type={type(batch_data).__name__}, " + f"metadata count={len(loaded_batch.get('metadata', []))}" + ) + + # Handle empty batches (cutana returns [] if all cutouts failed) + if isinstance(batch_data, list): + if len(batch_data) == 0: + logger.warning(f"Batch {batch_idx} returned empty cutouts (list), skipping") + continue + # Convert list to numpy array if needed + batch_data = np.array(batch_data) + + batch_size_actual = batch_data.shape[0] + num_images += batch_size_actual + + batch_filenames = (source["source_id"] for source in loaded_batch["metadata"]) + filenames.extend(batch_filenames) + + # I/O and preprocessing in ThreadPool (returns numpy arrays) + # CUDA operations are kept on main thread to prevent memory fragmentation + batch_process_start = time.time() + with ThreadPoolExecutor(max_workers=max_workers) as executor: + batch_args = [(batch_data[i], cfg) for i in range(batch_size_actual)] + numpy_images = list(executor.map(load_and_preprocess_zarr, batch_args)) + + # Tensor conversion on main thread (not in ThreadPool) to avoid CUDA context issues + stack_start = time.time() + batch_tensors = [transform(img).detach() for img in numpy_images] + images = torch.stack(batch_tensors, dim=0) + del numpy_images, batch_tensors # Free memory before CUDA ops + + # CUDA inference with explicit cleanup + batch_scores, batch_imgs = process_batch_predictions(model, images) + del images # Free CUDA tensor reference + + scores_list.append(batch_scores) + imgs_list.append(batch_imgs) + + # Periodic GPU cache clearing to prevent fragmentation + clear_gpu_cache_if_needed(batch_idx) + + processed_since_last_log += batch_size_actual + current_time = time.time() + + # Log performance every 10000 images or 60 seconds + if processed_since_last_log >= 10000 or (current_time - last_log_time) >= 60: + elapsed = current_time - last_log_time + rate = processed_since_last_log / elapsed + batch_time = current_time - batch_process_start + logger.info( + f"Performance: {rate:.1f} images/sec " + f"(batch {batch_size_actual}: {batch_time:.2f}s, " + f"load: {stack_start - batch_process_start:.2f}s, " + f"inference: {current_time - stack_start:.2f}s)" + ) + last_log_time = current_time + processed_since_last_log = 0 + + cutana_orchestrator.cleanup() + + total_time = time.time() - start_time + logger.info( + f"Total processing time: {total_time:.1f}s, " + f"Average rate: {num_images / total_time:.1f} images/sec" + ) + + # Concatenate results + all_scores = np.concatenate(scores_list) + all_imgs = np.concatenate(imgs_list) + all_filenames = np.array(filenames) + + return save_results(cfg, all_scores, all_imgs, all_filenames, top_n) + + +def main(): + start_time = time.time() + + parser = argparse.ArgumentParser() + parser.add_argument("config_path", type=str, help="Path to config file") + parser.add_argument( + "cutana_sources_path", type=str, help="Path to the directory to stream from" + ) + parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep") + args = parser.parse_args() + + logger.info(f"Loading config from {args.config_path}") + # Load cfg from pkl + try: + with open(args.config_path, "rb") as f: + cfg = pickle.load(f) + cfg = DotMap(cfg) + except Exception as e: + logger.error(f"Failed to load config from {args.config_path}: {e}") + sys.exit(1) + + logger.info("Setting batch size") + batch_size = ( + estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction + ) + logger.info(f"Batch size set to: {batch_size}") + + # Log key configuration parameters + logger.debug("Configuration loaded with parameters:") + logger.debug(f" Save file: {cfg.save_file}") + logger.debug(f" Save path: {cfg.save_path}") + logger.debug(f" Model path: {cfg.model_path}") + logger.debug(f" Output directory: {cfg.output_dir}") + logger.debug(f" Image size: {cfg.normalisation.image_size}") + + # Log full configuration + logger.debug("Full configuration:") + logger.debug(f"{cfg.toDict()}") + + # Create output directory if it doesn't exist + os.makedirs(cfg.output_dir, exist_ok=True) + + logger.info(f"Streaming from directory: {args.cutana_sources_path}") + + try: + evaluate_images_from_cutana( + args.cutana_sources_path, cfg, batch_size=batch_size, top_n=args.top_n + ) + elapsed_time = time.time() - start_time + logger.success(f"Script completed in {elapsed_time:.2f} seconds") + except Exception as e: + logger.exception(f"Error during processing: {str(e)}") + raise + + +if __name__ == "__main__": + + # Configure logging + logs_dir = os.path.join(os.path.dirname(__file__), "logs") + os.makedirs(logs_dir, exist_ok=True) + + # Remove default handler and set up file logging + logger.remove() + script_logger_id = logger.add( + os.path.join(logs_dir, "prediction_cutana_{time}.log"), + rotation="1 MB", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", + level="DEBUG", + ) + logger.add(sys.stderr, level="INFO") + main() diff --git a/prediction_process_hdf5.py b/prediction_process_hdf5.py index 5c9c6e4..de87f33 100644 --- a/prediction_process_hdf5.py +++ b/prediction_process_hdf5.py @@ -18,31 +18,20 @@ import h5py from tqdm import tqdm +from anomaly_match.data_io.load_images import process_single_wrapper + from prediction_utils import ( load_model, save_results, process_batch_predictions, + clear_gpu_cache_if_needed, jpeg_decoder, + estimate_batch_size, ) from anomaly_match.image_processing.transforms import ( get_prediction_transforms, ) -from anomaly_match.data_io.load_images import process_image_array - -# Configure logging -logs_dir = os.path.join(os.path.dirname(__file__), "logs") -os.makedirs(logs_dir, exist_ok=True) - -# Remove default handler and set up file logging -logger.remove() -logger.add( - os.path.join(logs_dir, "prediction_thread_{time}.log"), - rotation="1 MB", - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", - level="DEBUG", -) -logger.add(sys.stderr, level="INFO") def read_and_decode_image_from_hdf5(image_data, cfg): @@ -64,21 +53,26 @@ def read_and_decode_image_from_hdf5(image_data, cfg): image = np.array(Image.open(io.BytesIO(image_bytes))) - processed_image = process_image_array(image, cfg, convert_to_rgb=True, image_source="hdf5") + processed_image = process_single_wrapper(image, cfg, desc="hdf5") return processed_image except Exception as e: logger.error(f"Error decoding image from HDF5: {e}") # Return a blank image as fallback - return np.zeros((cfg.size[0], cfg.size[1], 3), dtype=np.uint8) + return np.zeros( + (cfg.normalisation.image_size[0], cfg.normalisation.image_size[1], 3), + dtype=np.uint8, + ) + +def load_and_preprocess_hdf5(args): + """Load and preprocess a single image from HDF5. -def load_and_preprocess(args): - """Load and preprocess a single image.""" - image_data, transform, cfg = args - image = read_and_decode_image_from_hdf5(image_data, cfg) - image = transform(image) - return image + Note: Returns numpy array, not tensor. Tensor conversion is done on main + thread to avoid CUDA context issues in ThreadPoolExecutor. + """ + image_data, cfg = args + return read_and_decode_image_from_hdf5(image_data, cfg) def evaluate_images_in_hdf5(hdf5_path, cfg, top_n=1000, batch_size=1000, max_workers=4): @@ -108,25 +102,36 @@ def evaluate_images_in_hdf5(hdf5_path, cfg, top_n=1000, batch_size=1000, max_wor last_log_time = start_time processed_since_last_log = 0 - for batch_start in tqdm(range(0, num_images, batch_size), desc="Processing batches"): + for batch_idx, batch_start in enumerate( + tqdm(range(0, num_images, batch_size), desc="Processing batches") + ): batch_end = min(batch_start + batch_size, num_images) batch_data = dataset[batch_start:batch_end] batch_size_actual = len(batch_data) - # Process batch in parallel + # I/O and preprocessing in ThreadPool (returns numpy arrays) + # CUDA operations are kept on main thread to prevent memory fragmentation batch_process_start = time.time() with ThreadPoolExecutor(max_workers=max_workers) as executor: - batch_args = [(data, transform, cfg) for data in batch_data] - batch_images = list(executor.map(load_and_preprocess, batch_args)) + batch_args = [(data, cfg) for data in batch_data] + numpy_images = list(executor.map(load_and_preprocess_hdf5, batch_args)) - # Stack images into a batch tensor and get predictions + # Tensor conversion on main thread (not in ThreadPool) stack_start = time.time() - images = torch.stack(batch_images, dim=0) + batch_tensors = [transform(img) for img in numpy_images] + images = torch.stack(batch_tensors, dim=0) + del numpy_images, batch_tensors # Free memory before CUDA ops + + # CUDA inference with explicit cleanup batch_scores, batch_imgs = process_batch_predictions(model, images) + del images # Free CUDA tensor reference scores_list.append(batch_scores) imgs_list.append(batch_imgs) + # Periodic GPU cache clearing to prevent fragmentation + clear_gpu_cache_if_needed(batch_idx) + processed_since_last_log += batch_size_actual current_time = time.time() @@ -177,13 +182,19 @@ def main(): logger.error(f"Failed to load config from {args.config_path}: {e}") sys.exit(1) + logger.info("Setting batch size") + batch_size = ( + estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction + ) + logger.info(f"Batch size set to: {batch_size}") + # Log key configuration parameters logger.debug("Configuration loaded with parameters:") logger.debug(f" Save file: {cfg.save_file}") logger.debug(f" Save path: {cfg.save_path}") logger.debug(f" Model path: {cfg.model_path}") logger.debug(f" Output directory: {cfg.output_dir}") - logger.debug(f" Image size: {cfg.size}") + logger.debug(f" Image size: {cfg.normalisation.image_size}") # Create output directory if it doesn't exist os.makedirs(cfg.output_dir, exist_ok=True) @@ -191,7 +202,7 @@ def main(): logger.info(f"Processing HDF5 file: {args.hdf5_path}") try: - evaluate_images_in_hdf5(args.hdf5_path, cfg, top_n=args.top_n) + evaluate_images_in_hdf5(args.hdf5_path, cfg, batch_size=batch_size, top_n=args.top_n) elapsed_time = time.time() - start_time logger.success(f"Script completed in {elapsed_time:.2f} seconds") except Exception as e: @@ -200,4 +211,18 @@ def main(): if __name__ == "__main__": + + # Configure logging + logs_dir = os.path.join(os.path.dirname(__file__), "logs") + os.makedirs(logs_dir, exist_ok=True) + + # Remove default handler and set up file logging + logger.remove() + logger.add( + os.path.join(logs_dir, "prediction_thread_{time}.log"), + rotation="1 MB", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", + level="DEBUG", + ) + logger.add(sys.stderr, level="INFO") main() diff --git a/prediction_process_zarr.py b/prediction_process_zarr.py index 336c05d..e375b32 100644 --- a/prediction_process_zarr.py +++ b/prediction_process_zarr.py @@ -20,30 +20,19 @@ from pathlib import Path from tqdm import tqdm +from anomaly_match.data_io.load_images import process_single_wrapper + from prediction_utils import ( load_model, save_results, process_batch_predictions, + estimate_batch_size, + clear_gpu_cache_if_needed, ) from anomaly_match.image_processing.transforms import ( get_prediction_transforms, ) -from anomaly_match.data_io.load_images import process_image_array - -# Configure logging -logs_dir = os.path.join(os.path.dirname(__file__), "logs") -os.makedirs(logs_dir, exist_ok=True) - -# Remove default handler and set up file logging -logger.remove() -logger.add( - os.path.join(logs_dir, "prediction_zarr_{time}.log"), - rotation="1 MB", - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", - level="DEBUG", -) -logger.add(sys.stderr, level="INFO") def read_and_preprocess_image_from_zarr(image_data, cfg): @@ -53,12 +42,20 @@ def read_and_preprocess_image_from_zarr(image_data, cfg): if not isinstance(image_data, np.ndarray): image_data = np.array(image_data) - # Convert from CHW to HWC format - image = image_data.transpose(1, 2, 0) + # Check if we need to transpose based on the shape + # If last dimension is 3 (RGB channels), data is already in HWC format + # If first dimension is 3, data is in CHW format and needs transposing + if image_data.shape[0] == cfg.normalisation.n_output_channels: + # In CHW format, convert to HWC + image = image_data.transpose(1, 2, 0) + else: + # Assume HWC format if neither first nor last dimension is 3 + # This handles grayscale or other formats + image = image_data # Use the centralized processing function - this handles RGB conversion, # normalization, and resizing efficiently without temporary files - processed_image = process_image_array(image, cfg, convert_to_rgb=True, image_source="zarr") + processed_image = process_single_wrapper(image, cfg, desc="zarr") return processed_image except Exception as e: @@ -67,11 +64,13 @@ def read_and_preprocess_image_from_zarr(image_data, cfg): def load_and_preprocess_zarr(args): - """Load and preprocess a single image from Zarr.""" - image_data, transform, cfg = args - image = read_and_preprocess_image_from_zarr(image_data, cfg) - image = transform(image) - return image + """Load and preprocess a single image from Zarr. + + Note: Returns numpy array, not tensor. Tensor conversion is done on main + thread to avoid CUDA context issues in ThreadPoolExecutor. + """ + image_data, cfg = args + return read_and_preprocess_image_from_zarr(image_data, cfg) def evaluate_images_in_zarr(zarr_path, cfg, top_n=1000, batch_size=1000, max_workers=4): @@ -100,6 +99,15 @@ def evaluate_images_in_zarr(zarr_path, cfg, top_n=1000, batch_size=1000, max_wor filenames = [] metadata_file = None + # Generate a unique prefix for this zarr file to avoid filename collisions + # Use the parent directory name for batch folders, or the zarr file name itself + if zarr_path.name == "images.zarr": + # For batch folders, use the parent directory name + zarr_prefix = zarr_path.parent.name + else: + # For direct zarr files, use the zarr file name + zarr_prefix = zarr_path.stem + # Check for metadata file in Zarr attributes if "metadata_file" in root.attrs: metadata_file = Path(root.attrs["metadata_file"]) @@ -109,9 +117,16 @@ def evaluate_images_in_zarr(zarr_path, cfg, top_n=1000, batch_size=1000, max_wor # Fallback: look for metadata parquet file next to zarr if metadata_file is None or not metadata_file.exists(): + # First try: _metadata.parquet next to zarr file potential_metadata = zarr_path.parent / f"{zarr_path.stem}_metadata.parquet" if potential_metadata.exists(): metadata_file = potential_metadata + # Second try: For batch folders with images.zarr subdirectory, + # look for images_metadata.parquet in parent directory + elif zarr_path.name == "images.zarr": + potential_metadata = zarr_path.parent / "images_metadata.parquet" + if potential_metadata.exists(): + metadata_file = potential_metadata if metadata_file and metadata_file.exists(): logger.info(f"Loading metadata from {metadata_file}") @@ -121,22 +136,29 @@ def evaluate_images_in_zarr(zarr_path, cfg, top_n=1000, batch_size=1000, max_wor filenames = metadata_df["original_filename"].tolist() elif "filename" in metadata_df.columns: filenames = metadata_df["filename"].tolist() + elif "source_id" in metadata_df.columns: + # Use source_id as filename if available + filenames = metadata_df["source_id"].tolist() + logger.info("Using source_id column as filenames") else: - logger.warning("No filename column found in metadata, using indices") - filenames = [f"image_{i:06d}" for i in range(num_images)] + logger.warning( + "No filename column found in metadata, using indices with zarr prefix" + ) + filenames = [f"{zarr_prefix}__image_{i:06d}" for i in range(num_images)] except Exception as e: logger.warning(f"Failed to load metadata: {e}") - filenames = [f"image_{i:06d}" for i in range(num_images)] + logger.info("Using image indices with zarr prefix as fallback") + filenames = [f"{zarr_prefix}__image_{i:06d}" for i in range(num_images)] else: - logger.info("No metadata file found, using image indices as filenames") - filenames = [f"image_{i:06d}" for i in range(num_images)] + logger.info("No metadata file found, using image indices with zarr prefix as filenames") + filenames = [f"{zarr_prefix}__image_{i:06d}" for i in range(num_images)] # Ensure we have the right number of filenames if len(filenames) != num_images: logger.warning( - f"Filename count ({len(filenames)}) doesn't match image count ({num_images})" + f"Filename count ({len(filenames)}) doesn't match image count ({num_images}), regenerating with zarr prefix" ) - filenames = [f"image_{i:06d}" for i in range(num_images)] + filenames = [f"{zarr_prefix}__image_{i:06d}" for i in range(num_images)] model = load_model(cfg) model.eval() @@ -150,27 +172,38 @@ def evaluate_images_in_zarr(zarr_path, cfg, top_n=1000, batch_size=1000, max_wor last_log_time = start_time processed_since_last_log = 0 - for batch_start in tqdm(range(0, num_images, batch_size), desc="Processing batches"): + for batch_idx, batch_start in enumerate( + tqdm(range(0, num_images, batch_size), desc="Processing batches") + ): batch_end = min(batch_start + batch_size, num_images) batch_size_actual = batch_end - batch_start # Read batch data from Zarr batch_data = images_array[batch_start:batch_end] - # Process batch in parallel + # I/O and preprocessing in ThreadPool (returns numpy arrays) + # CUDA operations are kept on main thread to prevent memory fragmentation batch_process_start = time.time() with ThreadPoolExecutor(max_workers=max_workers) as executor: - batch_args = [(batch_data[i], transform, cfg) for i in range(batch_size_actual)] - batch_images = list(executor.map(load_and_preprocess_zarr, batch_args)) + batch_args = [(batch_data[i], cfg) for i in range(batch_size_actual)] + numpy_images = list(executor.map(load_and_preprocess_zarr, batch_args)) - # Stack images into a batch tensor and get predictions + # Tensor conversion on main thread (not in ThreadPool) stack_start = time.time() - images = torch.stack(batch_images, dim=0) + batch_tensors = [transform(img) for img in numpy_images] + images = torch.stack(batch_tensors, dim=0) + del numpy_images, batch_tensors # Free memory before CUDA ops + + # CUDA inference with explicit cleanup batch_scores, batch_imgs = process_batch_predictions(model, images) + del images # Free CUDA tensor reference scores_list.append(batch_scores) imgs_list.append(batch_imgs) + # Periodic GPU cache clearing to prevent fragmentation + clear_gpu_cache_if_needed(batch_idx) + processed_since_last_log += batch_size_actual current_time = time.time() @@ -221,13 +254,19 @@ def main(): logger.error(f"Failed to load config from {args.config_path}: {e}") sys.exit(1) + logger.info("Setting batch size") + batch_size = ( + estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction + ) + logger.info(f"Batch size set to: {batch_size}") + # Log key configuration parameters logger.debug("Configuration loaded with parameters:") logger.debug(f" Save file: {cfg.save_file}") logger.debug(f" Save path: {cfg.save_path}") logger.debug(f" Model path: {cfg.model_path}") logger.debug(f" Output directory: {cfg.output_dir}") - logger.debug(f" Image size: {cfg.size}") + logger.debug(f" Image size: {cfg.normalisation.image_size}") # Log full configuration logger.debug("Full configuration:") @@ -239,7 +278,7 @@ def main(): logger.info(f"Processing Zarr file: {args.zarr_path}") try: - evaluate_images_in_zarr(args.zarr_path, cfg, top_n=args.top_n) + evaluate_images_in_zarr(args.zarr_path, cfg, batch_size=batch_size, top_n=args.top_n) elapsed_time = time.time() - start_time logger.success(f"Script completed in {elapsed_time:.2f} seconds") except Exception as e: @@ -248,4 +287,17 @@ def main(): if __name__ == "__main__": + # Configure logging + logs_dir = os.path.join(os.path.dirname(__file__), "logs") + os.makedirs(logs_dir, exist_ok=True) + + # Remove default handler and set up file logging + logger.remove() + logger.add( + os.path.join(logs_dir, "prediction_zarr_{time}.log"), + rotation="1 MB", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", + level="DEBUG", + ) + logger.add(sys.stderr, level="INFO") main() diff --git a/prediction_utils.py b/prediction_utils.py index 8f1ce27..55226d0 100644 --- a/prediction_utils.py +++ b/prediction_utils.py @@ -15,13 +15,140 @@ import os import torch import numpy as np -from loguru import logger import pandas as pd + +from loguru import logger from turbojpeg import TurboJPEG + # Initialize TurboJPEG jpeg_decoder = TurboJPEG() +# Memory model coefficients for batch size estimation +# These were derived from empirical measurements (R² > 0.9999) +# Formula: reserved_mb = a * batch_size * image_size² + b * batch_size + c +MEMORY_COEFFICIENTS = { + "efficientnet-lite0": {"a": 0.000391, "b": 0.0637, "c": 31.67}, + "efficientnet-b1": {"a": 0.000513, "b": 0.0710, "c": 42.09}, + "efficientnet-b2": {"a": 0.000513, "b": 0.0739, "c": 47.10}, +} + +# GPU memory management constants +GPU_CACHE_CLEAR_INTERVAL = 5 # Clear GPU cache every N batches + + +def clear_gpu_cache_if_needed(batch_idx: int, interval: int = GPU_CACHE_CLEAR_INTERVAL): + """Clear GPU cache periodically to prevent memory fragmentation. + + Args: + batch_idx: Current batch index (0-based) + interval: Clear cache every N batches + """ + if torch.cuda.is_available() and (batch_idx + 1) % interval == 0: + torch.cuda.empty_cache() + + +def estimate_batch_size( + cfg, + available_vram: float = None, + safety_margin: float = 0.3, +) -> int: + """Calculate optimal batch size based on available GPU VRAM and image dimensions. + + Uses empirically-derived memory consumption model to predict the maximum batch size + that will fit in GPU memory. The model accounts for: + - Input tensor memory (scales with batch_size × image_size²) + - Intermediate activations (scales with batch_size × image_size²) + - Model parameters (constant overhead) + - CUDA memory allocator overhead (~1.65× peak allocation) + + The formula used is: + reserved_mb = a × batch_size × image_size² + b × batch_size + c + + Args: + available_vram: Available GPU VRAM in MB. If None, auto-detects + from the current CUDA device. + safety_margin: Fraction of VRAM to keep free (default: 0.2 = 20%). + Higher values are safer but reduce batch size. + model: Model architecture name. Supported values: + - 'efficientnet-lite0' (default) + - 'efficientnet-b1' + - 'efficientnet-b2' + + Returns: + int: Recommended batch size (minimum 1). + + Example: + >>> # For a 16GB GPU with 64×64 images + >>> # For a 16GB GPU with 64×64 images + >>> batch_size = get_batch_size(image_size=64, available_vram=16384) + >>> print(f"Recommended batch size: {batch_size}") + Recommended batch size: 7852 + + >>> # For 224×224 images with 20% safety margin + >>> batch_size = get_batch_size(image_size=224, safety_margin=0.2) + + Notes: + - The model was calibrated for EfficientNet architectures + - For other architectures, efficientnet-lite0 coefficients provide + a reasonable approximation + - The safety_margin accounts for memory fragmentation and other + processes using GPU memory + """ + + # Auto-detect available VRAM if not provided + if available_vram is None: + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties(torch.cuda.current_device()) + available_vram = device_props.total_memory / 1024**2 # Convert to MB + logger.debug(f"Auto-detected GPU VRAM: {available_vram:.0f} MB") + else: + # Default to 4GB if no GPU detected (conservative estimate) + available_vram = 4096 + logger.warning("No CUDA device detected, using default 4GB VRAM estimate") + + # Get coefficients for the specified model + coef = MEMORY_COEFFICIENTS.get(cfg.net, MEMORY_COEFFICIENTS["efficientnet-lite0"]) + + if cfg.net not in MEMORY_COEFFICIENTS: + logger.warning( + f"Unknown model '{cfg.net}', using efficientnet-lite0 coefficients. " + f"Supported models: {list(MEMORY_COEFFICIENTS.keys())}" + ) + + # Calculate usable VRAM after safety margin + usable_vram = available_vram * (1 - safety_margin) + + # Solve for batch_size: + # usable_vram = a * B * S² + b * B + c + # usable_vram - c = B * (a * S² + b) + # B = (usable_vram - c) / (a * S² + b) + S2 = cfg.normalisation.image_size[0] * cfg.normalisation.image_size[1] + # Use num_channels if set, otherwise fall back to normalisation.n_output_channels + num_channels = ( + cfg.num_channels + if isinstance(cfg.num_channels, int) + else cfg.normalisation.n_output_channels + ) + denominator = coef["a"] * S2 * num_channels + coef["b"] + + if denominator <= 0: + logger.warning("Invalid memory model parameters, returning minimum batch size") + return 1 + + batch_size = (usable_vram - coef["c"]) / denominator + + # Ensure batch size is at least 1 + batch_size = max(1, int(batch_size)) + + logger.debug( + f"Calculated batch size: {batch_size} " + f"(image_size={cfg.normalisation.image_size[0]}, available_vram={available_vram:.0f}MB, " + f"safety_margin={safety_margin}, model={cfg.net})" + ) + + return batch_size + def load_model(cfg): """Initialize and load the anomaly detection model. @@ -47,12 +174,18 @@ def load_model(cfg): from anomaly_match.utils.get_net_builder import get_net_builder + # Use num_channels if set, otherwise fall back to normalisation.n_output_channels + num_channels = ( + cfg.num_channels + if isinstance(cfg.num_channels, int) + else cfg.normalisation.n_output_channels + ) net_builder = get_net_builder( cfg.net, pretrained=cfg.pretrained, - in_channels=cfg.num_channels, + in_channels=num_channels, ) - model = net_builder(num_classes=2, in_channels=3) + model = net_builder(num_classes=2, in_channels=num_channels) if torch.cuda.is_available(): gpu_device = getattr(cfg, "gpu", 0) # Default to 0 if not set @@ -73,6 +206,20 @@ def load_model(cfg): ) model.load_state_dict(checkpoint["eval_model"]) + + # Load fitsbolt config from checkpoint (DotMap pickles directly) + if "fitsbolt_cfg" in checkpoint and checkpoint["fitsbolt_cfg"] is not None: + cfg.fitsbolt_cfg = checkpoint["fitsbolt_cfg"] + logger.info("Loaded fitsbolt config from model checkpoint") + elif hasattr(cfg, "fitsbolt_cfg") and cfg.fitsbolt_cfg is not None: + # Allow pre-set fitsbolt_cfg (for testing or advanced use cases) + logger.info("Using fitsbolt config already present in cfg") + else: + raise ValueError( + "Model checkpoint does not contain fitsbolt config. " + "Please retrain the model with the updated version to include normalisation settings." + ) + logger.success(f"Successfully loaded model from {model_path}") return model @@ -105,7 +252,7 @@ def save_results(cfg, all_scores, all_imgs, all_filenames, top_n): predictions_file = os.path.join(cfg.output_dir, f"all_predictions_{cfg.save_file}.npz") # Load and merge existing predictions if they exist - all_scores, all_filenames, existing_top_images = _load_existing_predictions( + all_scores, all_filenames, existing_top_images, old_top_indices = _load_existing_predictions( predictions_file, output_npy_path, all_scores, all_filenames ) @@ -114,11 +261,13 @@ def save_results(cfg, all_scores, all_imgs, all_filenames, top_n): top_scores = all_scores[top_indices] top_filenames = all_filenames[top_indices] - # Build the top images array - top_imgs = _build_top_images_array(all_scores, all_imgs, top_indices, existing_top_images) + # Ensure current batch images are in consistent HWC format BEFORE building top array + all_imgs = _ensure_consistent_image_format(all_imgs) - # Ensure images are in consistent format - top_imgs = _ensure_consistent_image_format(top_imgs) + # Build the top images array + top_imgs = _build_top_images_array( + all_scores, all_imgs, top_indices, existing_top_images, old_top_indices + ) logger.debug( f"Top images shape: {top_imgs.shape}, dtype: {top_imgs.dtype}, range: [{top_imgs.min()}, {top_imgs.max()}]" @@ -160,11 +309,12 @@ def _load_existing_predictions( current_filenames (np.ndarray): Filenames from the current batch. Returns: - tuple: (merged_scores, merged_filenames, existing_top_images) + tuple: (merged_scores, merged_filenames, existing_top_images, old_top_indices) """ existing_scores = [] existing_filenames = [] existing_top_images = None + old_top_indices = None # Load existing predictions if available if os.path.exists(predictions_file): @@ -173,6 +323,10 @@ def _load_existing_predictions( existing_scores = data["scores"] existing_filenames = data["filenames"] + # Calculate the old top indices from existing scores + old_top_indices = np.argsort(existing_scores)[::-1] + logger.debug(f"Calculated old top indices shape: {old_top_indices.shape}") + # Also load existing top images if they exist if os.path.exists(output_npy_path): logger.info("Loading existing top images for preservation") @@ -185,12 +339,14 @@ def _load_existing_predictions( logger.info( f"Combined {len(existing_scores)} existing and {len(current_scores)} new predictions" ) - return merged_scores, merged_filenames, existing_top_images + return merged_scores, merged_filenames, existing_top_images, old_top_indices - return current_scores, current_filenames, existing_top_images + return current_scores, current_filenames, existing_top_images, old_top_indices -def _build_top_images_array(all_scores, current_batch_imgs, top_indices, existing_top_images): +def _build_top_images_array( + all_scores, current_batch_imgs, top_indices, existing_top_images, old_top_indices +): """Build an array of top images from current batch and existing images. This function handles the complex logic of selecting images either from the current batch @@ -201,6 +357,7 @@ def _build_top_images_array(all_scores, current_batch_imgs, top_indices, existin current_batch_imgs (np.ndarray): Images from the current batch only. top_indices (np.ndarray): Indices of top scoring images in the combined dataset. existing_top_images (np.ndarray or None): Previously saved top images. + old_top_indices (np.ndarray or None): Indices of old top results before merging. Returns: np.ndarray: Array of top images. @@ -209,8 +366,15 @@ def _build_top_images_array(all_scores, current_batch_imgs, top_indices, existin current_batch_start = len(all_scores) - len(current_batch_imgs) current_batch_global_indices = set(range(current_batch_start, len(all_scores))) + # Create a mapping from old global index to position in existing_top_images + old_idx_to_position = {} + if old_top_indices is not None and existing_top_images is not None: + for position, global_idx in enumerate(old_top_indices[: len(existing_top_images)]): + old_idx_to_position[global_idx] = position + # Collect images for each top index top_img_list = [] + missing_images = [] for i, global_idx in enumerate(top_indices): # Case 1: This top result is from the current batch @@ -219,29 +383,37 @@ def _build_top_images_array(all_scores, current_batch_imgs, top_indices, existin top_img_list.append(current_batch_imgs[batch_idx]) # Case 2: This top result is from a previous batch - elif existing_top_images is not None and i < len(existing_top_images): - # Use the existing top image at this position - top_img_list.append(existing_top_images[i]) - - # First check if all images have the same shape - shapes = [img.shape for img in top_img_list] - if len(set(shapes)) > 1: - # If different shapes, resize all to the first image's shape - logger.warning(f"Inconsistent image shapes detected: {set(shapes)}, standardizing") - reference_shape = top_img_list[0].shape - for i, img in enumerate(top_img_list): - if img.shape != reference_shape: - # Simple resize by zero-padding or cropping - fixed_img = np.zeros(reference_shape, dtype=np.uint8) - # Copy as much of the original image as will fit - slices = tuple(slice(0, min(s, rs)) for s, rs in zip(img.shape, reference_shape)) - if len(reference_shape) == 3: # For 3D arrays (HWC) - fixed_img[slices[0], slices[1], slices[2]] = img[slices] - else: # For other dimensions - fixed_img[slices] = img[slices] - top_img_list[i] = fixed_img - - # Now convert to numpy array - each element is guaranteed to have the same shape + elif global_idx in old_idx_to_position: + # Use the existing top image at the correct position + old_position = old_idx_to_position[global_idx] + top_img_list.append(existing_top_images[old_position]) + + # Case 3: This image is from a previous batch but wasn't in the old top_N + else: + missing_images.append((i, global_idx)) + # Create a placeholder black image + if len(top_img_list) > 0: + placeholder = np.zeros_like(top_img_list[0]) + else: + # Fallback: create a minimal placeholder + placeholder = np.zeros((64, 64, 3), dtype=np.uint8) + top_img_list.append(placeholder) + + if missing_images: + logger.warning( + f"Could not retrieve {len(missing_images)} images from previous batches " + f"(they were not in the previous top_N). Using placeholder images. " + f"First few missing: {missing_images[:5]}" + ) + + # Verify we have the expected number of images + if len(top_img_list) != len(top_indices): + raise ValueError( + f"Image count mismatch: expected {len(top_indices)} images, " + f"but collected {len(top_img_list)}" + ) + + # Convert to numpy array - all images should have the same shape by now return np.stack(top_img_list) @@ -285,6 +457,8 @@ def process_batch_predictions(model, images, original_images=None): the original images (if provided) or convert the tensor images back to uint8 format suitable for saving. + Note: Includes explicit CUDA tensor cleanup to prevent GPU memory fragmentation. + Args: model (torch.nn.Module): The neural network model for anomaly detection. images (torch.Tensor): Preprocessed tensor images for model inference. @@ -303,16 +477,25 @@ def process_batch_predictions(model, images, original_images=None): logits = model(images) batch_scores = torch.nn.functional.softmax(logits, dim=-1)[:, 1].cpu().numpy() + # Explicit cleanup of CUDA tensors to prevent memory fragmentation + del logits + # Return original uint8 images if provided, otherwise convert tensor back if original_images is not None: + # Clean up CUDA tensor before returning + del images return batch_scores, original_images else: - # Convert tensor images back to uint8 for saving - images_np = images.cpu().numpy() + # Convert tensor images back to uint8 for saving with explicit cleanup + images_np = images.detach().cpu().numpy() + del images # Free CUDA tensor + if images_np.max() <= 1.0: # Tensor format [0,1] -> uint8 [0,255] images_uint8 = (images_np * 255.0).clip(0, 255).astype(np.uint8) else: # Assume already in correct range images_uint8 = images_np.clip(0, 255).astype(np.uint8) + + del images_np # Free intermediate array return batch_scores, images_uint8 diff --git a/pyproject.toml b/pyproject.toml index 8685c14..6ce65ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "anomaly_match" -version = "1.1.0" +version = "1.2.0" description = "A tool for anomaly detection in images using semi-supervised and active learning with a GUI" readme = "README.md" license = { file = "LICENSE.txt" } @@ -44,18 +44,21 @@ dependencies = [ "matplotlib", "numpy", "opencv-python-headless", - "pandas", + "pandas<3", "pyturbojpeg", "scikit-learn", "scikit-image", + "fitsbolt>=0.1.6", "toml", "torch", + "torchvision", "tqdm", "zarr>=3.0.0b0", + "cutana>=0.2.1", ] [project.optional-dependencies] -dev = ["pytest", "pytest-cov", "black", "flake8", "mypy"] +dev = ["pytest", "pytest-cov", "black", "flake8", "mypy", "vulture>=2.10"] [tool.setuptools] packages = ["anomaly_match"] @@ -64,7 +67,7 @@ packages = ["anomaly_match"] "" = "." [tool.black] -line-length = 88 +line-length = 100 target-version = ['py311'] [tool.pytest.ini_options] @@ -72,3 +75,8 @@ testpaths = ["tests"] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] + +[tool.vulture] +min_confidence = 60 +paths = ["anomaly_match"] +exclude = ["tests/", "docs/", "examples/", "paper_scripts/"] diff --git a/tests/cfg_validation_test.py b/tests/cfg_validation_test.py index 2bbf092..f4b36da 100644 --- a/tests/cfg_validation_test.py +++ b/tests/cfg_validation_test.py @@ -23,6 +23,8 @@ def test_default_cfg_validation(caplog): """Test that the default configuration passes validation without warnings.""" # Get default config cfg = get_default_cfg() + # image_size has no default - must be set by user + cfg.normalisation.image_size = [224, 224] # Run validation validate_config(cfg) diff --git a/tests/dataset_test.py b/tests/dataset_test.py index 2f460e3..1a93bbc 100644 --- a/tests/dataset_test.py +++ b/tests/dataset_test.py @@ -28,11 +28,12 @@ def base_config(): """Fixture providing base configuration for tests.""" cfg = am.get_default_cfg() cfg.data_dir = "tests/test_data/" - cfg.size = [64, 64] + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 cfg.num_train_iter = 2 cfg.test_ratio = 0.5 cfg.N_to_load = 10 - cfg.fits_extension = None + cfg.normalisation.fits_extension = None cfg.label_file = None return cfg @@ -59,8 +60,8 @@ def sample_data(): def multi_extension_dataset(): """Create a temporary directory with images of different extensions for testing.""" with tempfile.TemporaryDirectory() as temp_dir: - # Create test images with different extensions - extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff"] + # Create test images with different extensions - only supported formats + extensions = [".jpg", ".jpeg", ".png", ".tiff"] test_images = [] # Create a simple test image @@ -93,7 +94,7 @@ def test_anomaly_detection_dataset_initialization(base_config): dataset = AnomalyDetectionDataset(base_config) assert dataset is not None - assert dataset.size == base_config.size + assert dataset.size == base_config.normalisation.image_size assert dataset.test_ratio == base_config.test_ratio assert dataset.num_channels == 3 assert hasattr(dataset, "data_dict") diff --git a/tests/file_io_test.py b/tests/file_io_test.py index 4d8599a..d90df1d 100644 --- a/tests/file_io_test.py +++ b/tests/file_io_test.py @@ -19,8 +19,74 @@ get_image_names_from_folder, get_image_paths_from_folder, ) -from anomaly_match.data_io.load_images import read_and_resize_image, load_images_parallel -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from anomaly_match.data_io.load_images import ( + load_and_process_wrapper, + load_and_process_single_wrapper, +) +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod + + +def _load_image_with_fitsbolt(filepath, cfg): + """Helper function to load image using fitsbolt with AnomalyMatch config.""" + # Use the new wrapper function instead of directly using fitsbolt + return load_and_process_single_wrapper(filepath, cfg, desc="test loading", show_progress=False) + + +def _load_multiple_images_with_fitsbolt(filepaths, cfg, show_progress=False): + """Helper function to load multiple images using fitsbolt with AnomalyMatch config.""" + # Use the new wrapper function instead of directly using fitsbolt + return load_and_process_wrapper( + filepaths, + cfg, + desc="test loading multiple", + show_progress=show_progress, + ) + + +def _update_config(cfg, **kwargs): + """Update both the main config and fitsbolt config with the given parameters.""" + for key, value in kwargs.items(): + setattr(cfg, key, value) + if key == "size": + cfg.normalisation.image_size = value + elif key == "fits_extension": + cfg.normalisation.fits_extension = value + # When setting fits_extension to a list, automatically set up appropriate channel_combination + # only if channel_combination hasn't been explicitly set + if isinstance(value, list) and len(value) > 1: + # Check if channel_combination is being explicitly set in this call + if "channel_combination" not in kwargs: + import numpy as np + + n_channels = min(3, len(value)) # Limit to 3 output channels (RGB) + n_extensions = len(value) + if n_extensions == n_channels: + cfg.normalisation.channel_combination = np.eye(n_channels) + else: + # Create a combination matrix that uses the first n_channels extensions + combination = np.zeros((n_channels, n_extensions)) + for i in range(n_channels): + combination[i, i] = 1.0 + cfg.normalisation.channel_combination = combination + elif key == "channel_combination": + # Convert dictionary format to numpy array format for fitsbolt + cfg.normalisation.channel_combination = value + elif key == "normalisation_method": + cfg.normalisation.normalisation_method = value + elif key == "interpolation_order": + cfg.normalisation.interpolation_order = value + elif key.startswith("normalisation."): + # Handle nested attributes in normalisation + norm_key = key.split(".")[1] + setattr(cfg.normalisation, norm_key, value) + fitsbolt_key = f"norm_{norm_key}" + setattr(cfg.normalisation, fitsbolt_key, value) + + # Clear the cached fitsbolt config to force regeneration with new parameters + if hasattr(cfg, "fitsbolt_cfg"): + delattr(cfg, "fitsbolt_cfg") + + return cfg class TestImageIO: @@ -35,12 +101,10 @@ def test_config(self): cfg = get_default_cfg() # Override for test specific settings - cfg.size = None # Default no resize (only the rgba test covers this) - cfg.fits_extension = None # Default first extension - cfg.normalisation.maximum_value = None - cfg.normalisation.minimum_value = None - cfg.normalisation.crop_for_maximum_value = None - cfg.normalisation.log_calculate_minimum_value = False + # Add fitsbolt configuration needed by the wrapper functions + cfg.normalisation.image_size = None + cfg.normalisation.n_output_channels = 3 + return cfg @classmethod @@ -204,50 +268,47 @@ def test_get_image_paths_from_folder(self): assert len(recursive_paths) > len(image_paths) assert any("nested" in path for path in recursive_paths) - def test_read_and_resize_image_rgb(self, test_config): - """Test reading and resizing an RGB image.""" + def test_load_and_process_images_rgb(self, test_config): + """Test loading and processing an RGB image.""" # Test loading an RGB image - img = read_and_resize_image(self.rgb_path, cfg=test_config) + img = _load_image_with_fitsbolt(self.rgb_path, cfg=test_config) assert img.shape[2] == 3 # Should be RGB assert img.dtype == np.uint8 # Test with resizing - test_config.size = (50, 50) - resized_img = read_and_resize_image(self.rgb_path, cfg=test_config) - assert resized_img.shape[:2] == test_config.size + _update_config(test_config, size=(50, 50)) + resized_img = _load_image_with_fitsbolt(self.rgb_path, cfg=test_config) + assert resized_img.shape[:2] == test_config.normalisation.image_size assert resized_img.shape[2] == 3 # Still RGB - def test_read_and_resize_image_grayscale(self, test_config): - """Test reading and resizing a grayscale image.""" - # Test loading a grayscale image and converting to RGB - img = read_and_resize_image(self.gray_path, cfg=test_config, convert_to_rgb=True) + def test_load_and_process_images_grayscale(self, test_config): + """Test loading and processing a grayscale image.""" + # Test loading a grayscale image (fitsbolt automatically converts to RGB) + img = _load_image_with_fitsbolt(self.gray_path, cfg=test_config) assert img.shape[2] == 3 # Should be converted to RGB + assert img.dtype == np.uint8 - # Test loading a grayscale image without converting to RGB - img = read_and_resize_image(self.gray_path, cfg=test_config, convert_to_rgb=False) - assert len(img.shape) == 2 or img.shape[2] == 1 # Should remain grayscale - - def test_read_and_resize_image_rgba(self, test_config): - """Test reading and resizing an RGBA image.""" - # Test loading an RGBA image and converting to RGB - img = read_and_resize_image(self.rgba_path, cfg=test_config) + def test_load_and_process_images_rgba(self, test_config): + """Test loading and processing an RGBA image.""" + # Test loading an RGBA image (fitsbolt automatically converts to RGB) + img = _load_image_with_fitsbolt(self.rgba_path, cfg=test_config) assert img.shape[2] == 3 # Alpha channel should be removed # Test resizing target_size = (75, 75) - test_config.size = target_size - resized_img = read_and_resize_image(self.rgba_path, cfg=test_config) + _update_config(test_config, size=target_size) + resized_img = _load_image_with_fitsbolt(self.rgba_path, cfg=test_config) assert resized_img.shape[:2] == target_size assert resized_img.shape[2] == 3 # RGB # Test with fully transparent image - transparent_img = read_and_resize_image(self.transparent_path, cfg=test_config) + transparent_img = _load_image_with_fitsbolt(self.transparent_path, cfg=test_config) assert transparent_img.shape[2] == 3 # Should still be RGB # The green channel should still be present even though the pixels were transparent assert np.any( transparent_img[:, :, 1] > 0 ), "Green channel data lost in transparent image" # Test complex RGBA with gradient alpha - complex_rgba_img = read_and_resize_image(self.complex_rgba_path, cfg=test_config) + complex_rgba_img = _load_image_with_fitsbolt(self.complex_rgba_path, cfg=test_config) assert complex_rgba_img.shape[2] == 3 # Should be RGB # Check that gradients are preserved in RGB channels # Note: Using 195 as threshold since some image formats/conversions may reduce max values slightly @@ -283,7 +344,7 @@ def test_rgba_to_rgb_conversion_values(self, test_config): Image.fromarray(test_rgba).save(test_path) # Load and convert to RGB - rgb_img = read_and_resize_image(test_path, cfg=test_config, convert_to_rgb=True) + rgb_img = _load_image_with_fitsbolt(test_path, cfg=test_config) # Test shape and type assert rgb_img.shape == (height, width, 3), "RGBA should convert to RGB shape" @@ -341,7 +402,7 @@ def test_image_value_preservation(self, test_config): Image.fromarray(test_values).save(values_path) # Load the image and check value preservation - loaded_img = read_and_resize_image(values_path, cfg=test_config) + loaded_img = _load_image_with_fitsbolt(values_path, cfg=test_config) # Check overall shape and type assert loaded_img.shape == (100, 100, 3), "Shape should be preserved" @@ -365,22 +426,18 @@ def test_image_value_preservation(self, test_config): assert g_diff <= max_allowed_diff, f"Green value not preserved at ({i},{j})" assert b_diff <= max_allowed_diff, f"Blue value not preserved at ({i},{j})" - def test_read_and_resize_image_fits(self, test_config): + def test_load_image_with_fitsbolt_fits(self, test_config): """Test reading and resizing a FITS image.""" # Test loading a FITS image - img = read_and_resize_image(self.fits_path, cfg=test_config) + img = _load_image_with_fitsbolt(self.fits_path, cfg=test_config) assert img.shape[2] == 3 # Should be converted to RGB assert img.dtype == np.uint8 # Should be converted to uint8 # Test loading a multi-channel FITS image - multi_img = read_and_resize_image(self.multi_fits_path, cfg=test_config) - assert multi_img.shape[2] == 3 # Should have 3 channels + with pytest.raises(ValueError): + _load_image_with_fitsbolt(self.multi_fits_path, cfg=test_config) + # should fail as this is a 3d fits extension - # Test with resizing - target_size = (60, 60) - test_config.size = target_size - resized_fits = read_and_resize_image(self.multi_fits_path, cfg=test_config) - assert resized_fits.shape[:2] == target_size with fits.open(self.four_dim_fits_path) as hdul: if hdul[0].data.ndim == 4: # Extract a 3D slice from the 4D array (first element of first dimension) @@ -390,22 +447,19 @@ def test_read_and_resize_image_fits(self, test_config): fits.writeto(slice_path, slice_data, overwrite=True) # Now test the 3D slice which should load correctly - slice_img = read_and_resize_image(slice_path, cfg=test_config) - assert slice_img.shape[2] == 3, "FITS slice data should be converted to RGB" - assert slice_img.dtype == np.uint8, "Should be converted to uint8" - - # Test with specific dimensions - target_size = (40, 40) - test_config.size = target_size - resized_slice = read_and_resize_image(slice_path, cfg=test_config) - assert resized_slice.shape == (40, 40, 3), "Resizing failed for 4D FITS slice" + # fitsbolt has no 3D fits support this will fail + with pytest.raises(ValueError): + _load_image_with_fitsbolt(slice_path, cfg=test_config) + + # If this test fails in the future, implement further testing + # changing target size for example else: # If it's not 4D, we'll skip this specific assertion pytest.skip("FITS file doesn't have 4D data structure") # Test value normalization with extreme values - test_config.size = None # No resizing for this test - extreme_img = read_and_resize_image(self.extreme_fits_path, cfg=test_config) + _update_config(test_config, size=None) # No resizing for this test + extreme_img = _load_image_with_fitsbolt(self.extreme_fits_path, cfg=test_config) assert extreme_img.shape[2] == 3, "Should be converted to RGB" assert extreme_img.dtype == np.uint8, "Should be converted to uint8" assert np.min(extreme_img) >= 0, "Minimum value should be normalized to at least 0" @@ -420,9 +474,9 @@ def test_read_and_resize_image_fits(self, test_config): def test_fits_extension_parameter(self, test_config): """Test the fits_extension parameter for FITS files.""" # Test explicit extension 0 (should be the same as default) - img_default = read_and_resize_image(self.fits_path, cfg=test_config) - test_config.fits_extension = 0 - img_ext0 = read_and_resize_image(self.fits_path, cfg=test_config) + img_default = _load_image_with_fitsbolt(self.fits_path, cfg=test_config) + _update_config(test_config, fits_extension=0) + img_ext0 = _load_image_with_fitsbolt(self.fits_path, cfg=test_config) assert np.array_equal(img_default, img_ext0) # For multi_fits_path, we created it with 3 channels in the test setup @@ -431,8 +485,8 @@ def test_fits_extension_parameter(self, test_config): # Opening directly to check the contents with fits.open(self.multi_fits_path) as hdul: if len(hdul) > 1: # Only test if there are multiple extensions - test_config.fits_extension = 1 - img_ext1 = read_and_resize_image(self.multi_fits_path, cfg=test_config) + _update_config(test_config, fits_extension=1) + img_ext1 = _load_image_with_fitsbolt(self.multi_fits_path, cfg=test_config) # Should be different from extension 0 assert not np.array_equal(img_default, img_ext1) @@ -464,12 +518,12 @@ def test_fits_extension_string_parameter(self, test_config): hdul.writeto(named_fits_path, overwrite=True) # Now test accessing by string name - test_config.fits_extension = "PRIMARY" - img_primary = read_and_resize_image(named_fits_path, cfg=test_config) - test_config.fits_extension = "SCIENCE" - img_science = read_and_resize_image(named_fits_path, cfg=test_config) - test_config.fits_extension = "ERROR" - img_error = read_and_resize_image(named_fits_path, cfg=test_config) + _update_config(test_config, fits_extension="PRIMARY") + img_primary = _load_image_with_fitsbolt(named_fits_path, cfg=test_config) + _update_config(test_config, fits_extension="SCIENCE") + img_science = _load_image_with_fitsbolt(named_fits_path, cfg=test_config) + _update_config(test_config, fits_extension="ERROR") + img_error = _load_image_with_fitsbolt(named_fits_path, cfg=test_config) # Verify that each extension has different data assert not np.array_equal(img_primary, img_science) @@ -477,12 +531,12 @@ def test_fits_extension_string_parameter(self, test_config): assert not np.array_equal(img_science, img_error) # Test accessing using index vs name (should be equivalent) - test_config.fits_extension = 0 - img_primary_idx = read_and_resize_image(named_fits_path, cfg=test_config) - test_config.fits_extension = 1 - img_science_idx = read_and_resize_image(named_fits_path, cfg=test_config) - test_config.fits_extension = 2 - img_error_idx = read_and_resize_image(named_fits_path, cfg=test_config) + _update_config(test_config, fits_extension=0) + img_primary_idx = _load_image_with_fitsbolt(named_fits_path, cfg=test_config) + _update_config(test_config, fits_extension=1) + img_science_idx = _load_image_with_fitsbolt(named_fits_path, cfg=test_config) + _update_config(test_config, fits_extension=2) + img_error_idx = _load_image_with_fitsbolt(named_fits_path, cfg=test_config) assert np.array_equal(img_primary, img_primary_idx) assert np.array_equal(img_science, img_science_idx) @@ -492,21 +546,24 @@ def test_fits_extension_error_handling(self, test_config): """Test error handling for invalid FITS extensions.""" # Test out-of-bounds extension index with pytest.raises(IndexError): - test_config.fits_extension = 999 - read_and_resize_image(self.fits_path, cfg=test_config) + _update_config(test_config, fits_extension=999) + _load_image_with_fitsbolt(self.fits_path, cfg=test_config) # Test negative extension index with pytest.raises(IndexError): - test_config.fits_extension = -1 - read_and_resize_image(self.fits_path, cfg=test_config) + _update_config(test_config, fits_extension=-1) + _load_image_with_fitsbolt(self.fits_path, cfg=test_config) - def test_load_images_parallel(self, test_config): + def test_load_and_process_images(self, test_config): """Test loading multiple images in parallel.""" # Make a copy of the file list to avoid permission issues - test_files = self.image_files[:5] # Include a variety of image types - + test_files = self.image_files[ + 2:5 + ] # Include a variety of image with the same number of channels # Test basic functionality with multiple file types - results = load_images_parallel(test_files, cfg=test_config, show_progress=False) + results = _load_multiple_images_with_fitsbolt( + test_files, cfg=test_config, show_progress=False + ) # Should return all the files we passed assert len(results) == len(test_files) @@ -519,30 +576,16 @@ def test_load_images_parallel(self, test_config): # Test with resizing target_size = (30, 30) - test_config.size = target_size - resized_results = load_images_parallel(test_files, cfg=test_config, show_progress=False) + _update_config(test_config, size=target_size) + resized_results = _load_multiple_images_with_fitsbolt( + test_files, cfg=test_config, show_progress=False + ) # Check that all resized images have the correct size for filepath, img in resized_results: assert img.shape[:2] == target_size, f"Image {filepath} wasn't resized correctly" assert img.shape[2] == 3, f"Image {filepath} should be RGB after resize" - # Test with a custom transform function - def custom_transform(image): - # Simple transform that inverts the image - return 255 - image - - transformed_results = load_images_parallel( - test_files, cfg=test_config, transform=custom_transform, show_progress=False - ) - - # Test that the transform was applied - for i, (filepath, img) in enumerate(transformed_results): - # Compare with original loaded image - original_img = read_and_resize_image(filepath, cfg=test_config) - # Check that some pixels are different (transformation was applied) - assert not np.array_equal(original_img, img), f"Transform not applied to {filepath}" - def test_fits_multiple_extensions(self, test_config): """Test loading and combining multiple FITS extensions.""" # Create a test FITS file with multiple extensions of the same shape @@ -572,18 +615,18 @@ def test_fits_multiple_extensions(self, test_config): # Test loading with list of integer indices int_indices = [0, 1, 2] - test_config.fits_extension = int_indices - combined_img1 = read_and_resize_image(multi_ext_path, cfg=test_config) + _update_config(test_config, fits_extension=int_indices) + combined_img1 = _load_image_with_fitsbolt(multi_ext_path, cfg=test_config) # Test loading with list of string names str_names = ["PRIMARY", "EXT1", "EXT2"] - test_config.fits_extension = str_names - combined_img2 = read_and_resize_image(multi_ext_path, cfg=test_config) + _update_config(test_config, fits_extension=str_names) + combined_img2 = _load_image_with_fitsbolt(multi_ext_path, cfg=test_config) # Test loading with mixed list of indices and names mixed_list = [0, "EXT2", "EXT1"] - test_config.fits_extension = mixed_list - combined_img3 = read_and_resize_image(multi_ext_path, cfg=test_config) + _update_config(test_config, fits_extension=mixed_list) + combined_img3 = _load_image_with_fitsbolt(multi_ext_path, cfg=test_config) # All should result in RGB images with shape (50, 50, 3) assert combined_img1.shape == (50, 50, 3), "Combined image should have shape (50, 50, 3)" @@ -620,46 +663,18 @@ def test_fits_multiple_extensions(self, test_config): # Test that combining different shapes raises a ValueError with pytest.raises(ValueError) as e_info: - test_config.fits_extension = [0, 1] - read_and_resize_image(diff_shapes_path, cfg=test_config) + _update_config(test_config, fits_extension=[0, 1], channel_combination=None) + _load_image_with_fitsbolt(diff_shapes_path, cfg=test_config) - # Validate the error message contains information about the shapes - assert "different shapes" in str(e_info.value), "Error should mention different shapes" - assert "(50, 50)" in str(e_info.value), "Error should include the first shape" - assert "(60, 40)" in str(e_info.value), "Error should include the second shape" - - # Test with more than 3 extensions (should use only first 3 as RGB channels) - many_ext_path = os.path.join(self.test_dir, "many_extensions.fits") - - # Create 5 extensions with same shape but different patterns - hdu_list = [fits.PrimaryHDU(np.ones((40, 40), dtype=np.float32) * 0.1)] - for i in range(4): - data = np.ones((40, 40), dtype=np.float32) * (i + 1) * 0.2 - data[10 + i * 5 : 20 + i * 5, 10 + i * 5 : 20 + i * 5] = ( - 0.9 # Different pattern in each - ) - hdu = fits.ImageHDU(data) - hdu.header["EXTNAME"] = f"EXT{i + 1}" - hdu_list.append(hdu) - - # Create FITS file with 5 extensions - many_hdul = fits.HDUList(hdu_list) - many_hdul.writeto(many_ext_path, overwrite=True) - - # Try loading all 5 extensions (should use only first 3) - with pytest.warns(UserWarning): # Should warn about using only first 3 - test_config.fits_extension = [0, 1, 2, 3, 4] - five_ext_img = read_and_resize_image(many_ext_path, cfg=test_config) - - # Should still be RGB image with 3 channels - assert five_ext_img.shape == ( - 40, - 40, - 3, - ), "Image should have 3 channels even with >3 extensions" + # Validate the error message contains information about channel mismatch + error_message = str(e_info.value) + assert ( + "channel" in error_message.lower() or "extension" in error_message.lower() + ), f"Error should mention channel or extension mismatch: {error_message}" + # could expand test if needed with more extensions - def test_load_images_parallel_fits_extension(self, test_config): - """Test that load_images_parallel correctly passes the fits_extension parameter.""" + def test_load_and_process_images_fits_extension(self, test_config): + """Test that load_and_process_images correctly passes the fits_extension parameter.""" # Create a test FITS file with multiple extensions of the same shape multi_ext_path = os.path.join(self.test_dir, "multi_extension_parallel.fits") @@ -692,17 +707,23 @@ def test_load_images_parallel_fits_extension(self, test_config): # List of files to test test_files = [multi_ext_path, multi_ext_path2] - # Test load_images_parallel with a single extension index - test_config.fits_extension = 0 - results_ext0 = load_images_parallel(test_files, cfg=test_config, show_progress=False) + # Test load_and_process_images with a single extension index + _update_config(test_config, fits_extension=0) + results_ext0 = _load_multiple_images_with_fitsbolt( + test_files, cfg=test_config, show_progress=False + ) - # Test load_images_parallel with a different extension index - test_config.fits_extension = 1 - results_ext1 = load_images_parallel(test_files, cfg=test_config, show_progress=False) + # Test load_and_process_images with a different extension index + _update_config(test_config, fits_extension=1) + results_ext1 = _load_multiple_images_with_fitsbolt( + test_files, cfg=test_config, show_progress=False + ) - # Test load_images_parallel with a list of extensions - test_config.fits_extension = [0, 1, 2] - results_combined = load_images_parallel(test_files, cfg=test_config, show_progress=False) + # Test load_and_process_images with a list of extensions + _update_config(test_config, fits_extension=[0, 1, 2]) + results_combined = _load_multiple_images_with_fitsbolt( + test_files, cfg=test_config, show_progress=False + ) # Verify that all files were loaded assert len(results_ext0) == len(test_files) @@ -746,8 +767,10 @@ def test_load_images_parallel_fits_extension(self, test_config): ), "Combined extensions should differ from single extension" # Also test with string extension names - test_config.fits_extension = ["PRIMARY", "EXT1", "EXT2"] - results_named = load_images_parallel(test_files, cfg=test_config, show_progress=False) + _update_config(test_config, fits_extension=["PRIMARY", "EXT1", "EXT2"]) + results_named = _load_multiple_images_with_fitsbolt( + test_files, cfg=test_config, show_progress=False + ) img_named = results_named[0][1] # Should be identical to using numeric indices [0, 1, 2] @@ -772,15 +795,15 @@ def test_image_normalisation(self, test_config): Image.fromarray(test_values).save(test_path) # Test with no normalisation (default) - test_config.normalisation_method = NormalisationMethod.CONVERSION_ONLY - img_none = read_and_resize_image(test_path, cfg=test_config) + _update_config(test_config, normalisation_method=NormalisationMethod.CONVERSION_ONLY) + img_none = _load_image_with_fitsbolt(test_path, cfg=test_config) assert np.array_equal( img_none, test_values ), "NONE normalisation should preserve original values" # Test with LOG normalisation - test_config.normalisation_method = NormalisationMethod.LOG - img_log = read_and_resize_image(test_path, cfg=test_config) + _update_config(test_config, normalisation_method=NormalisationMethod.LOG) + img_log = _load_image_with_fitsbolt(test_path, cfg=test_config) assert not np.array_equal(img_log, test_values), "LOG normalisation should modify values" # Log normalisation should enhance darker regions # Check that dark regions (low values) have relatively higher values after log normalisation @@ -790,8 +813,8 @@ def test_image_normalisation(self, test_config): assert ratio_dark > 1, "LOG normalisation should enhance dark regions" # Test with ZSCALE normalisation - test_config.normalisation_method = NormalisationMethod.ZSCALE - img_zscale = read_and_resize_image(test_path, cfg=test_config) + _update_config(test_config, normalisation_method=NormalisationMethod.ZSCALE) + img_zscale = _load_image_with_fitsbolt(test_path, cfg=test_config) assert not np.array_equal( img_zscale, test_values ), "ZSCALE normalisation should modify values" @@ -842,8 +865,8 @@ def test_image_interpolation_orders(self, test_config): Image.fromarray(large_img).save(large_path) # Set target size to 100x100 for both images - test_config.size = (100, 100) - test_config.normalisation_method = NormalisationMethod.CONVERSION_ONLY + _update_config(test_config, size=(100, 100)) + _update_config(test_config, normalisation_method=NormalisationMethod.CONVERSION_ONLY) # Define the expected center pixel colors for each quadrant after resizing # We're checking center points of each quadrant to avoid edge effects @@ -859,10 +882,10 @@ def test_image_interpolation_orders(self, test_config): # Check each interpolation order (0-5) for order in range(6): - test_config.interpolation_order = order + _update_config(test_config, interpolation_order=order) # Resize small image (40x40 → 100x100) - upsampling - resized_small = read_and_resize_image(small_path, cfg=test_config) + resized_small = _load_image_with_fitsbolt(small_path, cfg=test_config) upscaled_results.append(resized_small) assert resized_small.shape == ( 100, @@ -874,7 +897,7 @@ def test_image_interpolation_orders(self, test_config): ), f"Resized small image should be uint8 with order {order}" # Resize large image (200x200 → 100x100) - downsampling - resized_large = read_and_resize_image(large_path, cfg=test_config) + resized_large = _load_image_with_fitsbolt(large_path, cfg=test_config) assert resized_large.shape == ( 100, 100, @@ -978,3 +1001,65 @@ def test_image_interpolation_orders(self, test_config): assert not np.array_equal( upscaled_results[i - 1], upscaled_results[i] ), f"Order {i - 1} and order {i} interpolation should produce different results" + + def test_fits_combination_configurations(self, test_config): + """Test different configurations of the fits_combination dictionary.""" + # Create a FITS file with 4 extensions for testing + test_data = [] + for i in range(4): + data = np.zeros((31, 31), dtype=np.float32) + data[20:30, 20:30] = float(i + 1) # Different value in each extension + test_data.append(data) + + multi_ext_fits_path = os.path.join(self.test_dir, "multi_ext.fits") + primary_hdu = fits.PrimaryHDU(test_data[0]) + hdul = fits.HDUList([primary_hdu]) + for i, data in enumerate(test_data[1:], 1): + hdul.append(fits.ImageHDU(data, name=f"EXT{i}")) + hdul.writeto(multi_ext_fits_path, overwrite=True) + self.image_files.append(multi_ext_fits_path) # Add to cleanup list + _update_config(test_config, normalisation_method=NormalisationMethod.CONVERSION_ONLY) + # Test 1: None combination dictionary (should use first three extensions as default) + _update_config(test_config, fits_extension=[0, 1, 2]) # Use first three extensions + _update_config(test_config, channel_combination=None) + img_empty = _load_image_with_fitsbolt(multi_ext_fits_path, cfg=test_config) + assert img_empty.shape[2] == 3 + # Check if the first three extensions are used directly + assert np.any(img_empty[:, :, :] > 0) # Should have non-zero values + + # Test 2: Combination with all 4 extensions + _update_config(test_config, fits_extension=[0, 1, 2, 3]) + _update_config( + test_config, + channel_combination=np.array( + [ + [0.5, 0.5, 0, 0], # Combine ext 0 and 1 + [0, 0, 1, 0], # Just ext 2 + [0, 0, 0, 1], # Just ext 3 + ] + ), + ) + img_four = _load_image_with_fitsbolt(multi_ext_fits_path, cfg=test_config) + assert img_four.shape[2] == 3 + # Check if red channel contains combination of first two extensions + assert np.any(img_four[:, :, 0] > 0) # Red channel should have data + assert np.any(img_four[:, :, 1] > 0) # Green channel should have data + assert np.any(img_four[:, :, 2] > 0) # Blue channel should have data + + # Test 3: Combination with 2 extensions only + _update_config(test_config, fits_extension=[0, 1]) + _update_config( + test_config, + channel_combination=np.array( + [ + [1, 0], # Just ext 0 + [0.5, 0.5], # Mix of both + [0, 1], # Just ext 1 + ] + ), + ) + img_two = _load_image_with_fitsbolt(multi_ext_fits_path, cfg=test_config) + assert img_two.shape[2] == 3 + assert np.any(img_two[:, :, 0] > 0) # Red channel should have data + assert np.any(img_two[:, :, 1] > 0) # Green channel should have data + assert np.any(img_two[:, :, 2] > 0) # Blue channel should have data diff --git a/tests/fixmatch_test.py b/tests/fixmatch_test.py index 8b8798a..c426da5 100644 --- a/tests/fixmatch_test.py +++ b/tests/fixmatch_test.py @@ -75,7 +75,6 @@ def test_initialization(self, fixmatch_model): assert fixmatch_model.T == 0.5 assert fixmatch_model.p_cutoff == 0.95 assert fixmatch_model.lambda_u == 1.0 - assert fixmatch_model.use_hard_label is True # Check optimizer assert fixmatch_model.optimizer is not None diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 6f1e778..f9c381f 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -10,6 +10,8 @@ import tempfile import pandas as pd import pytest +import numpy as np +from PIL import Image from dotmap import DotMap from anomaly_match.datasets.AnomalyDetectionDataset import AnomalyDetectionDataset @@ -30,12 +32,14 @@ def setup_test_files(self): os.makedirs(data_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True) - # Create dummy image files + # Create dummy image files - create actual valid images instead of empty files for i in range(3): - # Create an empty file + # Create a simple test image + img_data = np.zeros((100, 100, 3), dtype=np.uint8) + # Use modulo to cycle through RGB channels (0, 1, 2) + img_data[25:75, 25:75, i % 3] = 255 # Different colored squares for each image img_path = os.path.join(data_dir, f"test_image_{i}.jpg") - with open(img_path, "w") as f: - f.write("") + Image.fromarray(img_data).save(img_path) # Create labeled_data.csv - label only the first two images label_file = os.path.join(test_dir, "labeled_data.csv") @@ -72,26 +76,26 @@ def setup_test_files(self): def test_metadata_loading(self, setup_test_files, monkeypatch): """Test that metadata is correctly loaded in AnomalyDetectionDataset.""" - # Mock image reading functions - def mock_read_and_resize(*args, **kwargs): - import numpy as np + # Mock the load_and_process_wrapper function to avoid actual image processing + def mock_load_and_process_wrapper( + filepaths, cfg, desc="Loading images", show_progress=True + ): + # Return a list of (filepath, mock_image) tuples + results = [] + for filepath in filepaths: + mock_image = np.zeros((224, 224, 3), dtype=np.uint8) + results.append((filepath, mock_image)) + return results - return np.zeros((224, 224, 3), dtype=np.uint8) - - def mock_get_image_names(dir_path, recursive=False): - return [os.path.join(dir_path, f"test_image_{i}.jpg") for i in range(3)] - - monkeypatch.setattr( - "anomaly_match.data_io.load_images.read_and_resize_image", mock_read_and_resize - ) monkeypatch.setattr( - "anomaly_match.data_io.find_images_in_folder.get_image_names_from_folder", - mock_get_image_names, + "anomaly_match.data_io.load_images.load_and_process_wrapper", + mock_load_and_process_wrapper, ) # Set up configuration paths = setup_test_files cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] cfg.data_dir = paths["data_dir"] cfg.label_file = paths["label_file"] cfg.metadata_file = paths["metadata_file"] @@ -115,21 +119,20 @@ def mock_get_image_names(dir_path, recursive=False): def test_metadata_saving_in_session(self, setup_test_files, monkeypatch): """Test that metadata is included when saving labels in Session.""" - # Mock required functions - def mock_read_and_resize(*args, **kwargs): - import numpy as np + # Mock the load_and_process_wrapper function to avoid actual image processing + def mock_load_and_process_wrapper( + filepaths, cfg, desc="Loading images", show_progress=True + ): + # Return a list of (filepath, mock_image) tuples + results = [] + for filepath in filepaths: + mock_image = np.zeros((224, 224, 3), dtype=np.uint8) + results.append((filepath, mock_image)) + return results - return np.zeros((224, 224, 3), dtype=np.uint8) - - def mock_get_image_names(dir_path, recursive=False): - return [os.path.join(dir_path, f"test_image_{i}.jpg") for i in range(3)] - - monkeypatch.setattr( - "anomaly_match.data_io.load_images.read_and_resize_image", mock_read_and_resize - ) monkeypatch.setattr( - "anomaly_match.data_io.find_images_in_folder.get_image_names_from_folder", - mock_get_image_names, + "anomaly_match.data_io.load_images.load_and_process_wrapper", + mock_load_and_process_wrapper, ) # Patch model initialization to avoid issues @@ -142,6 +145,7 @@ def mock_init_model(self): # Set up configuration paths = setup_test_files cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] cfg.data_dir = paths["data_dir"] cfg.label_file = paths["label_file"] cfg.metadata_file = paths["metadata_file"] @@ -165,35 +169,35 @@ def mock_init_model(self): assert col in saved_data.columns # Check that values were preserved - assert ( - saved_data[saved_data["filename"] == "test_image_0.jpg"]["sourceID"].values[0] - == "source_0" - ) - assert saved_data[saved_data["filename"] == "test_image_1.jpg"]["ra"].values[0] == 11.0 + test_img_0_data = saved_data[saved_data["filename"] == "test_image_0.jpg"] + assert test_img_0_data["sourceID"].values[0] == "source_0" + + test_img_1_data = saved_data[saved_data["filename"] == "test_image_1.jpg"] + assert test_img_1_data["ra"].values[0] == 11.0 def test_missing_metadata_file(self, setup_test_files, monkeypatch): """Test behavior when metadata file is specified but doesn't exist.""" - # Mock required functions - def mock_read_and_resize(*args, **kwargs): - import numpy as np - - return np.zeros((224, 224, 3), dtype=np.uint8) + # Mock the load_and_process_wrapper function to avoid actual image processing + def mock_load_and_process_wrapper( + filepaths, cfg, desc="Loading images", show_progress=True + ): + # Return a list of (filepath, mock_image) tuples + results = [] + for filepath in filepaths: + mock_image = np.zeros((224, 224, 3), dtype=np.uint8) + results.append((filepath, mock_image)) + return results - def mock_get_image_names(dir_path, recursive=False): - return [os.path.join(dir_path, f"test_image_{i}.jpg") for i in range(3)] - - monkeypatch.setattr( - "anomaly_match.data_io.load_images.read_and_resize_image", mock_read_and_resize - ) monkeypatch.setattr( - "anomaly_match.data_io.find_images_in_folder.get_image_names_from_folder", - mock_get_image_names, + "anomaly_match.data_io.load_images.load_and_process_wrapper", + mock_load_and_process_wrapper, ) # Set up configuration paths = setup_test_files cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] cfg.data_dir = paths["data_dir"] cfg.label_file = paths["label_file"] cfg.metadata_file = os.path.join(paths["test_dir"], "nonexistent_metadata.csv") diff --git a/tests/normalisations_test.py b/tests/normalisations_test.py deleted file mode 100644 index 8ca5ab3..0000000 --- a/tests/normalisations_test.py +++ /dev/null @@ -1,775 +0,0 @@ -# Copyright (c) European Space Agency, 2025. -# -# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which -# is part of this source code package. No part of the package, including -# this file, may be copied, modified, propagated, or distributed except according to -# the terms contained in the file 'LICENCE.txt'. -import numpy as np -import pytest -from loguru import logger -from dotmap import DotMap - -from anomaly_match.utils.get_default_cfg import get_default_cfg -from anomaly_match.image_processing.normalisation import normalise_image -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod - - -def get_test_config(method): - """Returns a test config with the specified normalisation method""" - cfg = get_default_cfg() - cfg.N_to_load = 10 - cfg.size = [64, 64] - cfg.normalisation_method = method - cfg.normalisation = DotMap() - cfg.normalisation.maximum_value = None - cfg.normalisation.minimum_value = None - cfg.normalisation.crop_for_maximum_value = None - cfg.normalisation.log_calculate_minimum_value = False - cfg.normalisation.asinh_scale = [10.0, 10.0, 10.0] # Default for ASINH - cfg.normalisation.asinh_clip = [99.0, 99.0, 99.0] # Default for ASINH - return cfg - - -def get_asinh_test_config(asinh_scale=[1.0, 1.0, 1.0], asinh_clip=[99.0, 99.0, 99.0]): - """Create a test config specifically for ASINH normalisation""" - cfg = get_test_config(NormalisationMethod.ASINH) - if asinh_scale is not None: - cfg.normalisation.asinh_scale = asinh_scale - if asinh_clip is not None: - cfg.normalisation.asinh_clip = asinh_clip - return cfg - - -@pytest.fixture -def caplog(caplog): - """Configure loguru to use the caplog handler""" - handler_id = logger.add(caplog.handler) - yield caplog - logger.remove(handler_id) - - -def create_gradient_rgb(height=16, width=16, dtype=np.uint8): - """Create a test RGB image with gradients in different channels""" - if dtype == np.uint16: - max_val = 65535 - elif dtype == np.float32: - max_val = 1e-3 - else: - max_val = 255 - - # Create gradients for each channel - r = np.linspace(0, max_val, width) - g = np.linspace(0, max_val / 2, width) - b = np.linspace(max_val / 4, max_val, width) - - # Create meshgrids - r_mesh, _ = np.meshgrid(r, np.linspace(0, max_val, height)) - g_mesh, _ = np.meshgrid(g, np.linspace(0, max_val / 2, height)) - b_mesh, _ = np.meshgrid(b, np.linspace(max_val / 4, max_val, height)) - - # Stack channels - image = np.stack([r_mesh, g_mesh, b_mesh], axis=2).astype(dtype) - return image - - -def create_gradient_single_channel(height=16, width=16, dtype=np.uint8): - """Create a test single channel image with gradient""" - if dtype == np.uint16: - max_val = 65535 - elif dtype == np.float32: - max_val = 1e-3 - else: - max_val = 255 - - x = np.linspace(0, max_val, width) - x_mesh, _ = np.meshgrid(x, np.linspace(0, max_val, height)) - return x_mesh.astype(dtype) - - -def create_multi_channel_image(height=16, width=16, dtype=np.uint8): - """Create a test multi-channel image simulating 4 channels: V,Y,J,H astronomical bands - with different intensity ranges to test proper scaling across channels""" - if dtype == np.uint16: - max_vals = [65535, 45000, 55000, 35000] # Different max for each channel - elif dtype == np.float32: - max_vals = [1e-3, 7e-4, 8e-4, 5e-4] - else: - max_vals = [255, 180, 220, 180] - - channels = [] - for max_val in max_vals: - # Create gradient with different ranges for each channel - x = np.linspace(0, max_val, width) - channel_mesh, _ = np.meshgrid(x, np.linspace(0, max_val, height)) - channels.append(channel_mesh) - - # Stack channels - image = np.stack(channels, axis=2).astype(dtype) - return image - - -def create_test_pattern(height=16, width=16, dtype=np.uint8): - """Create a test pattern with border 0s, background 10, and center 100, single channel""" - if dtype == np.float32: - values = [0.01, 0.1, 1.0] # normalized values for float - else: - values = [1, 10, 100] - - # Create base image with value 10 - image = np.full((height, width), values[1], dtype=dtype) - - # Set borders to 0 - image[0, :] = values[0] - image[-1, :] = values[0] - image[:, 0] = values[0] - image[:, -1] = values[0] - - # Set center to 100 - center_h = height // 2 - center_w = width // 2 - center_size = 2 - image[ - center_h - center_size : center_h + center_size + 1, - center_w - center_size : center_w + center_size + 1, - ] = values[2] - - return image - - -@pytest.mark.parametrize("method", NormalisationMethod.get_test_methods()) -def test_normalise_uint16_image(method): - """Test normalisation with uint16 image""" - # Create test image and config - test_image = create_gradient_rgb(dtype=np.uint16) - cfg = get_test_config(method) - - # Apply normalisation - result = normalise_image(test_image, cfg=cfg) - - # Common assertions for all methods - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - # assert good use of dynamic range - assert np.max(result) > 250 - assert np.min(result) < 5 - - if method == NormalisationMethod.CONVERSION_ONLY: - # For NONE, we expect the values to be scaled down from uint16 to uint8 - expected = np.round(((test_image) / (256 * 256 - 1) * 255)).astype(np.uint8) - np.testing.assert_array_equal(result, expected) - - -@pytest.mark.parametrize("method", NormalisationMethod.get_test_methods()) -def test_normalise_float32_image(method): - """Test normalisation with float32 RGB image""" - base_image = create_gradient_rgb(dtype=np.float32) - scale_factor = 1e-9 - test_image = base_image * scale_factor - cfg = get_test_config(method) - result = normalise_image(test_image, cfg=cfg) - - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - assert np.max(result) > 250 - assert np.min(result) < 5 - - -@pytest.mark.parametrize("method", NormalisationMethod.get_test_methods()) -def test_normalise_single_channel(method): - """Test normalisation with single channel uint8 gradient image""" - test_image = create_gradient_single_channel() - cfg = get_test_config(method) - cfg.normalisation.asinh_scale = [10.0] # ASINH scale for single channel - cfg.normalisation.asinh_clip = [99.0] # ASINH clip for single channel - result = normalise_image(test_image, cfg=cfg) - - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - if method == NormalisationMethod.CONVERSION_ONLY: - np.testing.assert_array_equal(result, test_image) - - -@pytest.mark.parametrize("method", NormalisationMethod.get_test_methods()) -def test_normalise_multi_channel(method): - """Test normalisation with multi-channel image (e.g., V,Y,J,H bands)""" - test_image = create_multi_channel_image() - cfg = get_test_config(method) - # unintended asinh parameters for not yet supported multi(=/=3) channel - cfg.normalisation.asinh_scale = [10.0, 10.0, 10.0, 10.0] # ASINH scale for each channel - cfg.normalisation.asinh_clip = [99.0, 99.0, 99.0, 99.0] # ASINH clip for each channel - result = normalise_image(test_image, cfg=cfg) - - # Basic checks - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - if method == NormalisationMethod.CONVERSION_ONLY: - np.testing.assert_array_equal(result, test_image) - elif method == NormalisationMethod.LOG: - # Check that log normalization preserves order in all channels - for channel in range(4): # V,Y,J,H channels - # Get values from first row which has a gradient - channel_vals = result[0, :, channel] - # Check that values are strictly increasing (gradient preserved) - assert np.all(np.diff(channel_vals) > 0) - - elif method == NormalisationMethod.ZSCALE: - # Check that zscale maps each channel to use full range effectively - for channel in range(4): # V,Y,J,H channels - channel_vals = result[..., channel] - # Check that we use most of the range (allowing some margin) - assert np.min(channel_vals) <= 40, f"Channel {channel} min value too high" - assert np.max(channel_vals) >= 180, f"Channel {channel} max value too low" - # Check for reasonable distribution - median_val = np.median(channel_vals) - assert 80 < median_val < 175, f"Channel {channel} median outside expected range" - - -@pytest.mark.parametrize("method", NormalisationMethod.get_test_methods()) -@pytest.mark.parametrize("dtype", [np.uint8, np.float32]) -def test_normalise_pattern(method, dtype): - """Test normalisation with specific pattern image with one channel""" - test_image = create_test_pattern(dtype=dtype) - cfg = get_test_config(method) - cfg.normalisation.asinh_scale = [10.0] # ASINH scale for single channel - cfg.normalisation.asinh_clip = [99.0] # ASINH clip for single channel - result = normalise_image(test_image, cfg=cfg) - - # Basic assertions - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - # Get values for testing - center_val = result[8:9, 8:9][0, 0] # center value (should be brightest) - border_val = result[0, 0] # border value (should be darkest) - bg_val = result[2, 2] # background value (should be intermediate) - - if method == NormalisationMethod.CONVERSION_ONLY: - if dtype == np.uint8: - # For uint8, values should remain unchanged - np.testing.assert_array_equal(result, test_image) - assert center_val == 100 - assert border_val == 1 - assert bg_val == 10 - else: - # For float32, values should be scaled to uint8 range, here manually, normally with astropy - expected = np.round((test_image - 0.01) / (1 - 0.01) * 255).astype(np.uint8) - np.testing.assert_array_equal(result, expected) - elif method == NormalisationMethod.LOG: - # Check that log normalization preserves order and maps values logarithmically - assert border_val < bg_val < center_val - # Ensure reasonable value ranges - assert border_val < 90 # dark border - assert bg_val > 50 and bg_val < 175 # mid-range background - assert center_val > 200 # bright center - - # Get the unique input and output values in order [border, background, center] - if dtype == np.float32: - input_values = np.array([0.01, 0.1, 1.0]) - else: - input_values = np.array([1, 10, 100]) - output_values = np.array([border_val, bg_val, center_val]) - - # Calculate log10 of input values - log_values = np.log10( - 1000 * (input_values - np.min([0])) / (np.max(input_values) - np.min([0])) + 1 - ) / np.log10(1000 + 1) - - # Find scaling factor between log values and output values - # Using least squares to find the best scaling factor - # scale_factor = np.sum(output_values * log_values) / np.sum(log_values * log_values) - - # Check if scaled log values match output values within tolerance - log_values = np.round(log_values * 255) - np.testing.assert_allclose( - output_values, log_values, rtol=0.1 - ) # 10% tolerance for uint8 quantization - - elif method == NormalisationMethod.ZSCALE: - # ZScale should map the values to use the full range effectively - assert border_val < 10 # very dark border - assert bg_val > 50 and bg_val < 175 # mid-range background - assert center_val > 245 # very bright center - # Background should be closer to border than to center due to outlier handling - assert (bg_val - border_val) < (center_val - bg_val) - - -def test_normalise_invalid_method(caplog): - """Test normalisation with invalid method""" - test_image = create_gradient_rgb() - - # Test with invalid string - cfg = get_test_config("invalid") - result = normalise_image(test_image, cfg=cfg) - np.testing.assert_array_equal(result, test_image) - assert "Normalisation method type invalid" in caplog.text - assert "CRITICAL" in caplog.text - caplog.clear() - - # Test with invalid integer - cfg = get_test_config(999) - result = normalise_image(test_image, cfg=cfg) - np.testing.assert_array_equal(result, test_image) - assert "Normalisation method type 999" in caplog.text - assert "CRITICAL" in caplog.text - caplog.clear() - - # Test with invalid type - cfg = get_test_config(None) - result = normalise_image(test_image, cfg=cfg) - np.testing.assert_array_equal(result, test_image) - assert "Normalisation method type None" in caplog.text - assert "CRITICAL" in caplog.text - - -def test_gradient_creation(): - """Test the gradient creation helper function""" - # Test uint8 - img_uint8 = create_gradient_rgb(dtype=np.uint8) - assert img_uint8.dtype == np.uint8 - assert img_uint8.shape == (16, 16, 3) - assert np.min(img_uint8) >= 0 - assert np.max(img_uint8) <= 255 - - # Test uint16 - img_uint16 = create_gradient_rgb(dtype=np.uint16) - assert img_uint16.dtype == np.uint16 - assert img_uint16.shape == (16, 16, 3) - assert np.min(img_uint16) >= 0 - assert np.max(img_uint16) <= 65535 - - # Test float32 - img_float32 = create_gradient_rgb(dtype=np.float32) - assert img_float32.dtype == np.float32 - assert img_float32.shape == (16, 16, 3) - assert np.min(img_float32) >= 0 - assert np.max(img_float32) <= 1e-3 - - -def test_normalise_float32_max_value(): - """Test log normalisation with maximum_value setting""" - base_image = create_gradient_rgb(dtype=np.float32) - test_image = base_image * 1e-9 # Values from 0 to 1e-12 - - # Then with clipping - cfg = get_test_config(NormalisationMethod.LOG) - cfg.normalisation.maximum_value = 0.5e-12 # Clip the top half of values - result = normalise_image(test_image, cfg=cfg) - - # Results should be uint8 and keep dimensions - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - - # Values above maximum_value should all map to the same uint8 value - high_values = test_image > cfg.normalisation.maximum_value - unique_high_values = np.unique(result[high_values]) - assert len(unique_high_values) == 1, "Clipped values should map to the same output" - assert unique_high_values[0] == 255, "Clipped values should map to 255" - - # Values below maximum_value should maintain relative order - low_values = test_image <= cfg.normalisation.maximum_value - low_values_result = result[low_values] - low_values_orig = test_image[low_values] - order_preserved = np.all(np.diff(low_values_result[np.argsort(low_values_orig)]) >= 0) - assert order_preserved, "Order of non-clipped values should be preserved" - - -def test_normalise_float32_min_value(): - """Test log normalisation with minimum_value setting""" - base_image = create_gradient_rgb(dtype=np.float32) - test_image = base_image * 1e-9 # Values from 0 to 1e-12 - min_val = 0.2e-12 - - # Then with minimum value clipping - cfg = get_test_config(NormalisationMethod.LOG) - cfg.normalisation.minimum_value = min_val - result = normalise_image(test_image, cfg=cfg) - # Results should be uint8 and keep dimensions - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - # Values below minimum_value should all map to the same uint8 value - low_values = test_image < min_val - unique_low_values = np.unique(result[low_values]) - assert len(unique_low_values) == 1, "Clipped values should map to same output" - assert unique_low_values[0] == 0, "Clipped values should map to 0" - # Values above minimum_value should maintain relative order - high_values = test_image >= min_val - high_values_result = result[high_values] - high_values_orig = test_image[high_values] - order_preserved = np.all(np.diff(high_values_result[np.argsort(high_values_orig)]) >= 0) - assert order_preserved, "Order of non-clipped values should be preserved" - - -def test_normalise_float32_crop_max(): - """Test log normalisation with crop_for_maximum_value setting""" - base_image = create_gradient_rgb(dtype=np.float32) - test_image = base_image * 1e-9 - # Create a bright spot outside crop region - test_image[0, 0] = 5e-12 # Much brighter than rest - # First without crop - cfg_no_crop = get_test_config(NormalisationMethod.LOG) - result_no_crop = normalise_image(test_image, cfg=cfg_no_crop) - # Then with center crop that excludes bright spot - cfg = get_test_config(NormalisationMethod.LOG) - crop_pixels = 6 - cfg.normalisation.crop_for_maximum_value = (crop_pixels, crop_pixels) # Center crop - result = normalise_image(test_image, cfg=cfg) - # Results should be uint8 and keep dimensions - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - # Without crop, bright spot dominates scaling - assert np.all(result_no_crop[0, 0] == 255) - assert np.max(result_no_crop[1:, 1:]) < 200 - # With crop, center region uses full dynamic range - crop_im_h = test_image.shape[0] - crop_pixels // 2 - crop_im_w = test_image.shape[1] - crop_pixels // 2 - center_region = result[crop_im_h : crop_im_h + crop_pixels, crop_im_w : crop_im_w + crop_pixels] - center_region_no_crop = result_no_crop[ - crop_im_h : crop_im_h + crop_pixels, crop_im_w : crop_im_w + crop_pixels - ] - - # center region should now be the upper dynamic range limit, and before not - assert np.max(center_region) == 255 - assert np.max(center_region_no_crop) < 255 - # Bright spot outside crop still saturates - assert np.all(result[0, 0] == 255) - - -def test_normalise_float32_log_min(): - """Test log normalisation with log_calculate_minimum_value setting""" - base_image = create_gradient_rgb(dtype=np.float32) - test_image = base_image * 1e-9 - # Add negative values to test minimum handling - test_image[0:4, 0:4] = -0.5e-12 - test_image[0, 0] = -1e-12 # dark spot outside crop region - # Without log min calculation - cfg = get_test_config(NormalisationMethod.LOG) - cfg.normalisation.log_calculate_minimum_value = False - result_no_log = normalise_image(test_image, cfg=cfg) - # With log min calculation - cfg.normalisation.log_calculate_minimum_value = True - result_with_log = normalise_image(test_image, cfg=cfg) - # Both results should be valid uint8 images - assert isinstance(result_no_log, np.ndarray) - assert isinstance(result_with_log, np.ndarray) - assert result_no_log.dtype == np.uint8 - assert result_with_log.dtype == np.uint8 - assert result_no_log.shape == test_image.shape - assert result_with_log.shape == test_image.shape - # Without log min: negative values should be clipped to 0 - neg_region_no_log = result_no_log[0:4, 0:4] - assert np.all(neg_region_no_log == 0) - # With log min: negative values should be handled by shifting minimum - neg_region_with_log = result_with_log[0:4, 0:4] - assert not np.all(neg_region_with_log == 0) - assert np.all(result_with_log[0, 0] == 0) # lowest value has to be 0 - # Rest of image should preserve order in both cases - pos_vals = test_image > 0 - order_no_log = np.all(np.diff(result_no_log[pos_vals][np.argsort(test_image[pos_vals])]) >= 0) - order_with_log = np.all( - np.diff(result_with_log[pos_vals][np.argsort(test_image[pos_vals])]) >= 0 - ) - assert order_no_log, "Order of positive values should be preserved without log min" - assert order_with_log, "Order of positive values should be preserved with log min" - - -def create_rgb_test_image(height=16, width=16, dtype=np.float32): - """Create a test RGB image specifically for ASINH testing with different channel characteristics""" - if dtype == np.float32: - # Create different intensity ranges for each channel to test per-channel scaling - r_vals = np.linspace(0.001, 1.0, width) - g_vals = np.linspace(0.01, 0.5, width) - b_vals = np.linspace(0.1, 2.0, width) - else: - # For uint8/uint16 - max_val = 255 if dtype == np.uint8 else 65535 - r_vals = np.linspace(1, max_val, width) - g_vals = np.linspace(10, max_val // 2, width) - b_vals = np.linspace(50, max_val, width) - - # Create meshgrids for each channel - r_mesh, _ = np.meshgrid( - r_vals, - np.linspace( - 0.001 if dtype == np.float32 else 1, 1.0 if dtype == np.float32 else max_val, height - ), - ) - g_mesh, _ = np.meshgrid( - g_vals, - np.linspace( - 0.01 if dtype == np.float32 else 10, - 0.5 if dtype == np.float32 else max_val // 2, - height, - ), - ) - b_mesh, _ = np.meshgrid( - b_vals, - np.linspace( - 0.1 if dtype == np.float32 else 50, 2.0 if dtype == np.float32 else max_val, height - ), - ) - - # Stack channels to create RGB image - image = np.stack([r_mesh, g_mesh, b_mesh], axis=2).astype(dtype) - return image - - -def test_asinh_basic_functionality(): - """Test basic ASINH normalisation functionality with RGB image""" - test_image = create_rgb_test_image(dtype=np.float32) - cfg = get_asinh_test_config() - result = normalise_image(test_image, cfg=cfg) - - # Basic checks - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert len(result.shape) == 3 # Should be RGB - assert result.shape[2] == 3 # Should have 3 channels - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - # Check that any channel uses reasonable dynamic range - assert np.max(result) > 200 # Should use upper range - assert np.min(result) < 50 # Should use lower range - - -def test_asinh_scaling_parameters(): - """Test ASINH normalisation with different scaling parameters""" - test_image = create_rgb_test_image(dtype=np.float32) - - # Test with low scaling (more linear-like behavior) - cfg_low = get_asinh_test_config(asinh_scale=[0.1, 0.1, 0.1]) - result_low = normalise_image(test_image, cfg=cfg_low) - - # Test with high scaling (more log-like behavior) - cfg_high = get_asinh_test_config(asinh_scale=[3.0, 3.0, 3.0]) - result_high = normalise_image(test_image, cfg=cfg_high) - - # Test with per-channel scaling - cfg_mixed = get_asinh_test_config(asinh_scale=[1.0, 0.1, 0.05]) - result_mixed = normalise_image(test_image, cfg=cfg_mixed) - - # All results should be valid uint8 images - for result in [result_low, result_high, result_mixed]: - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - # With low scaling, gradient should be more preserved (more linear) - # With high scaling, contrast should be enhanced (more compressed dynamic range) - # Check that different scalings produce different results - assert not np.array_equal(result_low, result_high) - assert not np.array_equal(result_low, result_mixed) - assert not np.array_equal(result_high, result_mixed) - - # Test per-channel differences with mixed scaling - # Red channel (scale=1.0) should be different from Green (scale=10.0) and Blue (scale=50.0) - red_channel = result_mixed[:, :, 0] - green_channel = result_mixed[:, :, 1] - blue_channel = result_mixed[:, :, 2] - - # Channels should have different distributions due to different scaling - assert not np.array_equal(red_channel, green_channel) - assert not np.array_equal(green_channel, blue_channel) - - -def test_asinh_clipping_parameters(): - """Test ASINH normalisation with different clipping parameters""" - test_image = create_rgb_test_image(dtype=np.float32) - - # Test with no clipping (100% percentile) - cfg_no_clip = get_asinh_test_config(asinh_clip=[100.0, 100.0, 100.0]) - result_no_clip = normalise_image(test_image, cfg=cfg_no_clip) - - # Test with aggressive clipping (90% percentile) - cfg_clip = get_asinh_test_config(asinh_clip=[70.0, 70.0, 70.0]) - result_clip = normalise_image(test_image, cfg=cfg_clip) - - # Test with per-channel clipping - cfg_mixed_clip = get_asinh_test_config(asinh_clip=[85.0, 55.0, 98.0]) - result_mixed_clip = normalise_image(test_image, cfg=cfg_mixed_clip) - - # Test with single value clipping - cfg_single_clip = get_asinh_test_config(asinh_clip=92.0) - result_single_clip = normalise_image(test_image, cfg=cfg_single_clip) - - # All results should be valid uint8 images - for result in [result_no_clip, result_clip, result_mixed_clip, result_single_clip]: - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - # Clipping should produce different results - assert not np.array_equal(result_no_clip, result_clip) - assert not np.array_equal(result_no_clip, result_mixed_clip) - assert not np.array_equal(result_clip, result_mixed_clip) - - # With clipping, the distributions should be different - for channel in range(3): - # The distributions should be different - assert not np.array_equal(result_clip[:, :, channel], result_no_clip[:, :, channel]) - - -def test_asinh_with_uint8_image(): - """Test ASINH normalisation with uint8 input image""" - test_image = create_gradient_rgb(dtype=np.uint8) - cfg = get_asinh_test_config() - result = normalise_image(test_image, cfg=cfg) - - # Basic checks - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - # ASINH should preserve the overall gradient structure - # Check that each channel maintains some order - for channel in range(3): - channel_vals = result[0, :, channel] # First row gradient - # Should generally increase or at least not decrease significantly - # (allowing for some variation due to asinh transformation) - diffs = np.diff(channel_vals.astype(np.int16)) # Use int16 to handle negative diffs - # At least 70% of differences should be non-negative (allowing for some noise) - non_negative_ratio = np.sum(diffs >= 0) / len(diffs) - assert non_negative_ratio > 0.7, f"Channel {channel} gradient not well preserved" - - -def test_asinh_with_uint16_image(): - """Test ASINH normalisation with uint16 input image""" - test_image = create_gradient_rgb(dtype=np.uint16) - cfg = get_asinh_test_config() - result = normalise_image(test_image, cfg=cfg) - - # Basic checks - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - # Should use good dynamic range - assert np.max(result) > 200 - assert np.min(result) < 50 - - -def test_asinh_edge_cases(): - """Test ASINH normalisation edge cases""" - - # Test with zeros - zero_image = np.zeros((8, 8, 3), dtype=np.float32) - cfg = get_asinh_test_config() - result_zeros = normalise_image(zero_image, cfg=cfg) - - assert isinstance(result_zeros, np.ndarray) - assert result_zeros.dtype == np.uint8 - assert result_zeros.shape == zero_image.shape - # All values should be 0 for zero input - assert np.all(result_zeros == 0) - - # Test with very small values - small_image = np.full((8, 8, 3), 1e-10, dtype=np.float32) - result_small = normalise_image(small_image, cfg=cfg) - - assert isinstance(result_small, np.ndarray) - assert result_small.dtype == np.uint8 - assert result_small.shape == small_image.shape - - # Test with identical values (should result in min/max error handling) - uniform_image = np.full((8, 8, 3), 0.5, dtype=np.float32) - result_uniform = normalise_image(uniform_image, cfg=cfg) - - assert isinstance(result_uniform, np.ndarray) - assert result_uniform.dtype == np.uint8 - assert result_uniform.shape == uniform_image.shape - - -def test_asinh_channel_independence(): - """Test that ASINH normalisation processes channels independently""" - # Create an image where channels have very different ranges - height, width = 12, 12 - - # Red channel: very bright - red_channel = np.full((height, width), 1.0, dtype=np.float32) - # Green channel: medium brightness - green_channel = np.full((height, width), 0.1, dtype=np.float32) - # Blue channel: very dim - blue_channel = np.full((height, width), 0.01, dtype=np.float32) - - test_image = np.stack([red_channel, green_channel, blue_channel], axis=2) - - cfg = get_asinh_test_config(asinh_scale=[10.0, 10.0, 10.0]) - result = normalise_image(test_image, cfg=cfg) - - # Basic checks - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - - # Each channel should be processed independently - # So they should not all have the same values despite input differences - red_result = result[:, :, 0] - green_result = result[:, :, 1] - blue_result = result[:, :, 2] - - # Due to per-channel normalization, they might end up with similar values - # But the processing should still be per-channel - assert red_result.shape == (height, width) - assert green_result.shape == (height, width) - assert blue_result.shape == (height, width) - - -def test_asinh_with_crop_for_maximum(): - """Test ASINH normalisation with crop_for_maximum_value setting""" - test_image = create_rgb_test_image(dtype=np.float32) - - # Add a bright spot outside the center region - test_image[0, 0, :] = 10.0 # Very bright spot in corner - - # Test without crop, no clip to keep bright spot - cfg_no_crop = get_asinh_test_config(asinh_clip=100.0) - result_no_crop = normalise_image(test_image, cfg=cfg_no_crop) - - # Test with center crop that excludes the bright spot, no clip to keep bright spot - cfg_crop = get_asinh_test_config(asinh_clip=100.0) - cfg_crop.normalisation.crop_for_maximum_value = (8, 8) # Center crop - result_crop = normalise_image(test_image, cfg=cfg_crop) - - # Both results should be valid - for result in [result_no_crop, result_crop]: - assert isinstance(result, np.ndarray) - assert result.dtype == np.uint8 - assert result.shape == test_image.shape - assert np.min(result) >= 0 - assert np.max(result) <= 255 - - # Results should be different due to different maximum calculation - assert not np.array_equal(result_no_crop, result_crop) - - # The bright spot should still be present in both results - assert np.all(result_no_crop[0, 0, :] > 200) # Should be bright - assert np.all(result_crop[0, 0, :] > 200) # Should still be bright diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index 8444afd..e1137ca 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -19,18 +19,13 @@ def pipeline_config(): ), style={"color": "white"}, ) - progress_bar = widgets.FloatProgress( - value=0.0, - min=0.0, - max=1.0, - ) cfg = am.get_default_cfg() am.set_log_level("trace", cfg) cfg.data_dir = "tests/test_data/" - cfg.size = [64, 64] + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 cfg.num_train_iter = 10 - cfg.progress_bar = progress_bar return cfg, out diff --git a/tests/session_test.py b/tests/session_test.py index 2b24a3e..624a6cc 100644 --- a/tests/session_test.py +++ b/tests/session_test.py @@ -21,20 +21,15 @@ def base_config(): border="1px solid white", height="400px", background_color="black", overflow="auto" ), ) - progress_bar = widgets.FloatProgress( - value=0.0, - min=0.0, - max=1.0, - ) cfg = am.get_default_cfg() am.set_log_level("debug", cfg) cfg.data_dir = "tests/test_data/" - cfg.size = [64, 64] + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 cfg.num_train_iter = 2 cfg.test_ratio = 0.5 cfg.output_dir = "tests/test_output" - cfg.progress_bar = progress_bar return cfg, out @@ -669,3 +664,88 @@ def test_no_double_counting_after_load_next_batch(base_config): active_normal_after, active_anomalous_after = session.get_active_learning_counts() assert active_normal_after == 0 assert active_anomalous_after == 0 + + +def test_iteration_scores_saved_after_training(base_config): + """Test that unlabelled and test scores are saved after training with correct mappings.""" + cfg, out = base_config + + # Create a fresh session for this test + session = Session(cfg) + session.out = out + + # Train the model + session.train(cfg) + + # Check that session tracker has the iteration + assert len(session.session_tracker.session_iterations) >= 1 + + # Get the latest iteration + latest_iteration = session.session_tracker.session_iterations[-1] + + # Verify unlabelled scores file was created and path stored + assert latest_iteration.unlabelled_scores_file is not None + assert os.path.exists(latest_iteration.unlabelled_scores_file) + + # Load and verify unlabelled scores + unlabelled_df = pd.read_csv(latest_iteration.unlabelled_scores_file) + assert "filename" in unlabelled_df.columns + assert "score" in unlabelled_df.columns + assert len(unlabelled_df) > 0 + + # Verify score values are valid probabilities + assert unlabelled_df["score"].min() >= 0.0 + assert unlabelled_df["score"].max() <= 1.0 + + # Verify that filenames in the CSV match the session's filenames + csv_filenames = set(unlabelled_df["filename"].tolist()) + session_filenames = set(session.filenames.tolist()) + assert csv_filenames == session_filenames, "Saved filenames don't match session filenames" + + # Verify score mapping: check a few samples match between CSV and session + for idx, (filename, score) in enumerate(zip(session.filenames[:5], session.scores[:5])): + csv_score = unlabelled_df[unlabelled_df["filename"] == filename]["score"].values[0] + assert ( + abs(csv_score - score) < 1e-6 + ), f"Score mismatch for {filename}: {csv_score} vs {score}" + + # If test_ratio > 0, verify test scores were also saved + if cfg.test_ratio > 0: + assert latest_iteration.test_scores_file is not None + assert os.path.exists(latest_iteration.test_scores_file) + + # Load and verify test scores + test_df = pd.read_csv(latest_iteration.test_scores_file) + assert "filename" in test_df.columns + assert "score" in test_df.columns + assert len(test_df) > 0 + + # Verify test score values are valid probabilities + assert test_df["score"].min() >= 0.0 + assert test_df["score"].max() <= 1.0 + + +def test_iteration_scores_no_test_set(base_config): + """Test that only unlabelled scores are saved when test_ratio is 0.""" + cfg, out = base_config + + # Modify config to have no test set + cfg_no_test = cfg.copy() + cfg_no_test.test_ratio = 0.0 + + # Create a fresh session + session = Session(cfg_no_test) + session.out = out + + # Train the model + session.train(cfg_no_test) + + # Get the latest iteration + latest_iteration = session.session_tracker.session_iterations[-1] + + # Unlabelled scores should still be saved + assert latest_iteration.unlabelled_scores_file is not None + assert os.path.exists(latest_iteration.unlabelled_scores_file) + + # Test scores should not be saved (no test set) + assert latest_iteration.test_scores_file is None diff --git a/tests/test_batch_size_estimation.py b/tests/test_batch_size_estimation.py new file mode 100644 index 0000000..746f2f9 --- /dev/null +++ b/tests/test_batch_size_estimation.py @@ -0,0 +1,167 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. + +import pytest +from prediction_utils import estimate_batch_size, MEMORY_COEFFICIENTS +from anomaly_match import get_default_cfg + + +FAKE_VRAM_BYTES = 16 * 1024**3 # 16GB + + +@pytest.fixture +def test_config(): + cfg = get_default_cfg() + cfg.normalisation.image_size = [100, 100] + cfg.net = "efficientnet-lite0" + cfg.num_channels = 3 + + return cfg + + +@pytest.fixture +def no_cuda(monkeypatch): + """Cuda GPU mock showing no GPU available""" + import torch + + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(torch.cuda, "get_device_properties", lambda _: None) + + +@pytest.fixture +def fake_cuda(monkeypatch): + """Cuda GPU mock with fixed 16 GB of vram""" + import torch + + class Props: + total_memory = FAKE_VRAM_BYTES + + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "current_device", lambda: 0) + monkeypatch.setattr(torch.cuda, "get_device_properties", lambda _: Props()) + + +def verify_batch_size_estimation( + test_config, estimated_batch_size, available_vram, safety_margin, net_name=None +): + """Helper function to verify batch size estimation""" + S2 = test_config.normalisation.image_size[0] * test_config.normalisation.image_size[1] + coeffs = MEMORY_COEFFICIENTS[net_name or test_config.net] + + # Calculate memory occupied according to model with the estimated batch size + memory_occupied = ( + coeffs["a"] * estimated_batch_size * S2 * test_config.num_channels + + coeffs["b"] * estimated_batch_size + + coeffs["c"] + ) + + # How much memory actually is avaiable according to the specification + usable_vram = available_vram * (1 - safety_margin) + + # What the batch size should be if everything works properly + true_batch_size = max( + 1, + int( + (usable_vram - coeffs["c"]) + / (coeffs["a"] * S2 * test_config.num_channels + coeffs["b"]) + ), + ) + + assert true_batch_size == estimated_batch_size + assert memory_occupied <= usable_vram + + +def test_estimate_batch_size_manual_vram(test_config): + """Test batch size estimation with manual vram specification""" + + # Setup + available_vram = 1024 # MB + safety_margin = 0.2 + estimated_batch_size = estimate_batch_size( + test_config, available_vram=available_vram, safety_margin=safety_margin + ) + + verify_batch_size_estimation( + test_config=test_config, + estimated_batch_size=estimated_batch_size, + available_vram=available_vram, + safety_margin=safety_margin, + ) + + +def test_estimate_batch_size_unknown_model(test_config): + """Test batch size estimation with unknown model""" + + # Setup + available_vram = 1024 # MB + test_config.net = "UnknownNet" + safety_margin = 0.2 + estimated_batch_size = estimate_batch_size( + test_config, available_vram=available_vram, safety_margin=safety_margin + ) + + verify_batch_size_estimation( + test_config=test_config, + estimated_batch_size=estimated_batch_size, + available_vram=available_vram, + safety_margin=safety_margin, + net_name="efficientnet-lite0", + ) + + +def test_estimate_batch_size_cuda_16gb(test_config, fake_cuda): + """Test batch size estimation with efficientnet-lite0 and 16GB vram cuda GPU""" + + # Setup + safety_margin = 0.2 + estimated_batch_size = estimate_batch_size(test_config, safety_margin=safety_margin) + + verify_batch_size_estimation( + test_config=test_config, + estimated_batch_size=estimated_batch_size, + safety_margin=safety_margin, + available_vram=FAKE_VRAM_BYTES / 1024**2, + ) + + +def test_estimate_batch_size_no_cuda(test_config, no_cuda): + """Test batch size estimation with eddicientnet-lite0 and no cuda GPU available""" + + # Setup + safety_margin = 0.2 + estimated_batch_size = estimate_batch_size(test_config, safety_margin=safety_margin) + + verify_batch_size_estimation( + test_config=test_config, + estimated_batch_size=estimated_batch_size, + safety_margin=safety_margin, + available_vram=4096, + ) + + +def test_estimate_batch_size_invalid_memory(test_config): + """Test safeguard with invalid memory available""" + + batch_size = estimate_batch_size(test_config, available_vram=0) + + assert batch_size == 1 + + +def test_estimate_batch_size_invalid_model_coefficients(test_config, monkeypatch): + + import prediction_utils + + # Inject invalid coefficients + monkeypatch.setattr( + prediction_utils, + "MEMORY_COEFFICIENTS", + {"efficientnet-lite0": {"a": -0.1, "b": -0.1, "c": -0.1}}, + ) + + batch_size = estimate_batch_size(test_config, available_vram=4096) + + assert batch_size == 1 diff --git a/tests/test_fitsbolt_config_persistence.py b/tests/test_fitsbolt_config_persistence.py new file mode 100644 index 0000000..179b54f --- /dev/null +++ b/tests/test_fitsbolt_config_persistence.py @@ -0,0 +1,330 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. + +"""Tests for fitsbolt configuration persistence in model checkpoints. + +The fitsbolt DotMap configuration can be pickled directly via torch.save/load +without explicit serialization. +""" + +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import torch +from dotmap import DotMap +from fitsbolt.cfg.create_config import create_config as fb_create_cfg, validate_config +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod + +from anomaly_match.data_io.load_images import get_fitsbolt_config + + +class TestFitsboltConfigPickling: + """Test cases for fitsbolt config pickling via torch.save/load.""" + + def test_pickle_roundtrip_basic(self): + """Test basic pickle roundtrip via torch checkpoint.""" + original_cfg = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + n_output_channels=3, + normalisation_method=NormalisationMethod.CONVERSION_ONLY, + num_workers=4, + ) + + # Save via torch + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + loaded_cfg = loaded["fitsbolt_cfg"] + + assert loaded_cfg.size == original_cfg.size + assert loaded_cfg.n_output_channels == original_cfg.n_output_channels + assert loaded_cfg.num_workers == original_cfg.num_workers + assert loaded_cfg.normalisation_method == original_cfg.normalisation_method + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + def test_pickle_numpy_dtype(self): + """Test pickling of numpy dtypes.""" + original_cfg = fb_create_cfg( + output_dtype=np.float32, + size=[128, 128], + n_output_channels=3, + ) + + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + loaded_cfg = loaded["fitsbolt_cfg"] + + assert loaded_cfg.output_dtype == np.float32 + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + def test_pickle_all_normalisation_methods(self): + """Test pickling with all normalisation methods.""" + for method in NormalisationMethod: + original_cfg = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + n_output_channels=3, + normalisation_method=method, + ) + + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + loaded_cfg = loaded["fitsbolt_cfg"] + + assert loaded_cfg.normalisation_method == method + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + def test_pickle_channel_combination(self): + """Test pickling of numpy array channel_combination.""" + original_cfg = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + fits_extension=[0, 1, 2], + n_output_channels=3, + channel_combination=np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + ) + + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + loaded_cfg = loaded["fitsbolt_cfg"] + + np.testing.assert_array_equal( + loaded_cfg.channel_combination, original_cfg.channel_combination + ) + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + def test_pickle_asinh_settings(self): + """Test pickling of asinh normalisation settings.""" + original_cfg = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + n_output_channels=3, + normalisation_method=NormalisationMethod.ASINH, + norm_asinh_scale=[0.5, 0.6, 0.7], + norm_asinh_clip=[99.0, 99.5, 99.8], + ) + + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + loaded_cfg = loaded["fitsbolt_cfg"] + + assert loaded_cfg.normalisation.asinh_scale == original_cfg.normalisation.asinh_scale + assert loaded_cfg.normalisation.asinh_clip == original_cfg.normalisation.asinh_clip + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + +class TestFitsboltConfigValidation: + """Test cases for fitsbolt config validation after pickling.""" + + def test_validate_pickled_config(self): + """Test that pickled config passes fitsbolt validation.""" + original_cfg = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + n_output_channels=3, + normalisation_method=NormalisationMethod.CONVERSION_ONLY, + num_workers=4, + ) + + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": original_cfg}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + loaded_cfg = loaded["fitsbolt_cfg"] + + # Validate using fitsbolt's validate_config + validate_config(loaded_cfg) + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + +class TestFitsboltConfigCompatibility: + """Test compatibility with fitsbolt's create_config function.""" + + def test_compatibility_with_fits_extension_settings(self): + """Test pickling with various fits_extension configurations.""" + # Single integer extension + cfg1 = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + n_output_channels=3, + fits_extension=0, + ) + + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": cfg1}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + validate_config(loaded["fitsbolt_cfg"]) + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + # List of extensions + cfg2 = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + n_output_channels=3, + fits_extension=[0, 1, 2], + ) + + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": cfg2}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + validate_config(loaded["fitsbolt_cfg"]) + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + +class TestGetFitsboltConfigIntegration: + """Test get_fitsbolt_config integration with pickling.""" + + def test_get_fitsbolt_config_pickling(self): + """Test that config from get_fitsbolt_config can be pickled.""" + # Create an AnomalyMatch-style config + cfg = DotMap() + cfg.normalisation = DotMap() + cfg.normalisation.output_dtype = np.uint8 + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.fits_extension = None + cfg.normalisation.interpolation_order = 1 + cfg.normalisation.n_output_channels = 3 + cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY + cfg.normalisation.channel_combination = None + cfg.normalisation.norm_maximum_value = None + cfg.normalisation.norm_minimum_value = None + cfg.normalisation.norm_log_calculate_minimum_value = False + cfg.normalisation.norm_crop_for_maximum_value = None + cfg.normalisation.norm_asinh_scale = [0.7] + cfg.normalisation.norm_asinh_clip = [99.8] + cfg.num_workers = 4 + + # Get fitsbolt config + cfg = get_fitsbolt_config(cfg) + + # Save and load via torch + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as f: + checkpoint_path = f.name + + try: + torch.save({"fitsbolt_cfg": cfg.fitsbolt_cfg}, checkpoint_path) + loaded = torch.load(checkpoint_path, weights_only=False) + loaded_cfg = loaded["fitsbolt_cfg"] + + # Validate + validate_config(loaded_cfg) + + # Verify key properties + assert loaded_cfg.size == [64, 64] + assert loaded_cfg.n_output_channels == 3 + assert loaded_cfg.normalisation_method == NormalisationMethod.CONVERSION_ONLY + finally: + Path(checkpoint_path).unlink(missing_ok=True) + + +class TestFitsboltConfigE2EWithCheckpoint: + """End-to-end tests for fitsbolt config persistence with model checkpoints.""" + + def setup_method(self): + """Set up test environment.""" + self.temp_dir = tempfile.mkdtemp() + + def teardown_method(self): + """Clean up test environment.""" + shutil.rmtree(self.temp_dir) + + def test_fitsbolt_config_in_checkpoint_dict(self): + """Test that fitsbolt config can be saved and loaded in a checkpoint-like dict.""" + # Create a fitsbolt config + fitsbolt_cfg = fb_create_cfg( + output_dtype=np.uint8, + size=[64, 64], + n_output_channels=3, + normalisation_method=NormalisationMethod.ASINH, + norm_asinh_scale=[0.5, 0.6, 0.7], + norm_asinh_clip=[99.0, 99.5, 99.8], + ) + + # Create a mock checkpoint + checkpoint = { + "model_state": {"dummy": "data"}, + "optimizer_state": None, + "fitsbolt_cfg": fitsbolt_cfg, + } + + # Save checkpoint + checkpoint_path = Path(self.temp_dir) / "test_checkpoint.pth" + torch.save(checkpoint, checkpoint_path) + + # Load checkpoint + loaded_checkpoint = torch.load(checkpoint_path, weights_only=False) + loaded_fitsbolt_cfg = loaded_checkpoint["fitsbolt_cfg"] + + # Verify + assert loaded_fitsbolt_cfg.size == [64, 64] + assert loaded_fitsbolt_cfg.n_output_channels == 3 + assert loaded_fitsbolt_cfg.normalisation_method == NormalisationMethod.ASINH + assert loaded_fitsbolt_cfg.normalisation.asinh_scale == [0.5, 0.6, 0.7] + assert loaded_fitsbolt_cfg.normalisation.asinh_clip == [99.0, 99.5, 99.8] + + # Validate loaded config + validate_config(loaded_fitsbolt_cfg) + + def test_backward_compatibility_checkpoint_without_fitsbolt(self): + """Test loading checkpoints that don't have fitsbolt_cfg.""" + # Create a mock checkpoint without fitsbolt_cfg (legacy format) + checkpoint = { + "model_state": {"dummy": "data"}, + "optimizer_state": None, + } + + # Save checkpoint + checkpoint_path = Path(self.temp_dir) / "legacy_checkpoint.pth" + torch.save(checkpoint, checkpoint_path) + + # Load checkpoint + loaded_checkpoint = torch.load(checkpoint_path, weights_only=False) + + # Check that fitsbolt_cfg is not present + assert "fitsbolt_cfg" not in loaded_checkpoint + + # Accessing non-existent key should return None via .get() + result = loaded_checkpoint.get("fitsbolt_cfg") + assert result is None diff --git a/tests/test_image_io.py b/tests/test_image_io.py index 5b70929..5b280fb 100644 --- a/tests/test_image_io.py +++ b/tests/test_image_io.py @@ -12,12 +12,49 @@ from pathlib import Path from PIL import Image import torch +import copy -from anomaly_match.data_io.load_images import read_image_data, process_image_array +from anomaly_match.data_io.load_images import ( + load_and_process_single_wrapper, + process_single_wrapper, + get_fitsbolt_config, +) from anomaly_match.utils.get_default_cfg import get_default_cfg from prediction_utils import save_results +def _load_image_with_fitsbolt(filepath, cfg): + """Helper function to load image using fitsbolt with AnomalyMatch config.""" + # Use the new wrapper function instead of directly using fitsbolt + return load_and_process_single_wrapper(filepath, cfg, desc="test loading", show_progress=False) + + +def _update_config(cfg, **kwargs): + """Update both the main config and fitsbolt config with the given parameters.""" + for key, value in kwargs.items(): + setattr(cfg, key, value) + if key == "size": + cfg.normalisation.image_size = value + elif key == "fits_extension": + cfg.normalisation.fits_extension = value + elif key.startswith("normalisation."): + # Handle nested attributes in normalisation + norm_key = key.split(".")[1] + setattr(cfg.normalisation, norm_key, value) + fitsbolt_key = f"norm_{norm_key}" + setattr(cfg.normalisation, fitsbolt_key, value) + return cfg + + +def _process_image_array_with_fitsbolt(image_array, cfg): + """Helper function to process image array using fitsbolt.""" + # Use the new wrapper function instead of directly using fitsbolt + testcfg = copy.deepcopy(cfg) + testcfg = get_fitsbolt_config(testcfg) + + return process_single_wrapper(image_array, testcfg, desc="array") + + def create_fits_file(image_data, filepath): """ Create a FITS file from the provided image data. @@ -106,9 +143,13 @@ def test_image(self): def test_config(self): """Get test configuration.""" cfg = get_default_cfg() - cfg.size = [224, 224] # Set fits_extension to use the proper extensions from our test FITS file - cfg.fits_extension = ["R", "G", "B"] # Use the named extensions we created + + # Add fitsbolt configuration needed by the wrapper functions + cfg.normalisation.image_size = [224, 224] + cfg.normalisation.fits_extension = ["R", "G", "B"] + cfg.normalisation.n_output_channels = 3 + return cfg def test_image_format_consistency(self, test_image, test_config): @@ -126,13 +167,13 @@ def test_image_format_consistency(self, test_image, test_config): f.create_dataset("image", data=test_image) # Load and compare PNG vs HDF5 - loaded_png = read_image_data(str(png_path), test_config) + loaded_png = _load_image_with_fitsbolt(str(png_path), test_config) loaded_hdf5 = None # Load from HDF5 with h5py.File(hdf5_path, "r") as f: hdf5_data = f["image"][:] - loaded_hdf5 = process_image_array(hdf5_data, test_config) + loaded_hdf5 = _process_image_array_with_fitsbolt(hdf5_data, test_config) # Both should have the same shape after processing if hasattr(loaded_png, "shape"): @@ -228,14 +269,14 @@ def test_image_pipeline_integration(self, test_image, test_config): Image.fromarray(test_image).save(input_path) # Load through the pipeline - processed_image = read_image_data(str(input_path), test_config) + processed_image = _load_image_with_fitsbolt(str(input_path), test_config) # Should be processed correctly assert processed_image is not None # The function returns a numpy array assert isinstance(processed_image, np.ndarray) - expected_size = tuple(test_config.size) + expected_size = tuple(test_config.normalisation.image_size) assert ( processed_image.shape[:2] == expected_size or processed_image.shape[:2] == expected_size[::-1] @@ -312,17 +353,22 @@ def test_prediction_process_integration(self, test_image, test_config): # Check that critical fields are present assert hasattr( - reloaded_config, "fits_extension" - ), "fits_extension field missing from reloaded config" - assert hasattr(reloaded_config, "size"), "size field missing from reloaded config" + reloaded_config, "normalisation" + ), "normalisation field missing from reloaded config" + assert hasattr( + reloaded_config.normalisation, "fits_extension" + ), "fits_extension field missing from reloaded config.normalisation" + assert hasattr( + reloaded_config.normalisation, "size" + ), "size field missing from reloaded config.normalisation" assert hasattr( - reloaded_config, "normalisation_method" - ), "normalisation_method field missing from reloaded config" + reloaded_config.normalisation, "normalisation_method" + ), "normalisation_method field missing from reloaded config.normalisation" # Test image loading with reloaded config for test_file in test_files: try: - loaded_image = read_image_data(test_file, reloaded_config) + loaded_image = _load_image_with_fitsbolt(test_file, reloaded_config) assert ( loaded_image is not None ), f"Failed to load {test_file} with reloaded config" @@ -358,7 +404,6 @@ def test_image_formats_comprehensive(self, test_image, test_config): ("test.jpg", lambda p: Image.fromarray(test_image).save(p, quality=95), True), ("test.jpeg", lambda p: Image.fromarray(test_image).save(p, quality=95), True), ("test.tiff", lambda p: Image.fromarray(test_image).save(p), True), - ("test.tif", lambda p: Image.fromarray(test_image).save(p), True), ( "test.fits", lambda p: create_fits_file(test_image, p), @@ -378,7 +423,7 @@ def test_image_formats_comprehensive(self, test_image, test_config): assert file_path.is_file(), f"File {file_path} is not a regular file" try: - loaded_image = read_image_data(str(file_path), test_config) + loaded_image = _load_image_with_fitsbolt(str(file_path), test_config) if should_succeed: assert loaded_image is not None, f"Failed to load {filename}" diff --git a/tests/test_model_io_integration.py b/tests/test_model_io_integration.py index 0f2cd78..6c7f6a6 100644 --- a/tests/test_model_io_integration.py +++ b/tests/test_model_io_integration.py @@ -14,7 +14,7 @@ from anomaly_match.data_io.SessionIOHandler import SessionIOHandler from anomaly_match.pipeline.SessionTracker import SessionTracker -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod class MockModel(nn.Module): @@ -112,14 +112,14 @@ def test_load_model_with_normalisation_update(self): # Update config to match model for saving save_cfg = DotMap(self.cfg) - save_cfg.normalisation_method = NormalisationMethod.LOG + save_cfg.normalisation.normalisation_method = NormalisationMethod.LOG # Save model self.session_io.save_model(self.mock_model, save_cfg, session_tracker=None) # Create config with different normalisation for loading test_cfg = DotMap(self.cfg) - test_cfg.normalisation_method = NormalisationMethod.CONVERSION_ONLY + test_cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY # Load model new_model = MockFixMatch() @@ -127,7 +127,7 @@ def test_load_model_with_normalisation_update(self): # Verify normalisation was updated from model assert success - assert test_cfg.normalisation_method == NormalisationMethod.LOG + assert test_cfg.normalisation.normalisation_method == NormalisationMethod.LOG assert new_model.last_normalisation_method == NormalisationMethod.LOG def test_load_model_nonexistent_file(self): diff --git a/tests/test_prediction_process.py b/tests/test_prediction_process.py index 164ead9..32897f9 100644 --- a/tests/test_prediction_process.py +++ b/tests/test_prediction_process.py @@ -5,6 +5,7 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. import pytest +import csv import os import numpy as np import h5py @@ -14,34 +15,58 @@ import pandas as pd import torch from loguru import logger +from astropy.io import fits +from astropy.wcs import WCS +from astropy.table import Table from prediction_process import evaluate_files from prediction_process_hdf5 import evaluate_images_in_hdf5 from prediction_process_zarr import evaluate_images_in_zarr +from prediction_process_cutana import evaluate_images_from_cutana from prediction_utils import save_results -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from fitsbolt.cfg.create_config import create_config as fb_create_cfg from anomaly_match.utils.get_default_cfg import get_default_cfg @pytest.fixture def test_config(): cfg = get_default_cfg() - cfg.size = [150, 150] + cfg.normalisation.image_size = [150, 150] + cfg.normalisation.n_output_channels = 3 cfg.net = "efficientnet-lite0" cfg.pretrained = True cfg.num_channels = 3 cfg.model_path = "tests/test_data/test_model.pth" cfg.gpu = 0 cfg.output_dir = tempfile.mkdtemp() - cfg.normalisation_method = NormalisationMethod.CONVERSION_ONLY + cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY cfg.log_level = "INFO" # Add proper log level cfg.name = "test_session" # Add session name cfg.seed = 42 # Add seed cfg.test_ratio = 0.0 # Add test ratio cfg.save_dir = tempfile.mkdtemp() # Add save directory cfg.data_dir = "tests/test_data/" # Add data directory + + # Create fb_cfg for fitsbolt + cfg.fitsbolt_cfg = fb_create_cfg( + output_dtype=np.uint8, + size=cfg.normalisation.image_size, + fits_extension=cfg.normalisation.fits_extension, + interpolation_order=cfg.normalisation.interpolation_order, + normalisation_method=cfg.normalisation.normalisation_method, + channel_combination=cfg.normalisation.channel_combination, + num_workers=cfg.num_workers, + norm_maximum_value=cfg.normalisation.norm_maximum_value, + norm_minimum_value=cfg.normalisation.norm_minimum_value, + norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, + norm_crop_for_maximum_value=cfg.normalisation.norm_crop_for_maximum_value, + norm_asinh_scale=cfg.normalisation.norm_asinh_scale, + norm_asinh_clip=cfg.normalisation.norm_asinh_clip, + ) + return cfg @@ -163,15 +188,67 @@ def multiple_test_zarr(sample_images, tmp_path): return zarr_files, str(tmp_path) +@pytest.fixture +def zarr_batch_folders(sample_images, tmp_path): + """Create multiple batch folders with images.zarr subdirectories (mimics real structure).""" + batch_folders = [] + + # Create 3 different batch folders + for batch_idx in range(3): + batch_folder = tmp_path / f"batch_{batch_idx:03d}" + batch_folder.mkdir() + + zarr_path = batch_folder / "images.zarr" + + # Create zarr store + root = zarr.open_group(str(zarr_path), mode="w") + + # Use different images in each batch (split sample_images) + start_idx = batch_idx * 3 + end_idx = min(start_idx + 3, len(sample_images)) + batch_images = sample_images[start_idx:end_idx] + + if not batch_images: # If no more images, create a minimal one + # Create a simple test image with different color per batch + color_map = {0: (255, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255)} # RGB + img_array = np.zeros((150, 150, 3), dtype=np.uint8) + img_array[50:100, 50:100] = color_map[batch_idx] + batch_images = [Image.fromarray(img_array)] + + # Convert PIL images to numpy arrays + img_arrays = [] + filenames = [] + for i, img in enumerate(batch_images): + img_array = np.array(img) + img_arrays.append(img_array) + filenames.append(f"batch_{batch_idx:03d}_img_{i}.jpg") + + # Stack images into a single array and save to zarr + images_array = np.stack(img_arrays, axis=0) + zarr_images = root.create_dataset( + "images", shape=images_array.shape, chunks=(1, 150, 150, 3), dtype=np.uint8 + ) + zarr_images[:] = images_array + + # Create metadata as a separate parquet file in the batch folder + metadata_path = batch_folder / "images_metadata.parquet" + metadata_df = pd.DataFrame({"original_filename": filenames}) + metadata_df.to_parquet(metadata_path, index=False) + + batch_folders.append(str(zarr_path)) + + return batch_folders, str(tmp_path) + + @pytest.fixture def mixed_format_images(tmp_path): - """Create a directory with sample images in different formats (jpg, png, tif, tiff)""" + """Create a directory with sample images in different formats (jpg, png, tiff)""" img_dir = tmp_path / "mixed_formats" img_dir.mkdir() - # Create images in different formats + # Create images in different formats - only use supported extensions image_paths = [] - formats = {"jpg": "JPEG", "png": "PNG", "tif": "TIFF", "tiff": "TIFF"} + formats = {"jpg": "JPEG", "png": "PNG", "tiff": "TIFF"} # Create a simple test image base_img = np.zeros((150, 150, 3), dtype=np.uint8) @@ -186,6 +263,232 @@ def mixed_format_images(tmp_path): return image_paths, str(img_dir) +@pytest.fixture +def test_cutana(tmp_path): + """Create a directory with sample FITS files for cutana streaming.""" + data_dir = tmp_path / "cutana_test" + data_dir.mkdir() + + img_size = 512 + ra_center, dec_center = 150.14, 2.34 + tile_id = "102018211" + num_sources = 10 + + wcs = WCS(naxis=2) + pixel_scale = 0.1 / 3600.0 + wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] + wcs.wcs.crval = [ra_center, dec_center] + wcs.wcs.crpix = [img_size / 2, img_size / 2] + wcs.wcs.cd = [[-pixel_scale, 0], [0, pixel_scale]] + wcs.wcs.cunit = ["deg", "deg"] + wcs.wcs.radesys = "ICRS" + wcs.wcs.equinox = 2000.0 + + img_data = np.random.normal(0, 0.005, (img_size, img_size)).astype(np.float32) + primary_hdu = fits.PrimaryHDU(img_data) + header = primary_hdu.header + header.update(wcs.to_header()) + header["TELESCOP"] = "EUCLID" + header["INSTRUME"] = "VIS" + header["TILEID"] = tile_id + header["BUNIT"] = "electron/s" + header["DATATYPE"] = "BGSUB-MOSAIC" + header["EXPTIME"] = 565.0 + header["GAIN"] = 3.1 + header["READNOIS"] = 4.2 + header["MAGZERO"] = 24.6 + + fits_path = ( + data_dir / f"EUC_MER_BGSUB-MOSAIC-VIS_TILE{tile_id}-ACBD03_20251124T100053.096Z_00.00.fits" + ) + primary_hdu.writeto(fits_path, overwrite=True) + + field_of_view_deg = img_size * pixel_scale + fov_margin = field_of_view_deg * 0.45 + + ra_values = ra_center + np.random.uniform(-fov_margin, fov_margin, num_sources) + dec_values = dec_center + np.random.uniform(-fov_margin, fov_margin, num_sources) + object_ids = (np.arange(1, num_sources + 1) + int(tile_id) * 1000000).astype(np.int64) + + cat_table = Table() + cat_table["OBJECT_ID"] = object_ids + cat_table["RIGHT_ASCENSION"] = ra_values + cat_table["DECLINATION"] = dec_values + cat_table["RIGHT_ASCENSION_PSF_FITTING"] = ra_values + cat_table["DECLINATION_PSF_FITTING"] = dec_values + + catalog_fits = ( + data_dir / f"EUC_MER_FINAL-CAT_TILE{tile_id}-CC66F6_20251124T100053.096Z_00.00.fits" + ) + primary_hdu_cat = fits.PrimaryHDU() + table_hdu = fits.BinTableHDU(cat_table, name="EUC_MER__FINAL_CATALOG") + hdul = fits.HDUList([primary_hdu_cat, table_hdu]) + hdul.writeto(catalog_fits, overwrite=True) + + csv_directory = data_dir / "csv" + csv_directory.mkdir() + csv_path = csv_directory / "mock_sources_malformed.csv" + + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["SourceID", "RA", "Dec", "diameter_pixel", "fits_file_paths"]) + for i, (ra, dec) in enumerate(zip(ra_values, dec_values)): + writer.writerow( + [ + f"MockSource_{object_ids[i]}", + ra, + dec, + np.random.randint(100, 250), + str([str(fits_path)]), + ] + ) + + return str(csv_path) + + +@pytest.fixture +def test_cutana_parquet(tmp_path): + """Create a directory with sample FITS files and parquet catalogue for cutana streaming.""" + data_dir = tmp_path / "cutana_parquet_test" + data_dir.mkdir() + + img_size = 512 + ra_center, dec_center = 150.14, 2.34 + tile_id = "102018212" + num_sources = 10 + + wcs = WCS(naxis=2) + pixel_scale = 0.1 / 3600.0 + wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] + wcs.wcs.crval = [ra_center, dec_center] + wcs.wcs.crpix = [img_size / 2, img_size / 2] + wcs.wcs.cd = [[-pixel_scale, 0], [0, pixel_scale]] + wcs.wcs.cunit = ["deg", "deg"] + wcs.wcs.radesys = "ICRS" + wcs.wcs.equinox = 2000.0 + + img_data = np.random.normal(0, 0.005, (img_size, img_size)).astype(np.float32) + primary_hdu = fits.PrimaryHDU(img_data) + header = primary_hdu.header + header.update(wcs.to_header()) + header["TELESCOP"] = "EUCLID" + header["INSTRUME"] = "VIS" + header["TILEID"] = tile_id + header["BUNIT"] = "electron/s" + header["DATATYPE"] = "BGSUB-MOSAIC" + header["EXPTIME"] = 565.0 + header["GAIN"] = 3.1 + header["READNOIS"] = 4.2 + header["MAGZERO"] = 24.6 + + fits_path = ( + data_dir / f"EUC_MER_BGSUB-MOSAIC-VIS_TILE{tile_id}-ACBD03_20251124T100053.096Z_00.00.fits" + ) + primary_hdu.writeto(fits_path, overwrite=True) + + field_of_view_deg = img_size * pixel_scale + fov_margin = field_of_view_deg * 0.45 + + ra_values = ra_center + np.random.uniform(-fov_margin, fov_margin, num_sources) + dec_values = dec_center + np.random.uniform(-fov_margin, fov_margin, num_sources) + object_ids = (np.arange(1, num_sources + 1) + int(tile_id) * 1000000).astype(np.int64) + + # Create parquet catalogue + parquet_directory = data_dir / "parquet" + parquet_directory.mkdir() + parquet_path = parquet_directory / "mock_sources.parquet" + + import pandas as pd + + df = pd.DataFrame( + { + "SourceID": [f"MockSource_{oid}" for oid in object_ids], + "RA": ra_values, + "Dec": dec_values, + "diameter_pixel": np.random.randint(100, 250, num_sources), + "fits_file_paths": [str([str(fits_path)]) for _ in range(num_sources)], + } + ) + df.to_parquet(parquet_path, index=False) + + return str(parquet_path) + + +@pytest.fixture +def test_cutana_malformed_header(tmp_path): + """Create a directory with sample CSV file with malformed header.""" + data_dir = tmp_path / "cutana_malformed_test" + data_dir.mkdir() + + tile_id = "102018211" + + fits_path = ( + data_dir / f"EUC_MER_BGSUB-MOSAIC-VIS_TILE{tile_id}-ACBD03_20251124T100053.096Z_00.00.fits" + ) + + csv_directory = data_dir / "csv" + csv_directory.mkdir() + csv_path = csv_directory / "mock_sources_malformed.csv" + + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + # Write bad header, should report missing headers and raise RuntimeError + writer.writerow( + [ + "SourceID_MALFORMED", + "RA_MALFORMED", + "Dec", + "diameter_pixel_MALFORMED", + "fits_file_paths", + ] + ) + for i in range(10): + writer.writerow( + [ + f"MockSource_{i}", + (np.random.rand() - 0.5) * 2 * 5 + 150, + np.random.rand() - 0.5 + 2, + np.random.randint(100, 250), + str([str(fits_path)]), + ] + ) + + return str(csv_directory) + + +@pytest.fixture +def test_cutana_missing_images(tmp_path): + """Create a directory with sample CSV file with mcorrect header and missing images.""" + data_dir = tmp_path / "cutana_missing_test" + data_dir.mkdir() + + tile_id = "102018211" + + fits_path = ( # Fits in CSV, but not actually saved to disk (missing) + data_dir / f"EUC_MER_BGSUB-MOSAIC-VIS_TILE{tile_id}-ACBD03_20251124T100053.096Z_00.00.fits" + ) + + csv_directory = data_dir / "csv" + csv_directory.mkdir() + csv_path = csv_directory / "mock_sources_malformed.csv" + + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["SourceID", "RA", "Dec", "diameter_pixel", "fits_file_paths"]) + for i in range(10): + writer.writerow( + [ + f"MockSource_{i}", + (np.random.rand() - 0.5) * 2 * 5 + 150, + np.random.rand() - 0.5 + 2, + np.random.randint(100, 250), + str([str(fits_path)]), + ] + ) + + return str(csv_directory) + + def test_evaluate_files(test_config, sample_images, tmp_path): """Test evaluation of individual files.""" image_paths = [] @@ -219,6 +522,99 @@ def test_evaluate_images_in_zarr(test_config, test_zarr): assert imgs.shape[0] == 10 +def test_evaluate_images_cutana(test_config, test_cutana): + """Test evaluation of images via cutana streaming with CSV catalogue.""" + scores, filenames, imgs = evaluate_images_from_cutana(test_cutana, test_config, batch_size=5) + assert len(scores) == 10 + assert len(filenames) == 10 + assert imgs.shape[0] == 10 + + +def test_evaluate_images_cutana_parquet(test_config, test_cutana_parquet): + """Test evaluation of images via cutana streaming with parquet catalogue.""" + scores, filenames, imgs = evaluate_images_from_cutana( + test_cutana_parquet, test_config, batch_size=5 + ) + assert len(scores) == 10 + assert len(filenames) == 10 + assert imgs.shape[0] == 10 + + +def test_prediction_file_type_cutana_malformed_header(test_config, test_cutana_malformed_header): + """Test for meaningful exception when streaming from cutana and csv files have malformed headers.""" + from anomaly_match.pipeline.session import Session + from anomaly_match.utils.get_default_cfg import get_default_cfg + + cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] + cfg.prediction_search_dir = test_cutana_malformed_header + cfg.model_path = test_config.model_path + + session = Session(cfg) + + with pytest.warns( + RuntimeWarning, + match=r"File .* did not pass cutana compatibility check and will be skipped \(.*\)", + ): + with pytest.raises(RuntimeError, match="All found files are not compatible with cutana"): + session.evaluate_all_images() + + +def test_prediction_file_type_cutana_missing_images(test_config, test_cutana_missing_images): + """Test for meaningful exception when streaming from cutana and images are missing.""" + from anomaly_match.pipeline.session import Session + from anomaly_match.utils.get_default_cfg import get_default_cfg + + cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] + cfg.prediction_search_dir = test_cutana_missing_images + cfg.model_path = test_config.model_path + + session = Session(cfg) + + with pytest.warns( + RuntimeWarning, + match=r"File .* did not pass cutana compatibility check and will be skipped \(.*\)", + ): + with pytest.raises(RuntimeError, match="All found files are not compatible with cutana"): + session.evaluate_all_images() + + +def test_stream_file_type_detection_csv_and_parquet(tmp_path): + """Test that CSV and parquet files are correctly detected as stream type for cutana.""" + # Create test CSV file + csv_file = tmp_path / "test_catalogue.csv" + csv_file.write_text("SourceID,RA,Dec\n1,0.0,0.0\n") + + # Create test parquet file + parquet_file = tmp_path / "test_catalogue.parquet" + import pandas as pd + + pd.DataFrame({"SourceID": [1], "RA": [0.0], "Dec": [0.0]}).to_parquet(parquet_file) + + # Test file type detection via extension map (same logic as run_pipeline) + import os + + extension_map = { + ".h5": "hdf5", + ".hdf5": "hdf5", + ".zarr": "zarr", + ".txt": "image", + ".parquet": "stream", + ".csv": "stream", + } + + # Test CSV detection + _, csv_ext = os.path.splitext(str(csv_file).lower()) + assert csv_ext == ".csv" + assert extension_map.get(csv_ext) == "stream" + + # Test parquet detection + _, parquet_ext = os.path.splitext(str(parquet_file).lower()) + assert parquet_ext == ".parquet" + assert extension_map.get(parquet_ext) == "stream" + + def test_predictions_output(test_config, test_hdf5): """Test that predictions are saved correctly.""" evaluate_images_in_hdf5(test_hdf5, test_config) @@ -300,23 +696,25 @@ def mock_save_results(cfg, all_scores, all_imgs, all_filenames, top_n): assert len(scores) >= len(image_paths), "Not enough scores returned for all image formats" -def test_read_and_resize_multiple_formats(test_config, mixed_format_images): - """Test the read_and_resize_image function can handle multiple formats.""" - from prediction_process import read_and_resize_image +def test_load_and_preprocess_multiple_formats(test_config, mixed_format_images): + """Test the load_and_preprocess function can handle multiple formats.""" + from prediction_process import load_and_preprocess + from anomaly_match.image_processing.transforms import get_prediction_transforms image_paths, _ = mixed_format_images + transform = get_prediction_transforms() for path in image_paths: ext = os.path.splitext(path)[1].lower() - image = read_and_resize_image(path, cfg=test_config) + # load_and_preprocess now returns (filepath, numpy_image) + filename, numpy_image = load_and_preprocess((path, test_config)) + + # Apply transform to get tensor (transform is now applied on main thread) + image = transform(numpy_image) # Check image shape and type - assert image.shape == ( - test_config.size[0], - test_config.size[1], - 3, - ), f"Image resizing failed for {ext}" - assert image.dtype == np.uint8, f"Image type incorrect for {ext}" + assert isinstance(image, torch.Tensor), f"Expected tensor output for {ext}" + assert image.shape[0] == 3, f"Expected 3 channels for {ext}" # RGB channels class MockModel(torch.nn.Module): @@ -402,14 +800,25 @@ def mock_process_batch(model, images): # Load final results output_csv = os.path.join(test_config.output_dir, f"{test_config.save_file}_top{top_n}.csv") + output_npy = os.path.join(test_config.output_dir, f"{test_config.save_file}_top{top_n}.npy") final_results = pd.read_csv(output_csv) final_scores = final_results["Score"].values + final_images = np.load(output_npy) # Check if we got the highest scores from the second batch assert len(final_scores) == top_n assert np.all(final_scores >= 0.85) # All top scores should be from second batch assert np.all(final_scores <= 0.95) # Maximum probability capped at 0.95 + # CRITICAL: Verify that the image array size matches the CSV + assert len(final_images) == len(final_scores), ( + f"Image array size ({len(final_images)}) doesn't match CSV size ({len(final_scores)}). " + f"This indicates a bug in image accumulation logic." + ) + assert ( + final_images.shape[0] == top_n + ), f"Expected {top_n} images, got {final_images.shape[0]}" + def test_all_predictions_accumulation(test_config, monkeypatch): """Test that all predictions are correctly saved when processing multiple batches.""" @@ -618,12 +1027,11 @@ def test_image_directory_processing(test_config, mixed_format_images): # Create a temporary file list with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as file_list: - # List all image files and write them to the file + # List all image files and write them to the file - only supported extensions image_paths = ( list(Path(directory_path).glob("*.jpg")) + list(Path(directory_path).glob("*.jpeg")) + list(Path(directory_path).glob("*.png")) - + list(Path(directory_path).glob("*.tif")) + list(Path(directory_path).glob("*.tiff")) ) @@ -830,7 +1238,7 @@ def test_prediction_file_type_zarr(test_config, monkeypatch, test_zarr): cfg.save_file = "test_zarr_type" cfg.output_dir = os.path.join(os.path.dirname(test_zarr), "output") cfg.model_path = test_config.model_path - cfg.size = [150, 150] + cfg.normalisation.image_size = [150, 150] called_processes = [] @@ -904,10 +1312,13 @@ def test_zarr_image_processing_consistency(test_config, test_zarr): # Verify the output dimensions and type assert processed_image.shape == ( - test_config.size[0], - test_config.size[1], + test_config.normalisation.image_size[0], + test_config.normalisation.image_size[1], 3, - ), f"Wrong image shape: {processed_image.shape}, expected {(test_config.size[0], test_config.size[1], 3)}" + ), ( + f"Wrong image shape: {processed_image.shape}, " + f"expected {(test_config.normalisation.image_size[0], test_config.normalisation.image_size[1], 3)}" + ) assert processed_image.dtype == np.uint8, f"Wrong dtype: {processed_image.dtype}" # Verify the image is not all zeros (should have some content) @@ -1015,3 +1426,152 @@ def test_zarr_auto_detection_basic(test_config, multiple_test_zarr): max(file_type_counts, key=file_type_counts.get) if file_type_counts else "zarr" ) assert detected_type == "zarr" + + +def test_zarr_batch_folders_detection(test_config, zarr_batch_folders): + """Test auto-detection for zarr batch folders with images.zarr subdirectories.""" + batch_folders, batch_dir = zarr_batch_folders + + from anomaly_match.pipeline.session import Session + + try: + session = Session.__new__(Session) + session.cfg = test_config + + # Test auto-detection method + detected_type = session._auto_detect_prediction_file_type(batch_dir) + + # Should detect zarr file type + assert detected_type == "zarr" + except Exception: + # Manual test + import os + + file_type_counts = {} + for filename in os.listdir(batch_dir): + file_path = os.path.join(batch_dir, filename) + if os.path.isdir(file_path): + # Check for batch folders containing images.zarr subdirectory + if os.path.exists(os.path.join(file_path, "images.zarr")): + file_type_counts["zarr"] = file_type_counts.get("zarr", 0) + 1 + + detected_type = ( + max(file_type_counts, key=file_type_counts.get) if file_type_counts else "image" + ) + assert detected_type == "zarr" + + +def test_zarr_batch_folders_processing(test_config, zarr_batch_folders): + """Test processing multiple zarr batch folders.""" + batch_folders, batch_dir = zarr_batch_folders + + # Test each batch folder individually + all_scores = [] + all_filenames = [] + + for batch_folder in batch_folders: + scores, filenames, imgs = evaluate_images_in_zarr(batch_folder, test_config, top_n=100) + all_scores.extend(scores) + all_filenames.extend(filenames) + + # Should have processed all batch folders successfully + assert len(all_scores) > 0 + assert len(all_filenames) > 0 + + # Verify that filenames from different batches are present + batch_prefixes = set() + for filename in all_filenames: + if isinstance(filename, bytes): + filename_str = filename.decode("utf-8") + elif isinstance(filename, np.ndarray): + filename_str = str(filename.item()) if filename.size == 1 else str(filename) + else: + filename_str = str(filename) + + # Extract batch prefix (batch_000, batch_001, etc.) + if filename_str.startswith("batch_"): + parts = filename_str.split("_") + if len(parts) >= 2: + batch_prefix = parts[0] + "_" + parts[1] + batch_prefixes.add(batch_prefix) + + # Should have processed multiple batches + assert len(batch_prefixes) >= 2 + + +def test_zarr_batch_metadata_loading(test_config, zarr_batch_folders): + """Test that metadata is correctly loaded from batch folders.""" + batch_folders, batch_dir = zarr_batch_folders + + # Test the first batch folder + first_batch = batch_folders[0] + scores, filenames, imgs = evaluate_images_in_zarr(first_batch, test_config, top_n=100) + + # Verify filenames are loaded from metadata + assert len(filenames) > 0 + # Filenames should not be generic "image_000000" format + assert not all(f.startswith("image_") for f in filenames) + # Should contain batch identifier + assert any("batch_" in str(f) for f in filenames) + + +def test_zarr_fallback_filenames_have_prefix(tmp_path, test_config): + """Test that when metadata loading fails, fallback filenames include zarr prefix to avoid collisions.""" + import zarr + + # Create two zarr stores WITHOUT metadata to trigger fallback filename generation + for batch_idx in range(2): + batch_folder = tmp_path / f"batch_{batch_idx:03d}" + batch_folder.mkdir() + zarr_path = batch_folder / "images.zarr" + + # Create minimal zarr store + root = zarr.open_group(str(zarr_path), mode="w") + + # Create a simple image array + img_array = np.ones((5, 64, 64, 3), dtype=np.uint8) * (50 + batch_idx * 50) + zarr_images = root.create_dataset( + "images", shape=img_array.shape, chunks=(1, 64, 64, 3), dtype=np.uint8 + ) + zarr_images[:] = img_array + + # Intentionally NO metadata file to trigger fallback + + # Process both batches + batch_filenames = [] + for batch_idx in range(2): + zarr_path = tmp_path / f"batch_{batch_idx:03d}" / "images.zarr" + + # Use a unique output dir for each batch to avoid accumulation + batch_config = test_config.copy() + batch_config.output_dir = str(tmp_path / f"output_{batch_idx}") + os.makedirs(batch_config.output_dir, exist_ok=True) + + scores, filenames, imgs = evaluate_images_in_zarr(str(zarr_path), batch_config, top_n=100) + + # Convert filenames to strings + filenames_str = [] + for filename in filenames: + if isinstance(filename, bytes): + filename_str = filename.decode("utf-8") + elif isinstance(filename, np.ndarray): + filename_str = str(filename.item()) if filename.size == 1 else str(filename) + else: + filename_str = str(filename) + filenames_str.append(filename_str) + + batch_filenames.append(filenames_str) + + # Verify fallback filenames have zarr prefix + for batch_idx, filenames in enumerate(batch_filenames): + sample_filename = filenames[0] + # Should have format: __image_000000 + assert ( + "__image_" in sample_filename + ), f"Batch {batch_idx} fallback filename doesn't have expected format. Got: {sample_filename}" + + # Verify no collision between batches + set_0 = set(batch_filenames[0]) + set_1 = set(batch_filenames[1]) + overlap = set_0 & set_1 + assert len(overlap) == 0, f"Found filename collision between batches: {overlap}" diff --git a/tests/test_run_label_migration.py b/tests/test_run_label_migration.py index d9b2a81..dfc2fce 100644 --- a/tests/test_run_label_migration.py +++ b/tests/test_run_label_migration.py @@ -67,6 +67,8 @@ def mock_config(self): config = Mock() config.normalisation_method = "min_max" config.model_path = "test_model.pth" + # Explicitly set fitsbolt_cfg to None to avoid pickling issues with Mock + config.fitsbolt_cfg = None return config def test_save_run_basic(self, session_io, mock_model, temp_dir): @@ -162,17 +164,6 @@ def test_save_labels_with_session_tracker(self, session_io, session_tracker, tem expected_df["iteration"] = -1 # Default iteration for initial data pd.testing.assert_frame_equal(session_tracker.labeled_data_df, expected_df) - def test_session_tracker_save_training_run(self, session_tracker): - """Test SessionTracker.save_training_run method.""" - mock_config = Mock() - model_path = "/path/to/model.pth" - - session_tracker.save_training_run(model_path, mock_config) - - # Check that session was updated - assert len(session_tracker.session_iterations) == 1 - assert session_tracker.session_iterations[0].model_state_path == model_path - def test_session_tracker_update_labeled_data(self, session_tracker): """Test SessionTracker.update_labeled_data method.""" labeled_data = pd.DataFrame( diff --git a/tests/test_toml_config.py b/tests/test_toml_config.py index 583d524..b057a57 100644 --- a/tests/test_toml_config.py +++ b/tests/test_toml_config.py @@ -14,7 +14,7 @@ save_config_toml, _convert_enum_to_string, ) -from anomaly_match.image_processing.NormalisationMethod import NormalisationMethod +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod class TestTOMLConfigUtils: @@ -119,21 +119,21 @@ def test_config_types_and_structure(self): assert isinstance(config, DotMap) # Check for required keys that should exist in default config + # Note: normalisation.image_size has no default - user must set it explicitly required_keys = [ - "size", "net", "batch_size", "name", - ] # Removed 'device' as it doesn't exist + ] for key in required_keys: assert key in config, f"Missing required key: {key}" # Check types - assert isinstance(config.size, list) - assert len(config.size) == 2 assert isinstance(config.net, str) assert isinstance(config.batch_size, int) assert isinstance(config.name, str) - # Check that size contains valid dimensions - assert all(isinstance(dim, int) and dim > 0 for dim in config.size) + # Verify image_size is NOT in default config (user must set it) + assert ( + "image_size" not in config.normalisation + ), "image_size should not have a default value" diff --git a/tests/ui_test.py b/tests/ui_test.py index 9a7d004..bbd3070 100644 --- a/tests/ui_test.py +++ b/tests/ui_test.py @@ -6,7 +6,7 @@ # the terms contained in the file 'LICENCE.txt'. import pytest import ipywidgets as widgets -from anomaly_match.ui.Widget import Widget +from anomaly_match.ui.Widget import Widget, shorten_filename from anomaly_match.pipeline.session import Session import anomaly_match as am import os @@ -18,6 +18,59 @@ matplotlib.use("Agg") # Prevent matplotlib windows from opening +class TestShortenFilename: + """Tests for the shorten_filename helper function.""" + + def test_short_filename_unchanged(self): + """Filenames within max length should remain unchanged.""" + assert shorten_filename("short.fits", max_length=25) == "short.fits" + assert shorten_filename("image.jpg", max_length=25) == "image.jpg" + + def test_long_filename_shortened(self): + """Long filenames should be shortened to max_length.""" + long_name = "very_long_filename_that_exceeds_limit.fits" + result = shorten_filename(long_name, max_length=25) + assert len(result) <= 25 + assert result.endswith(".fits") + assert "..." in result + + def test_filename_with_multiple_dots(self): + """Filenames with multiple dots should preserve only the extension.""" + name = "image.2024.01.15.observation.fits" + result = shorten_filename(name, max_length=25) + assert len(result) <= 25 + assert result.endswith(".fits") + assert "..." in result + + def test_filename_without_extension(self): + """Filenames without extension should still be shortened correctly.""" + name = "very_long_filename_without_any_extension" + result = shorten_filename(name, max_length=25) + assert len(result) <= 25 + assert "..." in result + + def test_exact_max_length(self): + """Filename exactly at max_length should be unchanged.""" + name = "exactly_25_chars_long.fit" + assert len(name) == 25 + assert shorten_filename(name, max_length=25) == name + + def test_very_short_max_length(self): + """Very short max_length should still produce valid output.""" + name = "some_filename.fits" + result = shorten_filename(name, max_length=10) + assert len(result) <= 10 + assert "..." in result + + def test_preserves_start_and_end(self): + """Shortened name should contain parts of the original start and end.""" + name = "START_middle_content_END.fits" + result = shorten_filename(name, max_length=20) + assert result.startswith("START") + # Should contain some part of the end before the extension + assert "END" in result or "..." in result + + @pytest.fixture(scope="session") def base_config(): out = widgets.Output( @@ -25,20 +78,15 @@ def base_config(): border="1px solid white", height="400px", background_color="black", overflow="auto" ), ) - progress_bar = widgets.FloatProgress( - value=0.0, - min=0.0, - max=1.0, - ) cfg = am.get_default_cfg() am.set_log_level("debug", cfg) cfg.data_dir = "tests/test_data/" - cfg.size = [64, 64] + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 cfg.num_train_iter = 2 cfg.test_ratio = 0.5 cfg.output_dir = "tests/test_output" - cfg.progress_bar = progress_bar cfg.prediction_search_dir = "tests/test_data/" # Set a default search directory cfg.top_N = 10 return cfg, out @@ -85,7 +133,7 @@ def test_ui_initialization(self, ui_widget): def test_normalization_dropdown(self, ui_widget): # Get the initial normalization method - initial_method = ui_widget.session.cfg.normalisation_method + initial_method = ui_widget.session.cfg.normalisation.normalisation_method # Get the dropdown options and find a different method dropdown_options = ui_widget.ui["normalisation_dropdown"].options @@ -100,20 +148,20 @@ def test_normalization_dropdown(self, ui_widget): ui_widget.ui["normalisation_dropdown"].value = new_method # Assert that the session config was updated - assert ui_widget.session.cfg.normalisation_method == new_method + assert ui_widget.session.cfg.normalisation.normalisation_method == new_method class TestUINavigation: def test_next_image(self, ui_widget): - initial_index = ui_widget.current_index + initial_index = ui_widget.preview.current_index ui_widget.next_image() - assert ui_widget.current_index == initial_index + 1 + assert ui_widget.preview.current_index == initial_index + 1 def test_previous_image(self, ui_widget): ui_widget.next_image() # Move to next image first - initial_index = ui_widget.current_index + initial_index = ui_widget.preview.current_index ui_widget.previous_image() - assert ui_widget.current_index == initial_index - 1 + assert ui_widget.preview.current_index == initial_index - 1 class TestUISorting: @@ -142,22 +190,22 @@ def test_sort_by_median(self, ui_widget): class TestUIImageProcessing: def test_toggle_invert_image(self, ui_widget): - initial_invert_state = ui_widget.invert + initial_invert_state = ui_widget.preview.invert ui_widget.toggle_invert_image() - assert ui_widget.invert != initial_invert_state + assert ui_widget.preview.invert != initial_invert_state def test_toggle_unsharp_mask(self, ui_widget): - initial_unsharp_mask_state = ui_widget.unsharp_mask_applied + initial_unsharp_mask_state = ui_widget.preview.unsharp_mask_applied ui_widget.toggle_unsharp_mask() - assert ui_widget.unsharp_mask_applied != initial_unsharp_mask_state + assert ui_widget.preview.unsharp_mask_applied != initial_unsharp_mask_state def test_adjust_brightness_contrast(self, ui_widget): initial_brightness = ui_widget.ui["brightness_slider"].value initial_contrast = ui_widget.ui["contrast_slider"].value ui_widget.ui["brightness_slider"].value = initial_brightness + 0.1 ui_widget.ui["contrast_slider"].value = initial_contrast + 0.1 - assert ui_widget.brightness == initial_brightness + 0.1 - assert ui_widget.contrast == initial_contrast + 0.1 + assert ui_widget.preview.brightness == initial_brightness + 0.1 + assert ui_widget.preview.contrast == initial_contrast + 0.1 class TestUIModelOperations: