diff --git a/src/murfey/instrument_server/api.py b/src/murfey/instrument_server/api.py index afd890255..ca31f9c21 100644 --- a/src/murfey/instrument_server/api.py +++ b/src/murfey/instrument_server/api.py @@ -215,6 +215,15 @@ def stop_multigrid_watcher(session_id: MurfeySessionID, label: str): return {"success": True} +@router.get("/sessions/{session_id}/multigrid_controller/status") +def check_multigrid_controller_exists( + session_id: MurfeySessionID, +): + if controllers.get(session_id, None) is not None: + return {"exists": True} + return {"exists": False} + + @router.post("/sessions/{session_id}/multigrid_controller/visit_end_time") def update_multigrid_controller_visit_end_time( session_id: MurfeySessionID, end_time: datetime diff --git a/src/murfey/server/api/instrument.py b/src/murfey/server/api/instrument.py index e4acae82c..9f2860f44 100644 --- a/src/murfey/server/api/instrument.py +++ b/src/murfey/server/api/instrument.py @@ -4,7 +4,7 @@ import datetime import logging from pathlib import Path -from typing import Annotated, List, Optional +from typing import Annotated, Any, List, Optional from urllib.parse import quote import aiohttp @@ -101,6 +101,31 @@ async def check_if_session_is_active( return {"active": response.status == 200} +@router.get("/sessions/{session_id}/multigrid_controller/status") +async def check_multigrid_controller_exists(session_id: MurfeySessionID, db=murfey_db): + session = db.exec(select(Session).where(Session.id == session_id)).one() + instrument_name = session.instrument_name + machine_config = get_machine_config(instrument_name=instrument_name)[ + instrument_name + ] + if machine_config.instrument_server_url: + log.debug( + f"Submitting request to inspect multigrid controller for session {session_id}" + ) + async with aiohttp.ClientSession() as clientsession: + async with clientsession.get( + f"{machine_config.instrument_server_url}{url_path_for('api.router', 'check_multigrid_controller_exists', session_id=session_id)}", + headers={ + "Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}" + }, + ) as resp: + data: dict[str, Any] = await resp.json() + else: + data = {"detail": "No instrument server URL found"} + log.debug(f"Received response: {data}") + return data + + @router.post("/sessions/{session_id}/multigrid_watcher") async def setup_multigrid_watcher( session_id: MurfeySessionID, watcher_spec: MultigridWatcherSetup, db=murfey_db @@ -165,6 +190,36 @@ async def start_multigrid_watcher(session_id: MurfeySessionID, db=murfey_db): return data +@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time") +async def update_visit_end_time( + session_id: MurfeySessionID, end_time: datetime.datetime, db=murfey_db +): + # Load data for session + session_entry = db.exec(select(Session).where(Session.id == session_id)).one() + instrument_name = session_entry.instrument_name + + # Update visit end time in database + session_entry.visit_end_time = end_time + db.add(session_entry) + db.commit() + + # Update the multigrid controller + data = {} + machine_config = get_machine_config(instrument_name=instrument_name)[ + instrument_name + ] + if machine_config.instrument_server_url: + async with aiohttp.ClientSession() as clientsession: + async with clientsession.post( + f"{machine_config.instrument_server_url}{url_path_for('api.router', 'update_multigrid_controller_visit_end_time', session_id=session_id)}?end_time={quote(end_time.isoformat())}", + headers={ + "Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}" + }, + ) as resp: + data = await resp.json() + return data + + class ProvidedProcessingParameters(BaseModel): dose_per_frame: float extract_downscale: bool = True @@ -397,36 +452,6 @@ async def finalise_session(session_id: MurfeySessionID, db=murfey_db): return data -@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time") -async def update_visit_end_time( - session_id: MurfeySessionID, end_time: datetime.datetime, db=murfey_db -): - # Load data for session - session_entry = db.exec(select(Session).where(Session.id == session_id)).one() - instrument_name = session_entry.instrument_name - - # Update visit end time in database - session_entry.visit_end_time = end_time - db.add(session_entry) - db.commit() - - # Update the multigrid controller - data = {} - machine_config = get_machine_config(instrument_name=instrument_name)[ - instrument_name - ] - if machine_config.instrument_server_url: - async with aiohttp.ClientSession() as clientsession: - async with clientsession.post( - f"{machine_config.instrument_server_url}{url_path_for('api.router', 'update_multigrid_controller_visit_end_time', session_id=session_id)}?end_time={quote(end_time.isoformat())}", - headers={ - "Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}" - }, - ) as resp: - data = await resp.json() - return data - - @router.post("/sessions/{session_id}/abandon_session") async def abandon_session(session_id: MurfeySessionID, db=murfey_db): data = {} diff --git a/src/murfey/util/route_manifest.yaml b/src/murfey/util/route_manifest.yaml index 03bfb2af7..a92821c19 100644 --- a/src/murfey/util/route_manifest.yaml +++ b/src/murfey/util/route_manifest.yaml @@ -43,6 +43,11 @@ murfey.instrument_server.api.router: path_params: [] methods: - POST + - path: /sessions/{session_id}/multigrid_controller/status + function: check_multigrid_controller_exists + path_params: [] + methods: + - GET - path: /sessions/{session_id}/stop_rsyncer function: stop_rsyncer path_params: [] @@ -503,6 +508,11 @@ murfey.server.api.instrument.router: path_params: [] methods: - POST + - path: /instrument_server/sessions/{session_id}/multigrid_controller/status + function: check_multigrid_controller_exists + path_params: [] + methods: + - GET - path: /instrument_server/sessions/{session_id}/provided_processing_parameters function: pass_proc_params_to_instrument_server path_params: [] diff --git a/tests/instrument_server/test_api.py b/tests/instrument_server/test_api.py index 785f1e812..310479fd6 100644 --- a/tests/instrument_server/test_api.py +++ b/tests/instrument_server/test_api.py @@ -1,18 +1,33 @@ from pathlib import Path from typing import Optional -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, MagicMock, patch from urllib.parse import urlparse -from pytest import mark +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from pytest_mock import MockerFixture -from murfey.instrument_server.api import ( - GainReference, - _get_murfey_url, - upload_gain_reference, -) +from murfey.instrument_server.api import _get_murfey_url +from murfey.instrument_server.api import router as client_router +from murfey.instrument_server.api import validate_session_token from murfey.util import posix_path from murfey.util.api import url_path_for + +def set_up_test_client(session_id: Optional[int] = None): + """ + Helper function to set up a test client for the instrument server with validation + checks disabled. + """ + # Set up the instrument server + client_app = FastAPI() + if session_id: + client_app.dependency_overrides[validate_session_token] = lambda: session_id + client_app.include_router(client_router) + return TestClient(client_app) + + test_get_murfey_url_params_matrix = ( # Server URL to use ("default",), @@ -23,7 +38,7 @@ ) -@mark.parametrize("test_params", test_get_murfey_url_params_matrix) +@pytest.mark.parametrize("test_params", test_get_murfey_url_params_matrix) def test_get_murfey_url( test_params: tuple[str], mock_client_configuration, # From conftest.py @@ -57,6 +72,24 @@ def test_get_murfey_url( assert parsed_server.path == parsed_original.path +def test_check_multigrid_controller_exists(mocker: MockerFixture): + session_id = 1 + + # Patch out the multigrid controllers that have been stored in memory + mocker.patch("murfey.instrument_server.api.controllers", {session_id: MagicMock()}) + + # Set up the test client + client_server = set_up_test_client(session_id=session_id) + url_path = url_path_for( + "api.router", "check_multigrid_controller_exists", session_id=session_id + ) + response = client_server.get(url_path) + + # Check that the result is as expected + assert response.status_code == 200 + assert response.json() == {"exists": True} + + test_upload_gain_reference_params_matrix = ( # Rsync URL settings ("http://1.1.1.1",), # When rsync_url is provided @@ -65,25 +98,23 @@ def test_get_murfey_url( ) -@mark.parametrize("test_params", test_upload_gain_reference_params_matrix) -@patch("murfey.instrument_server.api.subprocess") -@patch("murfey.instrument_server.api.tokens") -@patch("murfey.instrument_server.api._get_murfey_url") -@patch("murfey.instrument_server.api.requests") +@pytest.mark.parametrize("test_params", test_upload_gain_reference_params_matrix) def test_upload_gain_reference( - mock_request, - mock_get_server_url, - mock_tokens, - mock_subprocess, + mocker: MockerFixture, test_params: tuple[Optional[str]], ): - # Unpack test parameters and define other ones (rsync_url_setting,) = test_params - server_url = "http://0.0.0.0:8000" + server_url = "https://murfey.server.test" instrument_name = "murfey" session_id = 1 + # Mock out objects + mock_request = mocker.patch("murfey.instrument_server.api.requests") + mock_get_server_url = mocker.patch("murfey.instrument_server.api._get_murfey_url") + mock_subprocess = mocker.patch("murfey.instrument_server.api.subprocess") + mocker.patch("murfey.instrument_server.api.tokens", {session_id: ANY}) + # Create a mock machine config base on the test params rsync_module = "data" gain_ref_dir = "C:/ProgramData/Gatan/Gain Reference" @@ -95,12 +126,12 @@ def test_upload_gain_reference( mock_machine_config["rsync_url"] = rsync_url_setting # Assign expected values to the mock objects - mock_response = Mock() + mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = mock_machine_config mock_request.get.return_value = mock_response mock_get_server_url.return_value = server_url - mock_subprocess.run.return_value = Mock(returncode=0) + mock_subprocess.run.return_value = MagicMock(returncode=0) # Construct payload and pass request to function gain_ref_file = f"{gain_ref_dir}/gain.mrc" @@ -111,13 +142,18 @@ def test_upload_gain_reference( "visit_path": visit_path, "gain_destination_dir": gain_dest_dir, } - result = upload_gain_reference( + + # Set up instrument server test client + client_server = set_up_test_client(session_id=session_id) + + # Poke the endpoint with the expected data + url_path = url_path_for( + "api.router", + "upload_gain_reference", instrument_name=instrument_name, session_id=session_id, - gain_reference=GainReference( - **payload, - ), ) + response = client_server.post(url_path, json=payload) # Check that the machine config request was called machine_config_url = f"{server_url}{url_path_for('session_control.router', 'machine_info_by_instrument', instrument_name=instrument_name)}" @@ -145,4 +181,4 @@ def test_upload_gain_reference( ) # Check that the function ran through to completion successfully - assert result == {"success": True} + assert response.json() == {"success": True} diff --git a/tests/server/api/test_instrument.py b/tests/server/api/test_instrument.py new file mode 100644 index 000000000..b8c83ae8f --- /dev/null +++ b/tests/server/api/test_instrument.py @@ -0,0 +1,130 @@ +from typing import Literal +from unittest import mock +from unittest.mock import AsyncMock, MagicMock + +from fastapi import FastAPI +from fastapi.testclient import TestClient +from pytest_mock import MockerFixture + +from murfey.server.api.auth import validate_frontend_session_access, validate_token +from murfey.server.api.instrument import router as backend_router +from murfey.server.murfey_db import murfey_db_session +from murfey.util.api import url_path_for + + +def mock_aiohttp_clientsession( + mocker: MockerFixture, + method: Literal["get", "post", "delete"] = "get", + json_data={}, + status=200, +): + """ + Helper function to patch a aiohttp.ClientSession GET request. This returns a + mocked async context manager with a mocked response that, in turn, returns + the given JSON data and status. + + Returns the mocked ClientSession, which can then be inspected to assert that + the expected calls were made. + """ + + # Mock out the async response + mock_response = MagicMock() + mock_response.json = AsyncMock(return_value=json_data) + mock_response.status = status + + # Mock out the context manager returned by clientsession.get() + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_response) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + + # Mock the client session + mock_clientsession = MagicMock() + mock_clientsession.__aenter__ = AsyncMock(return_value=mock_clientsession) + mock_clientsession.__aexit__ = AsyncMock(return_value=None) + + # Assign the context manager to the request method being tested + getattr(mock_clientsession, method.lower()).return_value = mock_context_manager + + # Patch 'aiohttp.ClientSession' to return the mocked client session + mocker.patch("aiohttp.ClientSession", return_value=mock_clientsession) + + return mock_clientsession, mock_response + + +def test_check_multigrid_controller_exists(mocker: MockerFixture): + # Set up the objects to mock + instrument_name = "test" + session_id = 1 + instrment_server_url = "https://murfey.instrument-server.test" + + # Override the database session generator + mock_session = MagicMock() + mock_session.instrument_name = instrument_name + mock_query_result = MagicMock() + mock_query_result.one.return_value = mock_session + mock_db_session = MagicMock() + mock_db_session.exec.return_value = mock_query_result + + def mock_get_db_session(): + yield mock_db_session + + # Mock the machine config + mock_machine_config = MagicMock() + mock_machine_config.instrument_server_url = instrment_server_url + mock_get_machine_config = mocker.patch( + "murfey.server.api.instrument.get_machine_config" + ) + mock_get_machine_config.return_value = { + instrument_name: mock_machine_config, + } + + # Mock the instrument server tokens dictionary + mock_tokens = mocker.patch( + "murfey.server.api.instrument.instrument_server_tokens", + {session_id: {"access_token": mock.sentinel}}, + ) + + # Mock out the async GET request in the endpoint + mock_clientsession, _ = mock_aiohttp_clientsession( + mocker, + method="get", + json_data={"exists": True}, + status=200, + ) + + # Set up the backend server + backend_app = FastAPI() + + # Override validation and database dependencies + backend_app.dependency_overrides[validate_token] = lambda: None + backend_app.dependency_overrides[validate_frontend_session_access] = ( + lambda: session_id + ) + backend_app.dependency_overrides[murfey_db_session] = mock_get_db_session + backend_app.include_router(backend_router) + backend_server = TestClient(backend_app) + + # Construct the URL paths for poking and sending to + backend_url_path = url_path_for( + "api.instrument.router", + "check_multigrid_controller_exists", + session_id=session_id, + ) + client_url_path = url_path_for( + "api.router", + "check_multigrid_controller_exists", + session_id=session_id, + ) + + # Poke the backend + response = backend_server.get(backend_url_path) + + # Check that the expected calls were made + mock_db_session.exec.assert_called_once() + mock_get_machine_config.assert_called_once_with(instrument_name=instrument_name) + mock_clientsession.get.assert_called_once_with( + f"{instrment_server_url}{client_url_path}", + headers={"Authorization": f"Bearer {mock_tokens[session_id]['access_token']}"}, + ) + assert response.status_code == 200 + assert response.json() == {"exists": True}