diff --git a/README.md b/README.md index 7b76f92..233b7b0 100644 --- a/README.md +++ b/README.md @@ -167,3 +167,7 @@ Available Operations are introspected from the `data_operations` directory and m #### Delete Operation from a Datasession `DELETE /api/datasessions/datasession_id/operations/operation_id/` + + +### ACKNOWLEDGMENT +The centroiding logic in this project was adapted from AstroImageJ’s implementation. We appreciate the AstroImageJ project and its contributors for the original work that informed this code. diff --git a/datalab/datalab_session/analysis/centroiding.py b/datalab/datalab_session/analysis/centroiding.py new file mode 100644 index 0000000..57ddfce --- /dev/null +++ b/datalab/datalab_session/analysis/centroiding.py @@ -0,0 +1,465 @@ +from dataclasses import dataclass +import logging +import math +from typing import TYPE_CHECKING + +import numpy as np +from astropy.wcs import WCS, WcsError + +from datalab.datalab_session.exceptions import ClientAlertException +from datalab.datalab_session.utils.file_utils import get_hdu, scale_points +from datalab.datalab_session.utils.filecache import FileCache + +if TYPE_CHECKING: + from django.contrib.auth.models import User + + +log = logging.getLogger() +log.setLevel(logging.INFO) + +PIXELCENTER = 0.5 + + +@dataclass(frozen=True) +class PlaneModel: + c0: float + c1: float + c2: float + + def value_at(self, x: float, y: float) -> float: + return self.c0 + self.c1 * x + self.c2 * y + + +@dataclass(frozen=True) +class BackgroundModel: + mean: float + peak: float + plane: PlaneModel | None = None + + +@dataclass(frozen=True) +class CentroidResult: + x: float + y: float + background: float + peak: float + success: bool = True + message: str | None = None + + +def _pixel(image: np.ndarray, x: int, y: int) -> float: + if y < 0 or y >= image.shape[0] or x < 0 or x >= image.shape[1]: + return math.nan + return float(image[y, x]) + +## Finds the maximum pixel value within a circular region around the center. +def _source_max(image: np.ndarray, x_center: float, y_center: float, radius: float) -> float: + radius2 = radius * radius + i1 = int(x_center - radius) + i2 = int(x_center + radius) + j1 = int(y_center - radius) + j2 = int(y_center + radius) + + source_max = -math.inf + for j in range(j1, j2 + 1): + dj = j - y_center + PIXELCENTER + for i in range(i1, i2 + 1): + di = i - x_center + PIXELCENTER + if di * di + dj * dj <= radius2: + value = _pixel(image, i, j) + if not math.isnan(value) and value > source_max: + source_max = value + return source_max + + +def _fit_plane(points: list[tuple[float, float, float]]) -> PlaneModel | None: + if len(points) < 4: + return None + + sum_1 = float(len(points)) + sum_x = sum(x for x, _, _ in points) + sum_y = sum(y for _, y, _ in points) + sum_xx = sum(x * x for x, _, _ in points) + sum_yy = sum(y * y for _, y, _ in points) + sum_xy = sum(x * y for x, y, _ in points) + sum_z = sum(z for _, _, z in points) + sum_xz = sum(x * z for x, _, z in points) + sum_yz = sum(y * z for _, y, z in points) + + matrix = [ + [sum_1, sum_x, sum_y, sum_z], + [sum_x, sum_xx, sum_xy, sum_xz], + [sum_y, sum_xy, sum_yy, sum_yz], + ] + + ## Solve the linear system using Gaussian elimination with partial pivoting. + for pivot in range(3): + pivot_row = max(range(pivot, 3), key=lambda row: abs(matrix[row][pivot])) + if abs(matrix[pivot_row][pivot]) < 1e-12: + return None + if pivot_row != pivot: + matrix[pivot], matrix[pivot_row] = matrix[pivot_row], matrix[pivot] + + pivot_value = matrix[pivot][pivot] + for col in range(pivot, 4): + matrix[pivot][col] /= pivot_value + + for row in range(3): + if row == pivot: + continue + factor = matrix[row][pivot] + if factor == 0.0: + continue + for col in range(pivot, 4): + matrix[row][col] -= factor * matrix[pivot][col] + + return PlaneModel(matrix[0][3], matrix[1][3], matrix[2][3]) + + +def _background( + image: np.ndarray, + x_center: float, + y_center: float, + radius: float, + r_back1: float, + r_back2: float, + remove_background_stars: bool, + use_plane_background: bool, +) -> BackgroundModel: + source_max = _source_max(image, x_center, y_center, radius) + if r_back2 <= r_back1: + return BackgroundModel(0.0, source_max) + + r12 = r_back1 * r_back1 + r22 = r_back2 * r_back2 + i1 = int(x_center - r_back2) + i2 = int(x_center + r_back2) + j1 = int(y_center - r_back2) + j2 = int(y_center + r_back2) + + annulus_pixels: list[tuple[float, float, float]] = [] + if remove_background_stars: + for j in range(j1, j2 + 1): + dj = j - y_center + PIXELCENTER + for i in range(i1, i2 + 1): + di = i - x_center + PIXELCENTER + radius2 = di * di + dj * dj + if r12 <= radius2 <= r22: + value = _pixel(image, i, j) + if not math.isnan(value): + annulus_pixels.append((di, dj, value)) + + back_mean = 0.0 + back2_mean = 0.0 + previous_back_mean = 0.0 + for iteration in range(9): + back_stdev = math.sqrt(max(0.0, back2_mean - back_mean * back_mean)) + lower = back_mean - 2.0 * back_stdev + upper = back_mean + 2.0 * back_stdev + clipped = [ + value + for _, _, value in annulus_pixels + if iteration == 0 or (lower <= value <= upper) + ] + if clipped: + back_mean = sum(clipped) / len(clipped) + back2_mean = sum(value * value for value in clipped) / len(clipped) + if abs(previous_back_mean - back_mean) < 0.1: + break + previous_back_mean = back_mean + else: + back_mean = 0.0 + back2_mean = 0.0 + + back_stdev = math.sqrt(max(0.0, back2_mean - back_mean * back_mean)) + lower = back_mean - 2.0 * back_stdev + upper = back_mean + 2.0 * back_stdev + + kept: list[tuple[float, float, float]] = [] + for j in range(j1, j2 + 1): + dj = j - y_center + PIXELCENTER + for i in range(i1, i2 + 1): + di = i - x_center + PIXELCENTER + radius2 = di * di + dj * dj + if r12 <= radius2 <= r22: + value = _pixel(image, i, j) + if math.isnan(value): + continue + if not remove_background_stars or (lower <= value <= upper): + kept.append((di, dj, value)) + + background = sum(value for _, _, value in kept) / len(kept) if kept else 0.0 + plane = _fit_plane(kept) if use_plane_background else None + return BackgroundModel(background, source_max - background, plane) + + +def _background_value( + background_model: BackgroundModel, + x_center: float, + y_center: float, + i: int, + j: int, +) -> float: + if background_model.plane is None: + return background_model.mean + return background_model.plane.value_at( + i - x_center + PIXELCENTER, + j - y_center + PIXELCENTER, + ) + + +def _failed_centroid( + x: float, + y: float, + background: float, + peak: float, + message: str, +) -> CentroidResult: + log.warning(f"Centroiding failed: {message}") + return CentroidResult(x, y, background, peak, success=False, message=message) + + +def centroid( + image: np.ndarray, + x_click: float, + y_click: float, + radius: float, + r_back1: float, + r_back2: float, + *, + find_centroid: bool = True, + remove_background_stars: bool = True, + use_plane_background: bool = False, +) -> CentroidResult: + image = np.asarray(image, dtype=float) + x_center = x_click + y_center = y_click + radius = max(radius, 3.0) + width = int(2.0 * radius) + height = width + + i1 = int(x_click - radius) + i2 = i1 + width + j1 = int(y_click - radius) + j2 = j1 + height + + x_start = x_center + y_start = y_center + background_model = _background( + image, + x_center, + y_center, + radius, + r_back1, + r_back2, + remove_background_stars, + use_plane_background, + ) + + still_moving = True + iteration = 100 if find_centroid else 0 + while still_moving and iteration > 0: + x_delta = 0.0 + y_delta = 0.0 + total_signal = 0.0 + samples = 0 + + for j in range(j1, j2 + 1): + for i in range(i1, i2 + 1): + value = _pixel(image, i, j) + if not math.isnan(value): + total_signal += value - _background_value(background_model, x_center, y_center, i, j) + samples += 1 + + if samples == 0: + return _failed_centroid( + x_start, + y_start, + background_model.mean, + background_model.peak, + "No valid pixels in centroid box.", + ) + + i_bar = total_signal / (i2 - i1 + 1) + j_bar = total_signal / (j2 - j1 + 1) + + weight_i = 0.0 + for i in range(i1, i2 + 1): + column_signal = 0.0 + di = i - x_center + PIXELCENTER + for j in range(j1, j2 + 1): + value = _pixel(image, i, j) + if not math.isnan(value): + column_signal += value - _background_value(background_model, x_center, y_center, i, j) + delta = column_signal - i_bar + if delta > 0.0: + weight_i += delta + x_delta += delta * di + + weight_j = 0.0 + for j in range(j1, j2 + 1): + row_signal = 0.0 + dj = j - y_center + PIXELCENTER + for i in range(i1, i2 + 1): + value = _pixel(image, i, j) + if not math.isnan(value): + row_signal += value - _background_value(background_model, x_center, y_center, i, j) + delta = row_signal - j_bar + if delta > 0.0: + weight_j += delta + y_delta += delta * dj + + if weight_i == 0.0 and weight_j == 0.0: + return _failed_centroid( + x_start, + y_start, + background_model.mean, + background_model.peak, + "Centroid calculation has zero weight in both dimensions.", + ) + if weight_i == 0.0: + return _failed_centroid( + x_start, + y_start, + background_model.mean, + background_model.peak, + "Centroid calculation has zero weight in the x dimension.", + ) + if weight_j == 0.0: + return _failed_centroid( + x_start, + y_start, + background_model.mean, + background_model.peak, + "Centroid calculation has zero weight in the y dimension.", + ) + + x_delta /= weight_i + y_delta /= weight_j + + if find_centroid and ( + abs(x_center + x_delta - x_start) > width + or abs(y_center + y_delta - y_start) > height + ): + return _failed_centroid( + x_start, + y_start, + background_model.mean, + background_model.peak, + "Centroid repositioning exceeded centroid box size.", + ) + + if abs(x_delta) < 0.01 and abs(y_delta) < 0.01: + still_moving = False + + if find_centroid: + x_center += x_delta + y_center += y_delta + i1 = int(x_center) - width // 2 + i2 = i1 + width + j1 = int(y_center) - height // 2 + j2 = j1 + height + background_model = _background( + image, + x_center, + y_center, + radius, + r_back1, + r_back2, + remove_background_stars, + use_plane_background, + ) + + iteration -= 1 + + return CentroidResult( + x_center, + y_center, + background_model.mean, + background_model.peak, + message="Centroid calculation completed.", + ) + + +def centroiding(input: dict, user: 'User'): + """ + Finds an AIJ-like Howell centroid for a clicked source position. + input = { + basename (str): The name of the file to analyze + height (int): The displayed image height + width (int): The displayed image width + x (float): Click x coordinate in displayed image space + y (float): Click y coordinate in displayed image space + radius (float): Centroid radius + r_back1 (float): Inner background annulus radius + r_back2 (float): Outer background annulus radius + } + """ + try: + file_path = FileCache().get_fits(input['basename'], input.get('source', 'archive'), user) + sci_hdu = get_hdu(file_path, 'SCI') + except TimeoutError: + raise ClientAlertException(f"Download of {input['basename']} timed out") + except TypeError as e: + raise ClientAlertException(f'Error: {e}') + + image = np.asarray(sci_hdu.data, dtype=float) + if image.ndim != 2: + message = f"Centroiding requires a 2D image, received shape {image.shape}." + log.error(message) + raise ClientAlertException(message) + + fits_height, fits_width = image.shape + x_points, y_points = scale_points( + input['height'], + input['width'], + fits_height, + fits_width, + x_points=[input['x']], + y_points=[input['y']], + ) + + result = centroid( + image, + x_click=float(x_points[0]), + y_click=float(y_points[0]), + radius=float(input.get('radius', 8.0)), + r_back1=float(input.get('r_back1', 10.0)), + r_back2=float(input.get('r_back2', 15.0)), + find_centroid=bool(input.get('find_centroid', True)), + remove_background_stars=bool(input.get('remove_background_stars', True)), + use_plane_background=bool(input.get('use_plane_background', False)), + ) + + output_x, output_y = scale_points( + fits_height, + fits_width, + input['height'], + input['width'], + x_points=[result.x], + y_points=[result.y], + ) + + ra = None + dec = None + try: + wcs = WCS(sci_hdu.header) + if wcs.get_axis_types()[0].get('coordinate_type') is None: + raise WcsError("No valid WCS solution") + sky_coord = wcs.pixel_to_world(result.y - 1, result.x - 1) + ra = float(sky_coord.ra.deg) + dec = float(sky_coord.dec.deg) + except (AttributeError, IndexError, KeyError, TypeError, ValueError, WcsError): + log.info(f"No valid WCS solution for centroiding on {input['basename']}") + pass + + return { + 'x': float(output_x[0]), + 'y': float(output_y[0]), + 'ra': ra, + 'dec': dec, + 'background': result.background, + 'peak': result.peak, + 'success': result.success, + 'message': result.message, + } diff --git a/datalab/datalab_session/tests/test_analysis.py b/datalab/datalab_session/tests/test_analysis.py index c7bca00..061904b 100644 --- a/datalab/datalab_session/tests/test_analysis.py +++ b/datalab/datalab_session/tests/test_analysis.py @@ -1,10 +1,13 @@ from unittest import mock import json +from types import SimpleNamespace +from astropy.io import fits from django.test import TestCase +import numpy as np from numpy.testing import assert_almost_equal -from datalab.datalab_session.analysis import line_profile, source_catalog +from datalab.datalab_session.analysis import centroiding, line_profile, source_catalog class TestAnalysis(TestCase): analysis_test_path = 'datalab/datalab_session/tests/test_files/analysis/' @@ -49,3 +52,126 @@ def test_source_catalog(self, mock_file_cache): }, None) self.assertEqual(output, self.test_source_catalog_data) + + def test_centroid_finds_pixels_center(self): + image = np.zeros((21, 21), dtype=float) + image[10, 10] = 100.0 + + result = centroiding.centroid( + image, + x_click=10.0, + y_click=10.0, + radius=3.0, + r_back1=4.0, + r_back2=5.0, + ) + + self.assertTrue(result.success) + self.assertAlmostEqual(result.x, 10.5, places=9) + self.assertAlmostEqual(result.y, 10.5, places=9) + self.assertEqual(result.background, 0.0) + self.assertEqual(result.peak, 100.0) + self.assertEqual(result.message, 'Centroid calculation completed.') + + @mock.patch('datalab.datalab_session.analysis.centroiding.get_hdu') + @mock.patch('datalab.datalab_session.analysis.centroiding.FileCache') + def test_centroiding_scales_display_coordinates(self, mock_file_cache, mock_get_hdu): + mock_instance = mock_file_cache.return_value + mock_instance.get_fits.return_value = self.analysis_fits_1_path + + fits_image = np.zeros((80, 120), dtype=float) + fits_image[48, 36] = 1200.0 + mock_get_hdu.return_value = SimpleNamespace(data=fits_image) + input_data = { + 'basename': 'fits_1', + 'height': 160, + 'width': 240, + 'x': 72.0, + 'y': 96.0, + 'radius': 3.0, + 'r_back1': 4.0, + 'r_back2': 5.0, + 'source': 'archive', + } + + output = centroiding.centroiding(input_data, None) + + self.assertTrue(output['success']) + self.assertAlmostEqual(output['x'], 73.0, places=9) + self.assertAlmostEqual(output['y'], 97.0, places=9) + self.assertEqual(output['background'], 0.0) + self.assertEqual(output['peak'], 1200.0) + self.assertEqual(output['message'], 'Centroid calculation completed.') + self.assertIsNone(output['ra']) + self.assertIsNone(output['dec']) + + @mock.patch('datalab.datalab_session.analysis.centroiding.get_hdu') + @mock.patch('datalab.datalab_session.analysis.centroiding.FileCache') + def test_centroiding_returns_ra_dec(self, mock_file_cache, mock_get_hdu): + mock_instance = mock_file_cache.return_value + mock_instance.get_fits.return_value = self.analysis_fits_1_path + + fits_image = np.zeros((80, 120), dtype=float) + fits_image[48, 36] = 1200.0 + header = fits.Header() + header['CTYPE1'] = 'RA---TAN' + header['CTYPE2'] = 'DEC--TAN' + header['CRVAL1'] = 150.0 + header['CRVAL2'] = 2.0 + header['CRPIX1'] = 1.0 + header['CRPIX2'] = 1.0 + header['CD1_1'] = 0.01 + header['CD1_2'] = 0.0 + header['CD2_1'] = 0.0 + header['CD2_2'] = 0.01 + mock_get_hdu.return_value = SimpleNamespace(data=fits_image, header=header) + input_data = { + 'basename': 'fits_1', + 'height': 160, + 'width': 240, + 'x': 72.0, + 'y': 96.0, + 'radius': 3.0, + 'r_back1': 4.0, + 'r_back2': 5.0, + 'source': 'archive', + } + + output = centroiding.centroiding(input_data, None) + + self.assertTrue(output['success']) + self.assertAlmostEqual(output['x'], 73.0, places=9) + self.assertAlmostEqual(output['y'], 97.0, places=9) + self.assertEqual(output['message'], 'Centroid calculation completed.') + self.assertIsNotNone(output['ra']) + self.assertIsNotNone(output['dec']) + + def test_centroid_returns_message_when_no_valid_pixels_in_box(self): + image = np.full((21, 21), np.nan, dtype=float) + + result = centroiding.centroid( + image, + x_click=10.0, + y_click=10.0, + radius=3.0, + r_back1=4.0, + r_back2=5.0, + ) + + self.assertFalse(result.success) + self.assertEqual(result.message, 'No valid pixels in centroid box.') + + def test_centroid_returns_message_when_zero_weight(self): + image = np.zeros((21, 21), dtype=float) + + result = centroiding.centroid( + image, + x_click=10.0, + y_click=10.0, + radius=3.0, + r_back1=4.0, + r_back2=5.0, + ) + + self.assertFalse(result.success) + self.assertEqual(result.message, 'Centroid calculation has zero weight in both dimensions.') diff --git a/datalab/datalab_session/tests/test_api.py b/datalab/datalab_session/tests/test_api.py index 4912d59..a9badce 100644 --- a/datalab/datalab_session/tests/test_api.py +++ b/datalab/datalab_session/tests/test_api.py @@ -2,6 +2,10 @@ from mixer.backend.django import mixer from django.contrib.auth.models import User from django.urls import reverse +from unittest import mock +from types import SimpleNamespace + +import numpy as np from datalab.datalab_session.models import DataOperation, DataSession @@ -25,3 +29,37 @@ def test_bulk_delete(self): # Only operation2 was not deleted self.assertEqual(DataOperation.objects.all().count(), 1) self.assertEqual(DataOperation.objects.first().id, operation2.id) + + @mock.patch('datalab.datalab_session.analysis.centroiding.get_hdu') + @mock.patch('datalab.datalab_session.analysis.centroiding.FileCache') + def test_centroiding_analysis_endpoint(self, mock_file_cache, mock_get_hdu): + mock_instance = mock_file_cache.return_value + mock_instance.get_fits.return_value = 'test.fits' + + fits_image = np.zeros((80, 120), dtype=float) + fits_image[48, 36] = 1200.0 + mock_get_hdu.return_value = SimpleNamespace(data=fits_image) + data = { + 'basename': 'fits_1', + 'height': 160, + 'width': 240, + 'x': 72.0, + 'y': 96.0, + 'radius': 3.0, + 'r_back1': 4.0, + 'r_back2': 5.0, + 'source': 'archive', + } + + response = self.client.post(reverse('analysis', args=('centroiding',)), data=data, format='json') + response_data = response.json() + + self.assertEqual(response.status_code, 200) + self.assertTrue(response_data['success']) + self.assertAlmostEqual(response_data['x'], 73.0, places=9) + self.assertAlmostEqual(response_data['y'], 97.0, places=9) + self.assertEqual(response_data['background'], 0.0) + self.assertEqual(response_data['peak'], 1200.0) + self.assertEqual(response_data['message'], 'Centroid calculation completed.') + self.assertIsNone(response_data['ra']) + self.assertIsNone(response_data['dec']) diff --git a/datalab/datalab_session/views.py b/datalab/datalab_session/views.py index 7cde640..35a4db7 100644 --- a/datalab/datalab_session/views.py +++ b/datalab/datalab_session/views.py @@ -5,6 +5,7 @@ from rest_framework.response import Response from datalab.datalab_session.data_operations.utils import available_operations +from datalab.datalab_session.analysis.centroiding import centroiding from datalab.datalab_session.analysis.line_profile import line_profile from datalab.datalab_session.analysis.source_catalog import source_catalog from datalab.datalab_session.analysis.get_tif import get_tif @@ -32,6 +33,7 @@ class AnalysisView(RetrieveAPIView): """ View to handle analysis actions and return the results. """ ACTIONS = { + "centroiding": centroiding, "line-profile": line_profile, "source-catalog": source_catalog, "get-tif": get_tif,