diff --git a/pyproject.toml b/pyproject.toml index ce4a1bff7..6f06add96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "backports.entry_points_selectable", "defusedxml", # For safely parsing XML files "pydantic>=2", "pydantic-settings", @@ -55,13 +54,14 @@ developer = [ ] instrument-server = [ "aiohttp", - "fastapi[standard]<0.116.0", + "fastapi[standard-no-fastapi-cloud-cli]>=0.116.0", "python-jose", ] server = [ "aiohttp", "cryptography", - "fastapi[standard]<0.116.0", + "fastapi[standard-no-fastapi-cloud-cli]>=0.116.0", + "graypy", "ispyb>=10.2.4", # Responsible for setting requirements for SQLAlchemy and mysql-connector-python; "jinja2", "mrcfile", @@ -73,7 +73,7 @@ server = [ "python-jose[cryptography]", "sqlalchemy[postgresql]", # Add as explicit dependency "sqlmodel", - "stomp-py<=8.1.0", # 8.1.1 (released 2024-04-06) doesn't work with our project + "stomp-py>8.1.1", # 8.1.1 (released 2024-04-06) doesn't work with our project "zocalo>=1", ] [project.urls] @@ -100,14 +100,18 @@ GitHub = "https://github.com/DiamondLightSource/python-murfey" [project.entry-points."murfey.config.extraction"] "murfey_machine" = "murfey.util.config:get_extended_machine_config" [project.entry-points."murfey.workflows"] +"atlas_update" = "murfey.workflows.register_atlas_update:run" "clem.align_and_merge" = "murfey.workflows.clem.align_and_merge:submit_cluster_request" "clem.process_raw_lifs" = "murfey.workflows.clem.process_raw_lifs:zocalo_cluster_request" "clem.process_raw_tiffs" = "murfey.workflows.clem.process_raw_tiffs:zocalo_cluster_request" "clem.register_align_and_merge_result" = "murfey.workflows.clem.register_align_and_merge_results:register_align_and_merge_result" "clem.register_preprocessing_result" = "murfey.workflows.clem.register_preprocessing_results:run" +"data_collection" = "murfey.workflows.register_data_collection:run" +"data_collection_group" = "murfey.workflows.register_data_collection_group:run" "pato" = "murfey.workflows.notifications:notification_setup" "picked_particles" = "murfey.workflows.spa.picking:particles_picked" "picked_tomogram" = "murfey.workflows.tomo.picking:picked_tomogram" +"processing_job" = "murfey.workflows.register_processing_job:run" "spa.flush_spa_preprocess" = "murfey.workflows.spa.flush_spa_preprocess:flush_spa_preprocess" [tool.setuptools] diff --git a/src/murfey/client/context.py b/src/murfey/client/context.py index 51c525b8e..15cc61f0e 100644 --- a/src/murfey/client/context.py +++ b/src/murfey/client/context.py @@ -1,11 +1,10 @@ from __future__ import annotations import logging +from importlib.metadata import entry_points from pathlib import Path from typing import Any, Dict, List, NamedTuple -from backports.entry_points_selectable import entry_points - from murfey.client.instance_environment import MurfeyInstanceEnvironment logger = logging.getLogger("murfey.client.context") diff --git a/src/murfey/server/api/clem.py b/src/murfey/server/api/clem.py index e5fdbe988..128bb2801 100644 --- a/src/murfey/server/api/clem.py +++ b/src/murfey/server/api/clem.py @@ -3,12 +3,14 @@ import re import traceback from ast import literal_eval -from importlib.metadata import EntryPoint # type hinting only +from importlib.metadata import ( + EntryPoint, # type hinting only + entry_points, +) from logging import getLogger from pathlib import Path from typing import Literal, Optional, Type, Union -from backports.entry_points_selectable import entry_points from fastapi import APIRouter from pydantic import BaseModel, field_validator from sqlalchemy.exc import NoResultFound @@ -752,9 +754,7 @@ def process_raw_lifs( try: # Try and load relevant Murfey workflow workflow: EntryPoint = list( - entry_points().select( - group="murfey.workflows", name="clem.process_raw_lifs" - ) + entry_points(group="murfey.workflows", name="clem.process_raw_lifs") )[0] except IndexError: raise RuntimeError("The relevant Murfey workflow was not found") @@ -792,9 +792,7 @@ def process_raw_tiffs( try: # Try and load relevant Murfey workflow workflow: EntryPoint = list( - entry_points().select( - group="murfey.workflows", name="clem.process_raw_tiffs" - ) + entry_points(group="murfey.workflows", name="clem.process_raw_tiffs") )[0] except IndexError: raise RuntimeError("The relevant Murfey workflow was not found") @@ -853,7 +851,7 @@ def align_and_merge_stacks( try: # Try and load relevant Murfey workflow workflow: EntryPoint = list( - entry_points().select(group="murfey.workflows", name="clem.align_and_merge") + entry_points(group="murfey.workflows", name="clem.align_and_merge") )[0] except IndexError: raise RuntimeError("The relevant Murfey workflow was not found") diff --git a/src/murfey/server/feedback.py b/src/murfey/server/feedback.py index 59a135520..2e77928f3 100644 --- a/src/murfey/server/feedback.py +++ b/src/murfey/server/feedback.py @@ -12,23 +12,16 @@ import subprocess import time from datetime import datetime -from functools import partial, singledispatch -from importlib.metadata import EntryPoint # For type hinting only +from functools import partial +from importlib.metadata import ( + EntryPoint, # For type hinting only + entry_points, +) from pathlib import Path from typing import Dict, List, NamedTuple, Tuple import mrcfile import numpy as np -from backports.entry_points_selectable import entry_points -from ispyb.sqlalchemy._auto_db_schema import ( - Atlas, - AutoProcProgram, - Base, - DataCollection, - DataCollectionGroup, - ProcessingJob, - ProcessingJobParameter, -) from sqlalchemy import func from sqlalchemy.exc import ( InvalidRequestError, @@ -42,7 +35,6 @@ import murfey.server import murfey.server.prometheus as prom import murfey.util.db as db -from murfey.server.ispyb import ISPyBSession, get_session_id from murfey.server.murfey_db import url # murfey_db from murfey.util import sanitise from murfey.util.config import ( @@ -65,11 +57,6 @@ murfey_db = None -class ExtendedRecord(NamedTuple): - record: Base # type: ignore - record_params: List[Base] # type: ignore - - class JobIDs(NamedTuple): dcgid: int dcid: int @@ -1902,7 +1889,6 @@ def _save_bfactor(message: dict, _db, demo: bool = False): def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: try: - record = None if "environment" in message: params = message["recipe"][str(message["recipe-pointer"])].get( "parameters", {} @@ -2014,254 +2000,6 @@ def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None - elif message["register"] == "data_collection_group": - ispyb_session_id = get_session_id( - microscope=message["microscope"], - proposal_code=message["proposal_code"], - proposal_number=message["proposal_number"], - visit_number=message["visit_number"], - db=ISPyBSession(), - ) - if dcg_murfey := _db.exec( - select(db.DataCollectionGroup) - .where(db.DataCollectionGroup.session_id == message["session_id"]) - .where(db.DataCollectionGroup.tag == message.get("tag")) - ).all(): - dcgid = dcg_murfey[0].id - else: - if ispyb_session_id is None: - murfey_dcg = db.DataCollectionGroup( - session_id=message["session_id"], - tag=message.get("tag"), - ) - dcgid = murfey_dcg.id - else: - record = DataCollectionGroup( - sessionId=ispyb_session_id, - experimentType=message["experiment_type"], - experimentTypeId=message["experiment_type_id"], - ) - dcgid = _register(record, header) - atlas_record = Atlas( - dataCollectionGroupId=dcgid, - atlasImage=message.get("atlas", ""), - pixelSize=message.get("atlas_pixel_size", 0), - cassetteSlot=message.get("sample"), - ) - if murfey.server._transport_object: - atlas_id = murfey.server._transport_object.do_insert_atlas( - atlas_record - )["return_value"] - else: - atlas_id = None - murfey_dcg = db.DataCollectionGroup( - id=dcgid, - atlas_id=atlas_id, - atlas=message.get("atlas", ""), - atlas_pixel_size=message.get("atlas_pixel_size"), - sample=message.get("sample"), - session_id=message["session_id"], - tag=message.get("tag"), - ) - _db.add(murfey_dcg) - _db.commit() - _db.close() - if murfey.server._transport_object: - if dcgid is None: - time.sleep(2) - murfey.server._transport_object.transport.nack(header, requeue=True) - return None - murfey.server._transport_object.transport.ack(header) - if dcg_hooks := entry_points().select( - group="murfey.hooks", name="data_collection_group" - ): - try: - for hook in dcg_hooks: - hook.load()(dcgid, session_id=message["session_id"]) - except Exception: - logger.error( - "Call to data collection group hook failed", exc_info=True - ) - return None - elif message["register"] == "atlas_update": - if murfey.server._transport_object: - murfey.server._transport_object.do_update_atlas( - message["atlas_id"], - message["atlas"], - message["atlas_pixel_size"], - message["sample"], - ) - murfey.server._transport_object.transport.ack(header) - if dcg_hooks := entry_points().select( - group="murfey.hooks", name="data_collection_group" - ): - try: - for hook in dcg_hooks: - hook.load()(message["dcgid"], session_id=message["session_id"]) - except Exception: - logger.error( - "Call to data collection group hook failed", exc_info=True - ) - return None - elif message["register"] == "data_collection": - logger.debug( - "Received message named 'data_collection' containing the following items:\n" - f"{', '.join([f'{sanitise(key)}: {sanitise(str(value))}' for key, value in message.items()])}" - ) - murfey_session_id = message["session_id"] - ispyb_session_id = get_session_id( - microscope=message["microscope"], - proposal_code=message["proposal_code"], - proposal_number=message["proposal_number"], - visit_number=message["visit_number"], - db=ISPyBSession(), - ) - dcg = _db.exec( - select(db.DataCollectionGroup) - .where(db.DataCollectionGroup.session_id == murfey_session_id) - .where(db.DataCollectionGroup.tag == message["source"]) - ).all() - if dcg: - dcgid = dcg[0].id - # flush_data_collections(message["source"], _db) - else: - logger.warning( - "No data collection group ID was found for image directory " - f"{sanitise(message['image_directory'])} and source " - f"{sanitise(message['source'])}" - ) - if murfey.server._transport_object: - murfey.server._transport_object.transport.nack(header, requeue=True) - return None - if dc_murfey := _db.exec( - select(db.DataCollection) - .where(db.DataCollection.tag == message.get("tag")) - .where(db.DataCollection.dcg_id == dcgid) - ).all(): - dcid = dc_murfey[0].id - else: - if ispyb_session_id is None: - murfey_dc = db.DataCollection( - tag=message.get("tag"), - dcg_id=dcgid, - ) - else: - record = DataCollection( - SESSIONID=ispyb_session_id, - experimenttype=message["experiment_type"], - imageDirectory=message["image_directory"], - imageSuffix=message["image_suffix"], - voltage=message["voltage"], - dataCollectionGroupId=dcgid, - pixelSizeOnImage=message["pixel_size"], - imageSizeX=message["image_size_x"], - imageSizeY=message["image_size_y"], - slitGapHorizontal=message.get("slit_width"), - magnification=message.get("magnification"), - exposureTime=message.get("exposure_time"), - totalExposedDose=message.get("total_exposed_dose"), - c2aperture=message.get("c2aperture"), - phasePlate=int(message.get("phase_plate", 0)), - ) - dcid = _register( - record, - header, - tag=( - message.get("tag") - if message["experiment_type"] == "tomography" - else "" - ), - ) - murfey_dc = db.DataCollection( - id=dcid, - tag=message.get("tag"), - dcg_id=dcgid, - ) - _db.add(murfey_dc) - _db.commit() - dcid = murfey_dc.id - _db.close() - if dcid is None and murfey.server._transport_object: - murfey.server._transport_object.transport.nack(header, requeue=True) - return None - if murfey.server._transport_object: - murfey.server._transport_object.transport.ack(header) - return None - elif message["register"] == "processing_job": - murfey_session_id = message["session_id"] - logger.info("registering processing job") - dc = _db.exec( - select(db.DataCollection, db.DataCollectionGroup) - .where(db.DataCollection.dcg_id == db.DataCollectionGroup.id) - .where(db.DataCollectionGroup.session_id == murfey_session_id) - .where(db.DataCollectionGroup.tag == message["source"]) - .where(db.DataCollection.tag == message["tag"]) - ).all() - if dc: - _dcid = dc[0][0].id - else: - logger.warning( - f"No data collection ID found for {sanitise(message['tag'])}" - ) - if murfey.server._transport_object: - murfey.server._transport_object.transport.nack(header, requeue=True) - return None - if pj_murfey := _db.exec( - select(db.ProcessingJob) - .where(db.ProcessingJob.recipe == message["recipe"]) - .where(db.ProcessingJob.dc_id == _dcid) - ).all(): - pid = pj_murfey[0].id - else: - if ISPyBSession() is None: - murfey_pj = db.ProcessingJob(recipe=message["recipe"], dc_id=_dcid) - else: - record = ProcessingJob( - dataCollectionId=_dcid, recipe=message["recipe"] - ) - run_parameters = message.get("parameters", {}) - assert isinstance(run_parameters, dict) - if message.get("job_parameters"): - job_parameters = [ - ProcessingJobParameter(parameterKey=k, parameterValue=v) - for k, v in message["job_parameters"].items() - ] - pid = _register(ExtendedRecord(record, job_parameters), header) - else: - pid = _register(record, header) - murfey_pj = db.ProcessingJob( - id=pid, recipe=message["recipe"], dc_id=_dcid - ) - _db.add(murfey_pj) - _db.commit() - pid = murfey_pj.id - _db.close() - if pid is None and murfey.server._transport_object: - murfey.server._transport_object.transport.nack(header, requeue=True) - return None - prom.preprocessed_movies.labels(processing_job=pid) - if not _db.exec( - select(db.AutoProcProgram).where(db.AutoProcProgram.pj_id == pid) - ).all(): - if ISPyBSession() is None: - murfey_app = db.AutoProcProgram(pj_id=pid) - else: - record = AutoProcProgram( - processingJobId=pid, processingStartTime=datetime.now() - ) - appid = _register(record, header) - if appid is None and murfey.server._transport_object: - murfey.server._transport_object.transport.nack( - header, requeue=True - ) - return None - murfey_app = db.AutoProcProgram(id=appid, pj_id=pid) - _db.add(murfey_app) - _db.commit() - _db.close() - if murfey.server._transport_object: - murfey.server._transport_object.transport.ack(header) - return None elif message["register"] == "flush_tomography_preprocess": _flush_tomography_preprocessing(message, _db) if murfey.server._transport_object: @@ -2457,14 +2195,10 @@ def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None - elif ( - message["register"] in entry_points().select(group="murfey.workflows").names - ): + elif message["register"] in entry_points(group="murfey.workflows").names: # Search for corresponding workflow workflows: list[EntryPoint] = list( - entry_points().select( - group="murfey.workflows", name=message["register"] - ) + entry_points(group="murfey.workflows", name=message["register"]) ) # Returns either 1 item or empty list if not workflows: logger.error(f"No workflow found for {sanitise(message['register'])}") @@ -2475,17 +2209,17 @@ def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: return None # Run the workflow if a match is found workflow: EntryPoint = workflows[0] - result = workflow.load()( + result: dict[str, bool] = workflow.load()( message=message, murfey_db=_db, ) if murfey.server._transport_object: - if result: + if result.get("success", False): murfey.server._transport_object.transport.ack(header) else: # Send it directly to DLQ without trying to rerun it murfey.server._transport_object.transport.nack( - header, requeue=False + header, requeue=result.get("requeue", False) ) if not result: logger.error( @@ -2516,65 +2250,6 @@ def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: return None -@singledispatch -def _register(record, header: dict, **kwargs): - raise NotImplementedError(f"Not method to register {record} or type {type(record)}") - - -@_register.register -def _(record: Base, header: dict, **kwargs): # type: ignore - if not murfey.server._transport_object: - logger.error( - f"No transport object found when processing record {record}. Message header: {header}" - ) - return None - try: - if isinstance(record, DataCollection): - return murfey.server._transport_object.do_insert_data_collection( - record, **kwargs - )["return_value"] - if isinstance(record, DataCollectionGroup): - return murfey.server._transport_object.do_insert_data_collection_group( - record - )["return_value"] - if isinstance(record, ProcessingJob): - return murfey.server._transport_object.do_create_ispyb_job(record)[ - "return_value" - ] - if isinstance(record, AutoProcProgram): - return murfey.server._transport_object.do_update_processing_status(record)[ - "return_value" - ] - # session = Session() - # session.add(record) - # session.commit() - # murfey.server._transport_object.transport.ack(header, requeue=False) - return getattr(record, record.__table__.primary_key.columns[0].name) - - except SQLAlchemyError as e: - logger.error(f"Murfey failed to insert ISPyB record {record}", e, exc_info=True) - # murfey.server._transport_object.transport.nack(header) - return None - except AttributeError as e: - logger.error( - f"Murfey could not find primary key when inserting record {record}", - e, - exc_info=True, - ) - return None - - -@_register.register -def _(extended_record: ExtendedRecord, header: dict, **kwargs): - if not murfey.server._transport_object: - raise ValueError( - "Transport object should not be None if a database record is being updated" - ) - return murfey.server._transport_object.do_create_ispyb_job( - extended_record.record, params=extended_record.record_params - )["return_value"] - - def feedback_listen(): if murfey.server._transport_object: if not murfey.server._transport_object.feedback_queue: diff --git a/src/murfey/server/ispyb.py b/src/murfey/server/ispyb.py index 70dc58e8c..f2c0cd801 100644 --- a/src/murfey/server/ispyb.py +++ b/src/murfey/server/ispyb.py @@ -614,7 +614,7 @@ def do_create_ispyb_job( dcid = record.dataCollectionId if not dcid: log.error("Can not create job: DCID not specified") - return False + return {"success": False, "return_value": None} jp = self.ispyb.mx_processing.get_job_params() jp["automatic"] = record.automatic diff --git a/src/murfey/server/main.py b/src/murfey/server/main.py index a20c99ef1..a65896aa0 100644 --- a/src/murfey/server/main.py +++ b/src/murfey/server/main.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging +from importlib.metadata import entry_points -from backports.entry_points_selectable import entry_points from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles diff --git a/src/murfey/util/config.py b/src/murfey/util/config.py index 458507b84..86862de1e 100644 --- a/src/murfey/util/config.py +++ b/src/murfey/util/config.py @@ -3,11 +3,11 @@ import os import socket from functools import lru_cache +from importlib.metadata import entry_points from pathlib import Path from typing import Any, Literal, Optional import yaml -from backports.entry_points_selectable import entry_points from pydantic import BaseModel, ConfigDict, RootModel, ValidationInfo, field_validator from pydantic_settings import BaseSettings @@ -270,6 +270,6 @@ def get_extended_machine_config( ) if not machine_config: return None - model = entry_points().select(group="murfey.config", name=extension_name)[0].load() + model = entry_points(group="murfey.config", name=extension_name)[0].load() data = getattr(machine_config, extension_name, {}) return model(**data) diff --git a/src/murfey/workflows/clem/align_and_merge.py b/src/murfey/workflows/clem/align_and_merge.py index efe358e07..2bc241ae0 100644 --- a/src/murfey/workflows/clem/align_and_merge.py +++ b/src/murfey/workflows/clem/align_and_merge.py @@ -78,4 +78,4 @@ def submit_cluster_request( }, new_connection=True, ) - return True + return {"success": True} diff --git a/src/murfey/workflows/clem/register_align_and_merge_results.py b/src/murfey/workflows/clem/register_align_and_merge_results.py index b9de46d64..fe52058b6 100644 --- a/src/murfey/workflows/clem/register_align_and_merge_results.py +++ b/src/murfey/workflows/clem/register_align_and_merge_results.py @@ -41,7 +41,7 @@ def parse_stringified_list(cls, value): def register_align_and_merge_result( message: dict, murfey_db: Session, demo: bool = False -) -> bool: +) -> dict[str, bool]: """ session_id (recipe) register (wrapper) @@ -69,13 +69,13 @@ def register_align_and_merge_result( "Invalid type for align-and-merge processing result: " f"{type(message['result'])}" ) - return False + return {"success": False, "requeue": False} except Exception: logger.error( "Exception encountered when parsing align-and-merge processing result: \n" f"{traceback.format_exc()}" ) - return False + return {"success": False, "requeue": False} # Outer try-finally block for tidying up database-related section of function try: @@ -103,8 +103,8 @@ def register_align_and_merge_result( f"{result.series_name!r}: \n" f"{traceback.format_exc()}" ) - return False + return {"success": False, "requeue": False} - return True + return {"success": True} finally: murfey_db.close() diff --git a/src/murfey/workflows/clem/register_preprocessing_results.py b/src/murfey/workflows/clem/register_preprocessing_results.py index f52ccb471..148223984 100644 --- a/src/murfey/workflows/clem/register_preprocessing_results.py +++ b/src/murfey/workflows/clem/register_preprocessing_results.py @@ -56,7 +56,7 @@ class CLEMPreprocessingResult(BaseModel): extent: list[float] -def run(message: dict, murfey_db: Session, demo: bool = False) -> bool: +def run(message: dict, murfey_db: Session, demo: bool = False) -> dict[str, bool]: session_id: int = ( int(message["session_id"]) if not isinstance(message["session_id"], int) @@ -72,13 +72,13 @@ def run(message: dict, murfey_db: Session, demo: bool = False) -> bool: logger.error( f"Invalid type for TIFF preprocessing result: {type(message['result'])}" ) - return False + return {"success": False, "requeue": False} except Exception: logger.error( "Exception encountered when parsing TIFF preprocessing result: \n" f"{traceback.format_exc()}" ) - return False + return {"success": False, "requeue": False} # Outer try-finally block for tidying up database-related section of function try: @@ -181,7 +181,7 @@ def run(message: dict, murfey_db: Session, demo: bool = False) -> bool: f"{result.series_name!r}: \n" f"{traceback.format_exc()}" ) - return False + return {"success": False, "requeue": False} # Load instrument name try: @@ -197,7 +197,7 @@ def run(message: dict, murfey_db: Session, demo: bool = False) -> bool: f"Error requesting data from database for {result.series_name!r} series: \n" f"{traceback.format_exc()}" ) - return False + return {"success": False, "requeue": False} # Construct list of files to use for image alignment and merging steps image_combos_to_process = [ @@ -234,12 +234,12 @@ def run(message: dict, murfey_db: Session, demo: bool = False) -> bool: f"{result.series_name!r} series", exc_info=True, ) - return False + return {"success": False, "requeue": False} logger.info( "Successfully requested image alignment and merging job for " f"{result.series_name!r} series" ) - return True + return {"success": True} finally: murfey_db.close() diff --git a/src/murfey/workflows/notifications/__init__.py b/src/murfey/workflows/notifications/__init__.py index f7314e8e8..9055e99bc 100644 --- a/src/murfey/workflows/notifications/__init__.py +++ b/src/murfey/workflows/notifications/__init__.py @@ -8,7 +8,7 @@ def notification_setup( message: dict, murfey_db: Session, num_instances_between_triggers: int = 500 -) -> bool: +) -> dict[str, bool]: parameters: Dict[str, Tuple[float, float]] = {} for k in message.keys(): parameter_name = "" @@ -48,4 +48,4 @@ def notification_setup( murfey_db.add_all(existing_notification_parameters + new_notification_parameters) murfey_db.commit() murfey_db.close() - return True + return {"success": True} diff --git a/src/murfey/workflows/register_atlas_update.py b/src/murfey/workflows/register_atlas_update.py new file mode 100644 index 000000000..c063edb60 --- /dev/null +++ b/src/murfey/workflows/register_atlas_update.py @@ -0,0 +1,34 @@ +import logging +from importlib.metadata import entry_points + +from sqlmodel.orm.session import Session as SQLModelSession + +from murfey.server import _transport_object + +logger = logging.getLogger("murfey.workflows.register_atlas_update") + + +def run( + message: dict, + murfey_db: SQLModelSession, # Defined for compatibility but unused + demo: bool = False, +): + if _transport_object is None: + logger.error("Unable to find transport manager") + return {"success": False, "requeue": False} + + logger.info(f"Registering updated atlas: \n{message}") + + _transport_object.do_update_atlas( + message["atlas_id"], + message["atlas"], + message["atlas_pixel_size"], + message["sample"], + ) + if dcg_hooks := entry_points(group="murfey.hooks", name="data_collection_group"): + try: + for hook in dcg_hooks: + hook.load()(message["dcgid"], session_id=message["session_id"]) + except Exception: + logger.error("Call to data collection group hook failed", exc_info=True) + return {"success": True} diff --git a/src/murfey/workflows/register_data_collection.py b/src/murfey/workflows/register_data_collection.py new file mode 100644 index 000000000..15ecc25a9 --- /dev/null +++ b/src/murfey/workflows/register_data_collection.py @@ -0,0 +1,104 @@ +import logging + +import ispyb.sqlalchemy._auto_db_schema as ISPyBDB +from sqlmodel import select +from sqlmodel.orm.session import Session as SQLModelSession + +import murfey.util.db as MurfeyDB +from murfey.server import _transport_object +from murfey.server.ispyb import ISPyBSession, get_session_id +from murfey.util import sanitise + +logger = logging.getLogger("murfey.workflows.register_data_collection") + + +def run( + message: dict, murfey_db: SQLModelSession, demo: bool = False +) -> dict[str, bool]: + # Fail immediately if transport manager was not provided + if _transport_object is None: + logger.error("Unable to find transport manager") + return {"success": False, "requeue": False} + + logger.info(f"Registering the following data collection: \n{message}") + + murfey_session_id = message["session_id"] + ispyb_session_id = get_session_id( + microscope=message["microscope"], + proposal_code=message["proposal_code"], + proposal_number=message["proposal_number"], + visit_number=message["visit_number"], + db=ISPyBSession(), + ) + dcg = murfey_db.exec( + select(MurfeyDB.DataCollectionGroup) + .where(MurfeyDB.DataCollectionGroup.session_id == murfey_session_id) + .where(MurfeyDB.DataCollectionGroup.tag == message["source"]) + ).all() + if dcg: + dcgid = dcg[0].id + # flush_data_collections(message["source"], murfey_db) + else: + logger.warning( + "No data collection group ID was found for image directory " + f"{sanitise(message['image_directory'])} and source " + f"{sanitise(message['source'])}" + ) + return {"success": False, "requeue": True} + + if dc_murfey := murfey_db.exec( + select(MurfeyDB.DataCollection) + .where(MurfeyDB.DataCollection.tag == message.get("tag")) + .where(MurfeyDB.DataCollection.dcg_id == dcgid) + ).all(): + dcid = dc_murfey[0].id + else: + if ispyb_session_id is None: + murfey_dc = MurfeyDB.DataCollection( + tag=message.get("tag"), + dcg_id=dcgid, + ) + else: + record = ISPyBDB.DataCollection( + SESSIONID=ispyb_session_id, + experimenttype=message["experiment_type"], + imageDirectory=message["image_directory"], + imageSuffix=message["image_suffix"], + voltage=message["voltage"], + dataCollectionGroupId=dcgid, + pixelSizeOnImage=message["pixel_size"], + imageSizeX=message["image_size_x"], + imageSizeY=message["image_size_y"], + slitGapHorizontal=message.get("slit_width"), + magnification=message.get("magnification"), + exposureTime=message.get("exposure_time"), + totalExposedDose=message.get("total_exposed_dose"), + c2aperture=message.get("c2aperture"), + phasePlate=int(message.get("phase_plate", 0)), + ) + dcid = _transport_object.do_insert_data_collection( + record, + tag=( + message.get("tag") + if message["experiment_type"] == "tomography" + else "" + ), + ).get("return_value", None) + murfey_dc = MurfeyDB.DataCollection( + id=dcid, + tag=message.get("tag"), + dcg_id=dcgid, + ) + murfey_db.add(murfey_dc) + murfey_db.commit() + dcid = murfey_dc.id + murfey_db.close() + + if dcid is None: + logger.error( + "Failed to register the following data collection: \n" + f"{message} \n" + "Requeueing message" + ) + return {"success": False, "requeue": True} + return {"success": True} diff --git a/src/murfey/workflows/register_data_collection_group.py b/src/murfey/workflows/register_data_collection_group.py new file mode 100644 index 000000000..18631808a --- /dev/null +++ b/src/murfey/workflows/register_data_collection_group.py @@ -0,0 +1,97 @@ +import logging +import time +from importlib.metadata import entry_points + +import ispyb.sqlalchemy._auto_db_schema as ISPyBDB +from sqlmodel import select +from sqlmodel.orm.session import Session as SQLModelSession + +import murfey.util.db as MurfeyDB +from murfey.server import _transport_object +from murfey.server.ispyb import ISPyBSession, get_session_id + +logger = logging.getLogger("murfey.workflows.register_data_collection_group") + + +def run( + message: dict, murfey_db: SQLModelSession, demo: bool = False +) -> dict[str, bool]: + # Fail immediately if no transport wrapper is found + if _transport_object is None: + logger.error("Unable to find transport manager") + return {"success": False, "requeue": False} + + logger.info(f"Registering the following data collection group: \n{message}") + + ispyb_session_id = get_session_id( + microscope=message["microscope"], + proposal_code=message["proposal_code"], + proposal_number=message["proposal_number"], + visit_number=message["visit_number"], + db=ISPyBSession(), + ) + + if dcg_murfey := murfey_db.exec( + select(MurfeyDB.DataCollectionGroup) + .where(MurfeyDB.DataCollectionGroup.session_id == message["session_id"]) + .where(MurfeyDB.DataCollectionGroup.tag == message.get("tag")) + ).all(): + dcgid = dcg_murfey[0].id + else: + if ispyb_session_id is None: + murfey_dcg = MurfeyDB.DataCollectionGroup( + session_id=message["session_id"], + tag=message.get("tag"), + ) + dcgid = murfey_dcg.id + else: + record = ISPyBDB.DataCollectionGroup( + sessionId=ispyb_session_id, + experimentType=message["experiment_type"], + experimentTypeId=message["experiment_type_id"], + ) + + dcgid = _transport_object.do_insert_data_collection_group(record).get( + "return_value", None + ) + + atlas_record = ISPyBDB.Atlas( + dataCollectionGroupId=dcgid, + atlasImage=message.get("atlas", ""), + pixelSize=message.get("atlas_pixel_size", 0), + cassetteSlot=message.get("sample"), + ) + atlas_id = _transport_object.do_insert_atlas(atlas_record).get( + "return_value", None + ) + + murfey_dcg = MurfeyDB.DataCollectionGroup( + id=dcgid, + atlas_id=atlas_id, + atlas=message.get("atlas", ""), + atlas_pixel_size=message.get("atlas_pixel_size"), + sample=message.get("sample"), + session_id=message["session_id"], + tag=message.get("tag"), + ) + murfey_db.add(murfey_dcg) + murfey_db.commit() + murfey_db.close() + + if dcgid is None: + time.sleep(2) + logger.error( + "Failed to register the following data collection group: \n" + f"{message} \n" + "Requeuing message" + ) + return {"success": False, "requeue": True} + + if dcg_hooks := entry_points(group="murfey.hooks", name="data_collection_group"): + try: + for hook in dcg_hooks: + hook.load()(dcgid, session_id=message["session_id"]) + except Exception: + logger.error("Call to data collection group hook failed", exc_info=True) + + return {"success": True} diff --git a/src/murfey/workflows/register_processing_job.py b/src/murfey/workflows/register_processing_job.py new file mode 100644 index 000000000..e2d7b1368 --- /dev/null +++ b/src/murfey/workflows/register_processing_job.py @@ -0,0 +1,99 @@ +import logging +from datetime import datetime + +import ispyb.sqlalchemy._auto_db_schema as ISPyBDB +from sqlmodel import select +from sqlmodel.orm.session import Session as SQLModelSession + +import murfey.server.prometheus as prom +import murfey.util.db as MurfeyDB +from murfey.server import _transport_object +from murfey.server.ispyb import ISPyBSession +from murfey.util import sanitise + +logger = logging.getLogger("murfey.workflows.register_processing_job") + + +def run(message: dict, murfey_db: SQLModelSession, demo: bool = False): + # Faill immediately if not transport manager is set + if _transport_object is None: + logger.error("Unable to find transport manager") + return {"success": False, "requeue": False} + + logger.info(f"Registering the following processing job: \n{message}") + + murfey_session_id = message["session_id"] + dc = murfey_db.exec( + select(MurfeyDB.DataCollection, MurfeyDB.DataCollectionGroup) + .where(MurfeyDB.DataCollection.dcg_id == MurfeyDB.DataCollectionGroup.id) + .where(MurfeyDB.DataCollectionGroup.session_id == murfey_session_id) + .where(MurfeyDB.DataCollectionGroup.tag == message["source"]) + .where(MurfeyDB.DataCollection.tag == message["tag"]) + ).all() + + if dc: + _dcid = dc[0][0].id + else: + logger.warning(f"No data collection ID found for {sanitise(message['tag'])}") + return {"success": False, "requeue": True} + if pj_murfey := murfey_db.exec( + select(MurfeyDB.ProcessingJob) + .where(MurfeyDB.ProcessingJob.recipe == message["recipe"]) + .where(MurfeyDB.ProcessingJob.dc_id == _dcid) + ).all(): + pid = pj_murfey[0].id + else: + if ISPyBSession() is None: + murfey_pj = MurfeyDB.ProcessingJob(recipe=message["recipe"], dc_id=_dcid) + else: + record = ISPyBDB.ProcessingJob( + dataCollectionId=_dcid, recipe=message["recipe"] + ) + run_parameters = message.get("parameters", {}) + assert isinstance(run_parameters, dict) + if message.get("job_parameters"): + job_parameters = [ + ISPyBDB.ProcessingJobParameter(parameterKey=k, parameterValue=v) + for k, v in message["job_parameters"].items() + ] + pid = _transport_object.do_create_ispyb_job( + record, params=job_parameters + ).get("return_value", None) + else: + pid = _transport_object.do_create_ispyb_job(record).get( + "return_value", None + ) + murfey_pj = MurfeyDB.ProcessingJob( + id=pid, recipe=message["recipe"], dc_id=_dcid + ) + murfey_db.add(murfey_pj) + murfey_db.commit() + pid = murfey_pj.id + murfey_db.close() + + if pid is None: + return {"success": False, "requeue": True} + + # Update Prometheus counter for preprocessed movies + prom.preprocessed_movies.labels(processing_job=pid) + + # Register AutoProcProgram database entry if it doesn't already exist + if not murfey_db.exec( + select(MurfeyDB.AutoProcProgram).where(MurfeyDB.AutoProcProgram.pj_id == pid) + ).all(): + if ISPyBSession() is None: + murfey_app = MurfeyDB.AutoProcProgram(pj_id=pid) + else: + record = ISPyBDB.AutoProcProgram( + processingJobId=pid, processingStartTime=datetime.now() + ) + appid = _transport_object.do_update_processing_status(record).get( + "return_value", None + ) + if appid is None: + return {"success": False, "requeue": True} + murfey_app = MurfeyDB.AutoProcProgram(id=appid, pj_id=pid) + murfey_db.add(murfey_app) + murfey_db.commit() + murfey_db.close() + return {"success": True} diff --git a/src/murfey/workflows/spa/flush_spa_preprocess.py b/src/murfey/workflows/spa/flush_spa_preprocess.py index 606511bb0..ffbb6eb31 100644 --- a/src/murfey/workflows/spa/flush_spa_preprocess.py +++ b/src/murfey/workflows/spa/flush_spa_preprocess.py @@ -7,7 +7,6 @@ from sqlmodel import Session, select from murfey.server import _transport_object -from murfey.server.api.auth import MurfeySessionIDInstrument as MurfeySessionID from murfey.server.feedback import _murfey_id from murfey.util import sanitise, secure_path from murfey.util.config import get_machine_config, get_microscope @@ -39,7 +38,7 @@ def register_grid_square( - session_id: MurfeySessionID, + session_id: int, gsid: int, grid_square_params: GridSquareParameters, murfey_db: Session, @@ -119,7 +118,7 @@ def register_grid_square( def register_foil_hole( - session_id: MurfeySessionID, + session_id: int, gs_name: int, foil_hole_params: FoilHoleParameters, murfey_db: Session, @@ -306,7 +305,9 @@ def _flush_position_analysis( return register_foil_hole(session_id, gs.id, foil_hole_parameters, murfey_db) -def flush_spa_preprocess(message: dict, murfey_db: Session, demo: bool = False) -> bool: +def flush_spa_preprocess( + message: dict, murfey_db: Session, demo: bool = False +) -> dict[str, bool]: session_id = message["session_id"] stashed_files = murfey_db.exec( select(PreprocessStash) @@ -314,7 +315,7 @@ def flush_spa_preprocess(message: dict, murfey_db: Session, demo: bool = False) .where(PreprocessStash.tag == message["tag"]) ).all() if not stashed_files: - return True + return {"success": True} murfey_session = murfey_db.exec( select(MurfeySession).where(MurfeySession.id == message["session_id"]) @@ -348,7 +349,7 @@ def flush_spa_preprocess(message: dict, murfey_db: Session, demo: bool = False) logger.warning( f"No SPA processing parameters found for client processing job ID {collected_ids[2].id}" ) - return False + return {"success": False, "requeue": False} murfey_ids = _murfey_id( collected_ids[3].id, @@ -444,4 +445,4 @@ def flush_spa_preprocess(message: dict, murfey_db: Session, demo: bool = False) ) murfey_db.commit() murfey_db.close() - return True + return {"success": True} diff --git a/src/murfey/workflows/spa/picking.py b/src/murfey/workflows/spa/picking.py index 894a5f48a..215423927 100644 --- a/src/murfey/workflows/spa/picking.py +++ b/src/murfey/workflows/spa/picking.py @@ -444,7 +444,7 @@ def _check_notifications(message: dict, murfey_db: Session) -> None: return None -def particles_picked(message: dict, murfey_db: Session) -> bool: +def particles_picked(message: dict, murfey_db: Session) -> dict[str, bool]: movie = murfey_db.exec( select(Movie).where(Movie.murfey_id == message["motion_correction_id"]) ).one() @@ -465,4 +465,4 @@ def particles_picked(message: dict, murfey_db: Session) -> bool: processing_job=_pj_id(message["program_id"], murfey_db) ).inc() _check_notifications(message, murfey_db) - return True + return {"success": True} diff --git a/src/murfey/workflows/tomo/picking.py b/src/murfey/workflows/tomo/picking.py index 9d1fafab7..a7432df71 100644 --- a/src/murfey/workflows/tomo/picking.py +++ b/src/murfey/workflows/tomo/picking.py @@ -205,6 +205,6 @@ def _register_picked_tomogram_use_diameter(message: dict, murfey_db: Session): murfey_db.close() -def picked_tomogram(message: dict, murfey_db: Session) -> bool: +def picked_tomogram(message: dict, murfey_db: Session) -> dict[str, bool]: _register_picked_tomogram_use_diameter(message, murfey_db) - return True + return {"success": True} diff --git a/tests/server/test_feedback.py b/tests/server/test_feedback.py new file mode 100644 index 000000000..1fd7ce1de --- /dev/null +++ b/tests/server/test_feedback.py @@ -0,0 +1,60 @@ +from importlib.metadata import entry_points +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +feedback_callback_params_matrix = ( + # Murfey workflows currently present in pyproject.toml + ("atlas_update",), + ("clem.align_and_merge",), + ("clem.process_raw_lifs",), + ("clem.process_raw_tiffs",), + ("clem.register_align_and_merge_result",), + ("clem.register_preprocessing_result",), + ("data_collection",), + ("data_collection_group",), + ("pato",), + ("picked_particles",), + ("picked_tomogram",), + ("processing_job",), + ("spa.flush_spa_preprocess",), +) + + +@pytest.mark.parametrize("test_params", feedback_callback_params_matrix) +def test_feedback_callback( + mocker: MockerFixture, + test_params: tuple[str], +): + """ + Checks that feedback-callback loop works correctly for the entry points-based workflows + """ + + # Unpack test params + (entry_point_name,) = test_params + + # Patch the functions used to generate the module-level variables + mock_get_security_config = mocker.patch("murfey.util.config.get_security_config") + mock_get_security_config.return_value = MagicMock() + mock_url = mocker.patch("murfey.server.murfey_db.url") + mock_url.return_value = MagicMock() + mock_create_engine = mocker.patch("sqlmodel.create_engine") + mock_create_engine.return_value = MagicMock() + mock_murfey_db = MagicMock() + mock_sql_session = mocker.patch("sqlmodel.Session") + mock_sql_session.return_value = mock_murfey_db + + # Load the entry point and patch the executable it calls + eps = list(entry_points(group="murfey.workflows", name=entry_point_name)) + assert len(eps) == 1 # Entry point should be present and unique + mock_function = mocker.patch(eps[0].value.replace(":", ".")) + + # Initialise after mocking + from murfey.server.feedback import feedback_callback + + # Run the function and check that it calls the entry point correctly + header = {"dummy": "dummy"} + message = {"register": entry_point_name} + feedback_callback(header, message, mock_murfey_db) + mock_function.assert_called_once_with(message=message, murfey_db=mock_murfey_db) diff --git a/tests/workflows/test_processing_job.py b/tests/workflows/test_processing_job.py new file mode 100644 index 000000000..85562d933 --- /dev/null +++ b/tests/workflows/test_processing_job.py @@ -0,0 +1,109 @@ +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from murfey.workflows.register_processing_job import run + +register_processing_job_params_matrix = [ + # ISPyB session present | DC search result | PJ search result | APP search result | Insert ISPyB job | Update processing status + (v0, v1, v2, v3, v4, v5) + for v0 in (0, None) + for v1 in (0, None) + for v2 in (0, None) + for v3 in (0, None) + for v4 in (0, None) + for v5 in (0, None) +] + + +@pytest.mark.parametrize("test_params", register_processing_job_params_matrix) +def test_run( + mocker: MockerFixture, + test_params: tuple[ + int | None, int | None, int | None, int | None, int | None, int | None + ], +): + # Unpack test params + ispyb_session, dc_result, pj_result, app_result, insert_job, update_status = ( + test_params + ) + + # Create mocks + # Transport object functions + mock_transport_object = mocker.patch( + "murfey.workflows.register_processing_job._transport_object" + ) + mock_transport_object.do_create_ispyb_job.return_value = { + "return_value": insert_job + } + mock_transport_object.do_update_processing_status.return_value = { + "return_value": update_status + } + + # ISPyB session + mock_ispyb_session = mocker.patch( + "murfey.workflows.register_processing_job.ISPyBSession" + ) + mock_ispyb_session.return_value = ispyb_session + + # Murfey database + mock_murfey_dc = MagicMock() + mock_murfey_dc.id = dc_result + mock_murfey_pj = MagicMock() + mock_murfey_pj.id = pj_result + mock_murfey_app = MagicMock() + mock_murfey_app.id = app_result + + # Set up side effects depending on route taken through the function + db_call_order = [[[mock_murfey_dc]] if dc_result is not None else []] + if dc_result is not None: + db_call_order.append([mock_murfey_pj] if pj_result is not None else []) + if pj_result is not None or insert_job is not None or ispyb_session is None: + db_call_order.append([mock_murfey_app] if app_result is not None else []) + mock_murfey_db = MagicMock() + mock_murfey_db.exec.return_value.all.side_effect = db_call_order + + # Mock Prometheus object + mock_prom = mocker.patch("murfey.workflows.register_processing_job.prom") + + # Run function and check results and calls + message = { + "session_id": 0, + "source": "some_path", + "tag": "some_tag", + "recipe": "some_recipe", + "parameters": {}, + "job_parameters": {"dummy": "dummy"}, + } + result = run(message=message, murfey_db=mock_murfey_db) + if dc_result is not None: + if pj_result is not None: + mock_prom.preprocessed_movies.labels.assert_called_once() + if app_result is not None: + assert {"success": True} + else: + if update_status is not None: + assert result == {"success": True} + else: + if ispyb_session is not None: + assert result == {"success": False, "requeue": True} + else: + assert result == {"success": True} + else: + if ispyb_session is not None: + mock_transport_object.do_create_ispyb_job.assert_called_once() + if insert_job is not None: + if app_result is not None: + assert result == {"success": True} + else: + if update_status is not None: + assert result == {"success": True} + else: + assert result == {"success": False, "requeue": True} + else: + assert result == {"success": False, "requeue": True} + else: + assert result == {"success": False, "requeue": True} + else: + assert result == {"success": False, "requeue": True} diff --git a/tests/workflows/test_register_atlas_update.py b/tests/workflows/test_register_atlas_update.py new file mode 100644 index 000000000..983c739b8 --- /dev/null +++ b/tests/workflows/test_register_atlas_update.py @@ -0,0 +1,33 @@ +from unittest import mock +from unittest.mock import MagicMock + +from pytest_mock import MockerFixture + +from murfey.workflows.register_atlas_update import run + + +def test_run( + mocker: MockerFixture, +): + # Set up mocks and the dummy message to be registered + mock_transport_object = mocker.patch( + "murfey.workflows.register_atlas_update._transport_object" + ) + mock_murfey_db = MagicMock() + message = { + "register": "atlas_update", + "atlas_id": mock.sentinel, + "atlas": mock.sentinel, + "atlas_pixel_size": mock.sentinel, + "sample": mock.sentinel, + } + + # Run the function and check the results and calls made + result = run(message, mock_murfey_db) + mock_transport_object.do_update_atlas.assert_called_once_with( + message["atlas_id"], + message["atlas"], + message["atlas_pixel_size"], + message["sample"], + ) + assert result == {"success": True} diff --git a/tests/workflows/test_register_data_collection.py b/tests/workflows/test_register_data_collection.py new file mode 100644 index 000000000..49be95860 --- /dev/null +++ b/tests/workflows/test_register_data_collection.py @@ -0,0 +1,104 @@ +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from murfey.workflows.register_data_collection import run +from tests.conftest import ExampleVisit + +register_data_collection_params_matrix = ( + # ISPyB session ID return value | DCG search result | DC search result | Insert data collection + (0, 0, 0, 0), + (0, 0, 0, None), + (0, 0, None, 0), + (0, 0, None, None), + (0, None, 0, 0), + (0, None, 0, None), + (0, None, None, 0), + (0, None, None, None), + (None, 0, 0, 0), + (None, 0, 0, None), + (None, 0, None, 0), + (None, 0, None, None), + (None, None, 0, 0), + (None, None, 0, None), + (None, None, None, 0), + (None, None, None, None), +) + + +@pytest.mark.parametrize("test_params", register_data_collection_params_matrix) +def test_run( + mocker: MockerFixture, + test_params: tuple[int | None, int | None, int | None, int | None], +): + # Unpack test params + ispyb_session_id, dcg_result, dc_result, insert_data_collection = test_params + + # Set up mock objects + # 'get_session_id' + mock_get_session_id = mocker.patch( + "murfey.workflows.register_data_collection.get_session_id" + ) + mock_get_session_id.return_value = ispyb_session_id + + # Transport object inserts + mock_transport_object = mocker.patch( + "murfey.workflows.register_data_collection._transport_object" + ) + mock_transport_object.do_insert_data_collection.return_value = { + "return_value": insert_data_collection + } + + # Murfey database + mock_murfey_db = MagicMock() + mock_dcg = MagicMock() + mock_dcg.id = dcg_result + + mock_dc = MagicMock() + mock_dc.id = dc_result + mock_murfey_db.exec.return_value.all.side_effect = [ + # Sequence of mock database tables + [mock_dcg] if dcg_result is not None else [], + [mock_dc] if dc_result is not None else [], + ] + + # Run the function and check results and calls + message = { + "session_id": 0, + "microscope": "test_instrument", + "proposal_code": ExampleVisit.proposal_code, + "proposal_number": ExampleVisit.proposal_number, + "visit_number": ExampleVisit.visit_number, + "source": "some_path", + "image_directory": "some_path", + "tag": "some_string", + "experiment_type": "SPA", + "image_suffix": ".jpg", + "voltage": 200, + "pixel_size": 1e-9, + "image_size_x": 2048, + "image_size_y": 2048, + "slit_width": 0.005, + "magnification": 150000, + "exposure_time": 30, + "total_exposed_dose": 30, + "c2aperture": 5, + "phase_plate": 1, + } + result = run(message=message, murfey_db=mock_murfey_db) + if dcg_result is None: + assert result == {"success": False, "requeue": True} + else: + if dc_result is None: + if ispyb_session_id is not None: + mock_transport_object.do_insert_data_collection.assert_called_once() + if insert_data_collection is not None: + assert result == {"success": True} + else: + assert result == {"success": False, "requeue": True} + else: + mock_transport_object.do_insert_data_collection.assert_not_called() + assert result == {"success": False, "requeue": True} + else: + assert result == {"success": True} diff --git a/tests/workflows/test_register_data_collection_group.py b/tests/workflows/test_register_data_collection_group.py new file mode 100644 index 000000000..efe46a6ea --- /dev/null +++ b/tests/workflows/test_register_data_collection_group.py @@ -0,0 +1,87 @@ +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from murfey.workflows.register_data_collection_group import run +from tests.conftest import ExampleVisit + +register_data_collection_group_params_matrix = ( + # ISPyB session ID | # DCG search result | # DCG insert result | # Atlas insert result + (0, 0, 0, 0), + (0, 0, 0, None), + (0, 0, None, 0), + (0, 0, None, None), + (0, None, 0, 0), + (0, None, 0, None), + (0, None, None, 0), + (0, None, None, None), + (None, 0, 0, 0), + (None, 0, 0, None), + (None, 0, None, 0), + (None, 0, None, None), + (None, None, 0, 0), + (None, None, 0, None), + (None, None, None, 0), + (None, None, None, None), +) + + +@pytest.mark.parametrize("test_params", register_data_collection_group_params_matrix) +def test_run( + mocker: MockerFixture, + test_params: tuple[int | None, int | None, int | None, int | None], +): + # Unpack test params + (ispyb_session_id, dcg_result, insert_dcg, insert_atlas) = test_params + + # Mock the transport object functions + mock_transport_object = mocker.patch( + "murfey.workflows.register_data_collection_group._transport_object" + ) + mock_transport_object.do_insert_data_collection_group.return_value = { + "return_value": insert_dcg, + } + mock_transport_object.do_insert_atlas.return_value = {"return_value": insert_atlas} + + # Mock the 'get_session_id' return value + mock_get_session_id = mocker.patch( + "murfey.workflows.register_data_collection_group.get_session_id" + ) + mock_get_session_id.return_value = ispyb_session_id + + # Mock the Murfey database + mock_murfey_db = MagicMock() + mock_dcg = MagicMock() + mock_dcg.id = dcg_result + mock_murfey_db.exec.return_value.all.return_value = ( + [mock_dcg] if dcg_result is not None else [] + ) + + # Run the function and check the results and calls + message = { + "microscope": "test", + "proposal_code": ExampleVisit.proposal_code, + "proposal_number": ExampleVisit.proposal_number, + "visit_number": ExampleVisit.visit_number, + "session_id": ExampleVisit.murfey_session_id, + "tag": "some_text", + "experiment_type": "single particle", + "experiment_type_id": 0, + "atlas": "some_file", + "atlas_pixel_size": 1e-9, + "sample": 0, + } + result = run(message=message, murfey_db=mock_murfey_db) + if dcg_result is not None: + assert result == {"success": True} + else: + if ispyb_session_id is not None: + mock_transport_object.do_insert_data_collection_group.assert_called_once() + mock_transport_object.do_insert_atlas.assert_called_once() + if insert_dcg is not None: + assert result == {"success": True} + else: + assert result == {"success": False, "requeue": True} + else: + assert result == {"success": False, "requeue": True}