Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.10-slim as base
FROM python:3.11-slim as base
LABEL maintainer="jnation@lco.global"

# use bash
Expand Down
152 changes: 148 additions & 4 deletions datalab/datalab_session/data_operations/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from datalab.datalab_session.exceptions import ClientAlertException
from datalab.datalab_session.utils.format import Format
from datalab.datalab_session.utils.file_utils import crop_arrays
from reproject import reproject_adaptive
Comment thread
capetillo marked this conversation as resolved.
from astropy.io import fits
from reproject.mosaicking import find_optimal_celestial_wcs
from astropy.wcs import WCS

log = logging.getLogger()
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -46,12 +50,137 @@ def wizard_description():
'type': Format.FITS,
'minimum': Stack.MINIMUM_NUMBER_OF_INPUTS,
'maximum': Stack.MAXIMUM_NUMBER_OF_INPUTS,
},
'stacking_mode': {
'name': 'Stacking Mode',
'description': 'Choose simple stacking or reprojection before stacking',
'type': 'select',
'options': ['simple', 'reproject'],
'default': 'simple'
}
}
}
return description

def find_optimal_reference(self, images):
"""
images: list of InputDataHandler
returns: optimized_wcs, optimized_shape
"""

image_hdus = [img.sci_hdu for img in images]

wcs_opt, shape_out = find_optimal_celestial_wcs(image_hdus)
return wcs_opt, shape_out

def crop_bbox_from_footprint(self, footprint: np.ndarray):
"""
Return a bbox (r0,r1,c0,c1) such that the bbox edges contain no invalid pixels.
footprint: 2D array where non-zero means covered.
"""
mask = footprint != 0
if not mask.any():
raise ValueError("Empty footprint")

# start with the loose envelope (any valid somewhere)
rows = np.where(mask.any(axis=1))[0]
cols = np.where(mask.any(axis=0))[0]
r0, r1 = rows[0], rows[-1] + 1
c0, c1 = cols[0], cols[-1] + 1

# now shrink until the *edges* are fully valid
changed = True
while changed:
changed = False

# top edge
if r0 < r1 and not mask[r0, c0:c1].all():
r0 += 1
changed = True

# bottom edge
if r0 < r1 and not mask[r1 - 1, c0:c1].all():
r1 -= 1
changed = True

# left edge
if c0 < c1 and not mask[r0:r1, c0].all():
c0 += 1
changed = True

# right edge
if c0 < c1 and not mask[r0:r1, c1 - 1].all():
c1 -= 1
changed = True

if r0 >= r1 or c0 >= c1:
raise ValueError("No rectangular region without invalid pixels on the edges")

if not changed:
Comment thread
capetillo marked this conversation as resolved.
break

log.info(f'crop bbox from footprint (shrunk): {r0}, {r1}, {c0}, {c1}')
return r0, r1, c0, c1

def intersect_bboxes(self, bboxes):
"""
bboxes: iterable of (r0,r1,c0,c1). Returns intersection bbox.
"""
r0 = max(b[0] for b in bboxes)
r1 = min(b[1] for b in bboxes)
c0 = max(b[2] for b in bboxes)
c1 = min(b[3] for b in bboxes)
if r0 >= r1 or c0 >= c1:
raise ValueError("No overlapping valid region across images")
log.info(f'intersection bbox: {r0}, {r1}, {c0}, {c1}')
return r0, r1, c0, c1

def crop(self, img, bbox):
r0, r1, c0, c1 = bbox
return np.ascontiguousarray(img[r0:r1, c0:c1])

def prepare_for_sum(self, images, footprints):
"""
images: list of 2D float arrays (len=3)
footprints: list of 2D {0,1} arrays (len=3)
Returns: cropped_images, common_bbox
"""


per_bbox = [self.crop_bbox_from_footprint(fp) for fp in footprints]
common_bbox = self.intersect_bboxes(per_bbox)
cropped = [self.crop(im, common_bbox) for im in images]

log.info(f'cropped: {cropped[0].shape}, common_bbox: {common_bbox}')
return cropped, common_bbox

def reproject_images_to_reference(self, input_fits, optimized_wcs, optimized_shape):
"""
input_fits: list of InputDataHandler
optimized_wcs: WCS object for the optimal reference frame
optimized_shape: (ny, nx) shape

returns:
reprojected_arrays: list of 2D float arrays reprojected to the optimal reference frame
footprints: list of 2D numpy arrays indicating valid data regions in
"""
reprojected_arrays = []
footprints = []
for img in input_fits:
array, footprint = reproject_adaptive(
img.sci_hdu,
optimized_wcs,
shape_out=optimized_shape,
return_footprint=True,
conserve_flux=True
)
reprojected_arrays.append(array)
footprints.append(footprint)

return reprojected_arrays, footprints

def operate(self, submitter: User):
stacking_mode = self.input_data.get("stacking_mode")
input_files = self._validate_inputs(
input_key='input_files',
minimum_inputs=self.MINIMUM_NUMBER_OF_INPUTS
Expand All @@ -65,14 +194,29 @@ def operate(self, submitter: User):
log.info(f'input fits list in normalization: {input_fits_list}')
self.set_operation_progress(Stack.PROGRESS_STEPS['STACKING_MIDPOINT'] * (index / len(input_files)))

cropped_data, _ = crop_arrays([image.sci_data for image in input_fits_list])
if stacking_mode == "reproject":
optimized_wcs, optimized_shape = self.find_optimal_reference(input_fits_list)
arrays, footprints = self.reproject_images_to_reference(input_fits_list, optimized_wcs, optimized_shape)
cropped_data, _ = self.prepare_for_sum(arrays, footprints)

optimized_header = optimized_wcs.to_header()
header = input_fits_list[0].sci_hdu.header.copy()
header.update(optimized_header)

else:
arrays = [image.sci_data for image in input_fits_list]
cropped_data, _ = crop_arrays(arrays)
header = input_fits_list[0].sci_hdu.header.copy()

self.set_operation_progress(Stack.PROGRESS_STEPS['STACKING_PERCENTAGE_COMPLETION'])
stacked_sum = np.sum(cropped_data, axis=0)
self.set_operation_progress(Stack.PROGRESS_STEPS['STACKING_OUTPUT_PERCENTAGE_COMPLETION'])

stacked_sum = np.nansum(np.stack(cropped_data), axis=0)

output = FITSOutputHandler(self.cache_key, stacked_sum, self.temp, comment, data_header=input_fits_list[0].sci_hdu.header.copy()).create_and_save_data_products(Format.FITS)
self.set_operation_progress(Stack.PROGRESS_STEPS['STACKING_OUTPUT_PERCENTAGE_COMPLETION'])

output = FITSOutputHandler(self.cache_key, stacked_sum, self.temp, comment, data_header=header).create_and_save_data_products(Format.FITS)
log.info(f'Stacked output: {output}')

self.set_output(output)
self.set_operation_progress(Stack.PROGRESS_STEPS['OUTPUT_PERCENTAGE_COMPLETION'])
self.set_status('COMPLETED')
Loading