diff --git a/src/murfey/cli/repost_failed_calls.py b/src/murfey/cli/repost_failed_calls.py index a3fe219c9..525b60598 100644 --- a/src/murfey/cli/repost_failed_calls.py +++ b/src/murfey/cli/repost_failed_calls.py @@ -7,7 +7,7 @@ from pathlib import Path from queue import Empty, Queue -from sqlmodel import Session +from sqlmodel import Session, create_engine from workflows.transport.pika_transport import PikaTransport import murfey.server.api.auth @@ -25,7 +25,7 @@ import murfey.server.api.session_info import murfey.server.api.websocket import murfey.server.api.workflow -from murfey.server.murfey_db import get_murfey_db_session +from murfey.server.murfey_db import url from murfey.util.config import security_from_file @@ -162,7 +162,10 @@ def run(): - feedback messages that can be sent back to rabbitmq """ parser = argparse.ArgumentParser( - description="Purge and reinject failed murfey messages" + description=( + "Purge and reinject failed murfey messages. " + "Provide security configuration and set machine configuration." + ) ) parser.add_argument( "-c", @@ -177,7 +180,6 @@ def run(): # Read the security config file security_config = security_from_file(args.config) - murfey_db = get_murfey_db_session(security_config) # Purge the queue and repost/reinject any messages found dlq_dump_path = Path(args.dir) @@ -187,7 +189,14 @@ def run(): security_config.feedback_queue, security_config.rabbitmq_credentials, ) - handle_failed_posts(exported_messages, murfey_db) + + # Set up database and retry api calls + _url = url(security_config) + engine = create_engine(_url) + with Session(engine) as murfey_db: + handle_failed_posts(exported_messages, murfey_db) + + # Reinject all remaining messages to rabbitmq handle_dlq_messages(exported_messages, security_config.rabbitmq_credentials) # Clean up any created directories diff --git a/src/murfey/server/api/workflow.py b/src/murfey/server/api/workflow.py index ff67c4a89..b8bf130f3 100644 --- a/src/murfey/server/api/workflow.py +++ b/src/murfey/server/api/workflow.py @@ -768,7 +768,7 @@ def register_completed_tilt_series( db.commit() for ts in tilt_series_db: if ( - check_tilt_series_mc(ts.id) + check_tilt_series_mc(ts.id, db) and not ts.processing_requested and ts.tilt_series_length > 2 ): @@ -795,9 +795,9 @@ def register_completed_tilt_series( machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] - tilts = get_all_tilts(ts.id) - ids = get_job_ids(ts.id, collected_ids[3].id) - preproc_params = get_tomo_proc_params(ids.dcgid) + tilts = get_all_tilts(ts.id, db) + ids = get_job_ids(ts.id, collected_ids[3].id, db) + preproc_params = get_tomo_proc_params(ids.dcgid, db) first_tilt = db.exec( select(Tilt).where(Tilt.tilt_series_id == ts.id) diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index c9fa3ba7d..24280d5de 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -748,7 +748,7 @@ def register_dc_group( if dcg_params.atlas: _flush_grid_square_records( - {"session_id": session_id, "tag": dcg_params.tag}, demo=True + {"session_id": session_id, "tag": dcg_params.tag}, _db=db, demo=True ) return dcg_params diff --git a/src/murfey/server/feedback.py b/src/murfey/server/feedback.py index 4d40a7b1a..3ba15f6c8 100644 --- a/src/murfey/server/feedback.py +++ b/src/murfey/server/feedback.py @@ -84,8 +84,8 @@ def get_angle(tilt_file_name: str) -> float: raise ValueError(f"Tilt angle not found for file {tilt_file_name}") -def check_tilt_series_mc(tilt_series_id: int) -> bool: - results = murfey_db.exec( +def check_tilt_series_mc(tilt_series_id: int, _db) -> bool: + results = _db.exec( select(db.Tilt, db.TiltSeries) .where(db.Tilt.tilt_series_id == db.TiltSeries.id) .where(db.TiltSeries.id == tilt_series_id) @@ -97,8 +97,8 @@ def check_tilt_series_mc(tilt_series_id: int) -> bool: ) -def get_all_tilts(tilt_series_id: int) -> List[str]: - complete_results = murfey_db.exec( +def get_all_tilts(tilt_series_id: int, _db) -> List[str]: + complete_results = _db.exec( select(db.Tilt, db.TiltSeries, db.Session) .where(db.Tilt.tilt_series_id == db.TiltSeries.id) .where(db.TiltSeries.id == tilt_series_id) @@ -118,8 +118,8 @@ def get_all_tilts(tilt_series_id: int) -> List[str]: ] -def get_job_ids(tilt_series_id: int, appid: int) -> JobIDs: - results = murfey_db.exec( +def get_job_ids(tilt_series_id: int, appid: int, _db) -> JobIDs: + results = _db.exec( select( db.TiltSeries, db.AutoProcProgram, @@ -144,8 +144,8 @@ def get_job_ids(tilt_series_id: int, appid: int) -> JobIDs: ) -def get_tomo_proc_params(dcg_id: int, *args) -> db.TomographyProcessingParameters: - results = murfey_db.exec( +def get_tomo_proc_params(dcg_id: int, _db) -> db.TomographyProcessingParameters: + results = _db.exec( select(db.TomographyProcessingParameters).where( db.TomographyProcessingParameters.dcg_id == dcg_id ) @@ -326,7 +326,7 @@ def _get_spa_params( return relion_params, feedback_params -def _release_2d_hold(message: dict, _db=murfey_db): +def _release_2d_hold(message: dict, _db): relion_params, feedback_params = _get_spa_params(message["program_id"], _db) if not feedback_params.star_combination_job: feedback_params.star_combination_job = feedback_params.next_job + ( @@ -403,7 +403,7 @@ def _release_2d_hold(message: dict, _db=murfey_db): _db.close() -def _release_3d_hold(message: dict, _db=murfey_db): +def _release_3d_hold(message: dict, _db): pj_id_params = _pj_id(message["program_id"], _db, recipe="em-spa-preprocess") pj_id = _pj_id(message["program_id"], _db, recipe="em-spa-class3d") relion_params = _db.exec( @@ -481,7 +481,7 @@ def _release_3d_hold(message: dict, _db=murfey_db): _db.close() -def _release_refine_hold(message: dict, _db=murfey_db): +def _release_refine_hold(message: dict, _db): pj_id_params = _pj_id(message["program_id"], _db, recipe="em-spa-preprocess") pj_id = _pj_id(message["program_id"], _db, recipe="em-spa-refine") relion_params = _db.exec( @@ -562,7 +562,7 @@ def _release_refine_hold(message: dict, _db=murfey_db): _db.close() -def _register_incomplete_2d_batch(message: dict, _db=murfey_db, demo: bool = False): +def _register_incomplete_2d_batch(message: dict, _db, demo: bool = False): """Received first batch from particle selection service""" # the general parameters are stored using the preprocessing auto proc program ID logger.info("Registering incomplete particle batch for 2D classification") @@ -686,7 +686,7 @@ def _register_incomplete_2d_batch(message: dict, _db=murfey_db, demo: bool = Fal _db.close() -def _register_complete_2d_batch(message: dict, _db=murfey_db, demo: bool = False): +def _register_complete_2d_batch(message: dict, _db, demo: bool = False): """Received full batch from particle selection service""" instrument_name = ( _db.exec(select(db.Session).where(db.Session.id == message["session_id"])) @@ -998,7 +998,7 @@ def _flush_class2d( _db.commit() -def _register_class_selection(message: dict, _db=murfey_db, demo: bool = False): +def _register_class_selection(message: dict, _db, demo: bool = False): """Received selection score from class selection service""" pj_id_params = _pj_id(message["program_id"], _db, recipe="em-spa-preprocess") pj_id = _pj_id(message["program_id"], _db, recipe="em-spa-class2d") @@ -1181,7 +1181,7 @@ def _resize_intial_model( return None -def _register_3d_batch(message: dict, _db=murfey_db, demo: bool = False): +def _register_3d_batch(message: dict, _db, demo: bool = False): """Received 3d batch from class selection service""" class3d_message = message.get("class3d_message") assert isinstance(class3d_message, dict) @@ -1375,7 +1375,7 @@ def _register_3d_batch(message: dict, _db=murfey_db, demo: bool = False): _db.close() -def _register_initial_model(message: dict, _db=murfey_db, demo: bool = False): +def _register_initial_model(message: dict, _db, demo: bool = False): """Received initial model from 3d classification service""" pj_id_params = _pj_id(message["program_id"], _db, recipe="em-spa-preprocess") # Add the initial model file to the database @@ -1390,31 +1390,31 @@ def _register_initial_model(message: dict, _db=murfey_db, demo: bool = False): _db.close() -def _flush_tomography_preprocessing(message: dict): +def _flush_tomography_preprocessing(message: dict, _db): session_id = message["session_id"] instrument_name = ( - murfey_db.exec(select(db.Session).where(db.Session.id == session_id)) + _db.exec(select(db.Session).where(db.Session.id == session_id)) .one() .instrument_name ) machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] - stashed_files = murfey_db.exec( + stashed_files = _db.exec( select(db.PreprocessStash) .where(db.PreprocessStash.session_id == session_id) .where(db.PreprocessStash.group_tag == message["data_collection_group_tag"]) ).all() if not stashed_files: return - collected_ids = murfey_db.exec( + collected_ids = _db.exec( select( db.DataCollectionGroup, ) .where(db.DataCollectionGroup.session_id == session_id) .where(db.DataCollectionGroup.tag == message["data_collection_group_tag"]) ).first() - proc_params = get_tomo_proc_params(collected_ids.id) + proc_params = get_tomo_proc_params(collected_ids.id, _db) if not proc_params: visit_name = message["visit_name"].replace("\r\n", "").replace("\n", "") logger.warning( @@ -1425,7 +1425,7 @@ def _flush_tomography_preprocessing(message: dict): recipe_name = machine_config.recipes.get("em-tomo-preprocess", "em-tomo-preprocess") for f in stashed_files: - collected_ids = murfey_db.exec( + collected_ids = _db.exec( select( db.DataCollectionGroup, db.DataCollection, @@ -1442,7 +1442,7 @@ def _flush_tomography_preprocessing(message: dict): ).one() detached_ids = [c.id for c in collected_ids] - murfey_ids = _murfey_id(detached_ids[3], murfey_db, number=1, close=False) + murfey_ids = _murfey_id(detached_ids[3], _db, number=1, close=False) p = Path(f.mrc_out) if not p.parent.exists(): p.parent.mkdir(parents=True) @@ -1452,7 +1452,7 @@ def _flush_tomography_preprocessing(message: dict): image_number=f.image_number, tag=f.tag, ) - murfey_db.add(movie) + _db.add(movie) zocalo_message: dict = { "recipes": [recipe_name], "parameters": { @@ -1497,13 +1497,14 @@ def _flush_tomography_preprocessing(message: dict): "movie_id": murfey_ids[0], "program_id": detached_ids[3], }, + _db, ) - murfey_db.delete(f) - murfey_db.commit() - murfey_db.close() + _db.delete(f) + _db.commit() + _db.close() -def _flush_grid_square_records(message: dict, _db=murfey_db, demo: bool = False): +def _flush_grid_square_records(message: dict, _db, demo: bool = False): tag = message["tag"] session_id = message["session_id"] gs_ids = [] @@ -1519,7 +1520,7 @@ def _flush_grid_square_records(message: dict, _db=murfey_db, demo: bool = False) _flush_foil_hole_records(i, _db=_db, demo=demo) -def _flush_foil_hole_records(grid_square_id: int, _db=murfey_db, demo: bool = False): +def _flush_foil_hole_records(grid_square_id: int, _db, demo: bool = False): for fh in _db.exec( select(db.FoilHole).where(db.FoilHole.grid_square_id == grid_square_id) ).all(): @@ -1527,7 +1528,7 @@ def _flush_foil_hole_records(grid_square_id: int, _db=murfey_db, demo: bool = Fa logger.info(f"Flushing foil hole: {fh.name}") -def _register_refinement(message: dict, _db=murfey_db, demo: bool = False): +def _register_refinement(message: dict, _db, demo: bool = False): """Received class to refine from 3D classification""" instrument_name = ( _db.exec(select(db.Session).where(db.Session.id == message["session_id"])) @@ -1675,7 +1676,7 @@ def _register_refinement(message: dict, _db=murfey_db, demo: bool = False): _db.close() -def _register_bfactors(message: dict, _db=murfey_db, demo: bool = False): +def _register_bfactors(message: dict, _db, demo: bool = False): """Received refined class to calculate b-factor""" instrument_name = ( _db.exec(select(db.Session).where(db.Session.id == message["session_id"])) @@ -1796,7 +1797,7 @@ def _register_bfactors(message: dict, _db=murfey_db, demo: bool = False): return True -def _save_bfactor(message: dict, _db=murfey_db, demo: bool = False): +def _save_bfactor(message: dict, _db, demo: bool = False): """Received b-factor from refinement run""" pj_id = _pj_id(message["program_id"], _db, recipe="em-spa-refine") bfactor_run = _db.exec( @@ -1842,11 +1843,11 @@ def _save_bfactor(message: dict, _db=murfey_db, demo: bool = False): # Clean up the b-factors table and release the hold [_db.delete(bf) for bf in all_bfactors] _db.commit() - _release_refine_hold(message) + _release_refine_hold(message, _db) _db.close() -def feedback_callback(header: dict, message: dict) -> None: +def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: try: record = None if "environment" in message: @@ -1856,7 +1857,7 @@ def feedback_callback(header: dict, message: dict) -> None: message = message["payload"] message.update(params) if message["register"] == "motion_corrected": - collected_ids = murfey_db.exec( + collected_ids = _db.exec( select( db.DataCollectionGroup, db.DataCollection, @@ -1871,7 +1872,7 @@ def feedback_callback(header: dict, message: dict) -> None: session_id = collected_ids[0].session_id # Find the autoprocprogram id for the alignment recipe - alignment_ids = murfey_db.exec( + alignment_ids = _db.exec( select( db.DataCollection, db.ProcessingJob, @@ -1883,7 +1884,7 @@ def feedback_callback(header: dict, message: dict) -> None: .where(db.ProcessingJob.recipe == "em-tomo-align") ).one() - relevant_tilt_and_series = murfey_db.exec( + relevant_tilt_and_series = _db.exec( select(db.Tilt, db.TiltSeries) .where(db.Tilt.movie_path == message.get("movie")) .where(db.Tilt.tilt_series_id == db.TiltSeries.id) @@ -1892,26 +1893,24 @@ def feedback_callback(header: dict, message: dict) -> None: relevant_tilt = relevant_tilt_and_series[0] relevant_tilt_series = relevant_tilt_and_series[1] relevant_tilt.motion_corrected = True - murfey_db.add(relevant_tilt) - murfey_db.commit() + _db.add(relevant_tilt) + _db.commit() if ( - check_tilt_series_mc(relevant_tilt_series.id) + check_tilt_series_mc(relevant_tilt_series.id, _db) and not relevant_tilt_series.processing_requested and relevant_tilt_series.tilt_series_length > 2 ): instrument_name = ( - murfey_db.exec( - select(db.Session).where(db.Session.id == session_id) - ) + _db.exec(select(db.Session).where(db.Session.id == session_id)) .one() .instrument_name ) machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] - tilts = get_all_tilts(relevant_tilt_series.id) - ids = get_job_ids(relevant_tilt_series.id, alignment_ids[2].id) - preproc_params = get_tomo_proc_params(ids.dcgid) + tilts = get_all_tilts(relevant_tilt_series.id, _db) + ids = get_job_ids(relevant_tilt_series.id, alignment_ids[2].id, _db) + preproc_params = get_tomo_proc_params(ids.dcgid, _db) stack_file = ( Path(message["mrc_out"]).parents[3] / "Tomograms" @@ -1954,11 +1953,11 @@ def feedback_callback(header: dict, message: dict) -> None: f"No transport object found. Zocalo message would be {zocalo_message}" ) relevant_tilt_series.processing_requested = True - murfey_db.add(relevant_tilt_series) + _db.add(relevant_tilt_series) prom.preprocessed_movies.labels(processing_job=collected_ids[2].id).inc() - murfey_db.commit() - murfey_db.close() + _db.commit() + _db.close() if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None @@ -1970,7 +1969,7 @@ def feedback_callback(header: dict, message: dict) -> None: visit_number=message["visit_number"], db=ISPyBSession(), ) - if dcg_murfey := murfey_db.exec( + if dcg_murfey := _db.exec( select(db.DataCollectionGroup) .where(db.DataCollectionGroup.session_id == message["session_id"]) .where(db.DataCollectionGroup.tag == message.get("tag")) @@ -2011,9 +2010,9 @@ def feedback_callback(header: dict, message: dict) -> None: session_id=message["session_id"], tag=message.get("tag"), ) - murfey_db.add(murfey_dcg) - murfey_db.commit() - murfey_db.close() + _db.add(murfey_dcg) + _db.commit() + _db.close() if murfey.server._transport_object: if dcgid is None: time.sleep(2) @@ -2064,14 +2063,14 @@ def feedback_callback(header: dict, message: dict) -> None: visit_number=message["visit_number"], db=ISPyBSession(), ) - dcg = murfey_db.exec( + dcg = _db.exec( select(db.DataCollectionGroup) .where(db.DataCollectionGroup.session_id == murfey_session_id) .where(db.DataCollectionGroup.tag == message["source"]) ).all() if dcg: dcgid = dcg[0].id - # flush_data_collections(message["source"], murfey_db) + # flush_data_collections(message["source"], _db) else: logger.warning( "No data collection group ID was found for image directory " @@ -2081,7 +2080,7 @@ def feedback_callback(header: dict, message: dict) -> None: if murfey.server._transport_object: murfey.server._transport_object.transport.nack(header, requeue=True) return None - if dc_murfey := murfey_db.exec( + if dc_murfey := _db.exec( select(db.DataCollection) .where(db.DataCollection.tag == message.get("tag")) .where(db.DataCollection.dcg_id == dcgid) @@ -2125,10 +2124,10 @@ def feedback_callback(header: dict, message: dict) -> None: tag=message.get("tag"), dcg_id=dcgid, ) - murfey_db.add(murfey_dc) - murfey_db.commit() + _db.add(murfey_dc) + _db.commit() dcid = murfey_dc.id - murfey_db.close() + _db.close() if dcid is None and murfey.server._transport_object: murfey.server._transport_object.transport.nack(header, requeue=True) return None @@ -2138,7 +2137,7 @@ def feedback_callback(header: dict, message: dict) -> None: elif message["register"] == "processing_job": murfey_session_id = message["session_id"] logger.info("registering processing job") - dc = murfey_db.exec( + dc = _db.exec( select(db.DataCollection, db.DataCollectionGroup) .where(db.DataCollection.dcg_id == db.DataCollectionGroup.id) .where(db.DataCollectionGroup.session_id == murfey_session_id) @@ -2154,7 +2153,7 @@ def feedback_callback(header: dict, message: dict) -> None: if murfey.server._transport_object: murfey.server._transport_object.transport.nack(header, requeue=True) return None - if pj_murfey := murfey_db.exec( + if pj_murfey := _db.exec( select(db.ProcessingJob) .where(db.ProcessingJob.recipe == message["recipe"]) .where(db.ProcessingJob.dc_id == _dcid) @@ -2180,15 +2179,15 @@ def feedback_callback(header: dict, message: dict) -> None: murfey_pj = db.ProcessingJob( id=pid, recipe=message["recipe"], dc_id=_dcid ) - murfey_db.add(murfey_pj) - murfey_db.commit() + _db.add(murfey_pj) + _db.commit() pid = murfey_pj.id - murfey_db.close() + _db.close() if pid is None and murfey.server._transport_object: murfey.server._transport_object.transport.nack(header, requeue=True) return None prom.preprocessed_movies.labels(processing_job=pid) - if not murfey_db.exec( + if not _db.exec( select(db.AutoProcProgram).where(db.AutoProcProgram.pj_id == pid) ).all(): if ISPyBSession() is None: @@ -2204,20 +2203,20 @@ def feedback_callback(header: dict, message: dict) -> None: ) return None murfey_app = db.AutoProcProgram(id=appid, pj_id=pid) - murfey_db.add(murfey_app) - murfey_db.commit() - murfey_db.close() + _db.add(murfey_app) + _db.commit() + _db.close() if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "flush_tomography_preprocess": - _flush_tomography_preprocessing(message) + _flush_tomography_preprocessing(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "spa_processing_parameters": session_id = message["session_id"] - collected_ids = murfey_db.exec( + collected_ids = _db.exec( select( db.DataCollectionGroup, db.DataCollection, @@ -2232,15 +2231,13 @@ def feedback_callback(header: dict, message: dict) -> None: .where(db.ProcessingJob.recipe == "em-spa-preprocess") ).one() pj_id = collected_ids[2].id - if not murfey_db.exec( + if not _db.exec( select(db.SPARelionParameters).where( db.SPARelionParameters.pj_id == pj_id ) ).all(): instrument_name = ( - murfey_db.exec( - select(db.Session).where(db.Session.id == session_id) - ) + _db.exec(select(db.Session).where(db.Session.id == session_id)) .one() .instrument_name ) @@ -2271,13 +2268,13 @@ def feedback_callback(header: dict, message: dict) -> None: initial_model="", next_job=0, ) - murfey_db.add(params) - murfey_db.add(feedback_params) - murfey_db.commit() + _db.add(params) + _db.add(feedback_params) + _db.commit() logger.info( f"SPA processing parameters registered for processing job {collected_ids[2].id}" ) - murfey_db.close() + _db.close() else: logger.info( f"SPA processing parameters already exist for processing job ID {pj_id}" @@ -2287,7 +2284,7 @@ def feedback_callback(header: dict, message: dict) -> None: return None elif message["register"] == "tomography_processing_parameters": session_id = message["session_id"] - collected_ids = murfey_db.exec( + collected_ids = _db.exec( select( db.DataCollectionGroup, db.DataCollection, @@ -2302,7 +2299,7 @@ def feedback_callback(header: dict, message: dict) -> None: .where(db.AutoProcProgram.pj_id == db.ProcessingJob.id) .where(db.ProcessingJob.recipe == "em-tomo-preprocess") ).one() - if not murfey_db.exec( + if not _db.exec( select(db.TomographyProcessingParameters.dcg_id).where( db.TomographyProcessingParameters.dcg_id == collected_ids[0].id ) @@ -2318,46 +2315,46 @@ def feedback_callback(header: dict, message: dict) -> None: gain_ref=message["gain_ref"], eer_fractionation_file=message["eer_fractionation_file"], ) - murfey_db.add(params) - murfey_db.commit() - murfey_db.close() + _db.add(params) + _db.commit() + _db.close() if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "done_incomplete_2d_batch": - _release_2d_hold(message) + _release_2d_hold(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "incomplete_particles_file": - _register_incomplete_2d_batch(message) + _register_incomplete_2d_batch(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "complete_particles_file": - _register_complete_2d_batch(message) + _register_complete_2d_batch(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "save_class_selection_score": - _register_class_selection(message) + _register_class_selection(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "done_3d_batch": - _release_3d_hold(message) + _release_3d_hold(message, _db) if message.get("do_refinement"): - _register_refinement(message) + _register_refinement(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "run_class3d": - _register_3d_batch(message) + _register_3d_batch(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "save_initial_model": - _register_initial_model(message) + _register_initial_model(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None @@ -2370,12 +2367,12 @@ def feedback_callback(header: dict, message: dict) -> None: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "atlas_registered": - _flush_grid_square_records(message) + _flush_grid_square_records(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None elif message["register"] == "done_refinement": - bfactors_registered = _register_bfactors(message) + bfactors_registered = _register_bfactors(message, _db) if murfey.server._transport_object: if bfactors_registered: murfey.server._transport_object.transport.ack(header) @@ -2383,7 +2380,7 @@ def feedback_callback(header: dict, message: dict) -> None: murfey.server._transport_object.transport.nack(header) return None elif message["register"] == "done_bfactor": - _save_bfactor(message) + _save_bfactor(message, _db) if murfey.server._transport_object: murfey.server._transport_object.transport.ack(header) return None @@ -2407,7 +2404,7 @@ def feedback_callback(header: dict, message: dict) -> None: workflow: EntryPoint = workflows[0] result = workflow.load()( message=message, - murfey_db=murfey_db, + murfey_db=_db, ) if murfey.server._transport_object: if result: @@ -2427,8 +2424,8 @@ def feedback_callback(header: dict, message: dict) -> None: murfey.server._transport_object.transport.nack(header, requeue=False) return None except PendingRollbackError: - murfey_db.rollback() - murfey_db.close() + _db.rollback() + _db.close() logger.warning("Murfey database required a rollback") if murfey.server._transport_object: murfey.server._transport_object.transport.nack(header, requeue=True) diff --git a/tests/cli/test_repost_failed_calls.py b/tests/cli/test_repost_failed_calls.py index ea8fc76e5..1a2f6df12 100644 --- a/tests/cli/test_repost_failed_calls.py +++ b/tests/cli/test_repost_failed_calls.py @@ -190,15 +190,23 @@ def test_handle_failed_posts(tmp_path): @mock.patch("murfey.cli.repost_failed_calls.dlq_purge") @mock.patch("murfey.cli.repost_failed_calls.handle_failed_posts") @mock.patch("murfey.cli.repost_failed_calls.handle_dlq_messages") -@mock.patch("murfey.cli.repost_failed_calls.get_murfey_db_session") +@mock.patch("murfey.cli.repost_failed_calls.url") +@mock.patch("murfey.cli.repost_failed_calls.create_engine") +@mock.patch("murfey.cli.repost_failed_calls.Session") def test_run_repost_failed_calls( - mock_db, + mock_db_session, + mock_db_engine, + mock_db_url, mock_reinject, mock_repost, mock_purge, mock_security_configuration, ): - mock_db.return_value = "db" + mock_session = mock.MagicMock() + + mock_db_url.return_value = "db_url" + mock_db_engine.return_value = "db_engine" + mock_db_session.return_value = mock_session mock_purge.return_value = ["/path/to/msg1"] config_file = mock_security_configuration @@ -215,14 +223,16 @@ def test_run_repost_failed_calls( repost_failed_calls.run() security_config_class = security_from_file(config_file) - mock_db.assert_called_with(security_config_class) + mock_db_url.assert_called_with(security_config_class) + mock_db_engine.assert_called_with("db_url") + mock_db_session.assert_called_with("db_engine") mock_purge.assert_called_once_with( Path("DLQ_dir"), "murfey_feedback", Path(security_config_dict["rabbitmq_credentials"]), ) - mock_repost.assert_called_once_with(["/path/to/msg1"], "db") + mock_repost.assert_called_once_with(["/path/to/msg1"], mock_session.__enter__()) mock_reinject.assert_called_once_with( ["/path/to/msg1"], Path(security_config_dict["rabbitmq_credentials"]) )