diff --git a/src/murfey/server/api/auth.py b/src/murfey/server/api/auth.py index 24fe07f5f..d790f457d 100644 --- a/src/murfey/server/api/auth.py +++ b/src/murfey/server/api/auth.py @@ -219,15 +219,7 @@ def get_visit_name(session_id: int) -> str: ) -async def validate_frontend_session_access( - session_id: int, - token: Annotated[str, Depends(oauth2_scheme)], -) -> int: - """ - Validates whether a frontend request can access information about this session - """ - visit_name = get_visit_name(session_id) - +async def submit_to_auth_endpoint(url_subpath: str, token: str) -> None: if auth_url: headers = ( {} @@ -241,7 +233,7 @@ async def validate_frontend_session_access( ) async with aiohttp.ClientSession(cookies=cookies) as session: async with session.get( - f"{auth_url}/validate_visit_access/{visit_name}", + f"{auth_url}/{url_subpath}", headers=headers, ) as response: success = response.status == 200 @@ -253,10 +245,21 @@ async def validate_frontend_session_access( detail="You do not have access to this visit", headers={"WWW-Authenticate": "Bearer"}, ) + + +async def validate_frontend_session_access( + session_id: int, + token: Annotated[str, Depends(oauth2_scheme)], +) -> int: + """ + Validates whether a frontend request can access information about this session + """ + visit_name = get_visit_name(session_id) + await submit_to_auth_endpoint(f"validate_visit_access/{visit_name}", token) return session_id -async def validate_instrument_session_access( +async def validate_instrument_server_session_access( session_id: int, token: Annotated[str, Depends(instrument_oauth2_scheme)], ) -> int: @@ -288,9 +291,26 @@ async def validate_instrument_session_access( return session_id +async def validate_user_instrument_access( + instrument_name: str, + token: Annotated[str, Depends(oauth2_scheme)], +) -> str: + """ + Validates whether a frontend request can access information about this instrument + """ + await submit_to_auth_endpoint( + f"validate_instrument_access/{instrument_name}", token + ) + return instrument_name + + # Set validation conditions for the session ID based on where the request is from MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)] -MurfeySessionIDInstrument = Annotated[int, Depends(validate_instrument_session_access)] +MurfeySessionIDInstrument = Annotated[ + int, Depends(validate_instrument_server_session_access) +] + +MurfeyInstrumentNameFrontend = Annotated[str, Depends(validate_user_instrument_access)] """ diff --git a/src/murfey/server/api/instrument.py b/src/murfey/server/api/instrument.py index 34e238208..bcd3a45a8 100644 --- a/src/murfey/server/api/instrument.py +++ b/src/murfey/server/api/instrument.py @@ -12,6 +12,7 @@ from sqlmodel import select from werkzeug.utils import secure_filename +from murfey.server.api.auth import MurfeyInstrumentNameFrontend as MurfeyInstrumentName from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID from murfey.server.api.auth import ( create_access_token, @@ -42,7 +43,7 @@ "/instruments/{instrument_name}/sessions/{session_id}/activate_instrument_server" ) async def activate_instrument_server_for_session( - instrument_name: str, + instrument_name: MurfeyInstrumentName, session_id: int, token_in: Annotated[str, Depends(oauth2_scheme)], db=murfey_db, @@ -80,7 +81,9 @@ async def activate_instrument_server_for_session( @router.get("/instruments/{instrument_name}/sessions/{session_id}/active") -async def check_if_session_is_active(instrument_name: str, session_id: int): +async def check_if_session_is_active( + instrument_name: MurfeyInstrumentName, session_id: int +): if instrument_server_tokens.get(session_id) is None: return {"active": False} async with lock: @@ -214,7 +217,7 @@ async def pass_proc_params_to_instrument_server( @router.get("/instruments/{instrument_name}/instrument_server") -async def check_instrument_server(instrument_name: str): +async def check_instrument_server(instrument_name: MurfeyInstrumentName): data = None machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name @@ -232,7 +235,7 @@ async def check_instrument_server(instrument_name: str): "/instruments/{instrument_name}/sessions/{session_id}/possible_gain_references" ) async def get_possible_gain_references( - instrument_name: str, session_id: MurfeySessionID + instrument_name: MurfeyInstrumentName, session_id: MurfeySessionID ) -> List[File]: data = [] machine_config = get_machine_config(instrument_name=instrument_name)[ @@ -491,7 +494,7 @@ class RSyncerInfo(BaseModel): @router.get("/instruments/{instrument_name}/sessions/{session_id}/rsyncer_info") async def get_rsyncer_info( - instrument_name: str, session_id: MurfeySessionID, db=murfey_db + instrument_name: MurfeyInstrumentName, session_id: MurfeySessionID, db=murfey_db ) -> List[RSyncerInfo]: rsyncer_list = [] analyser_list = [] diff --git a/src/murfey/server/api/session_info.py b/src/murfey/server/api/session_info.py index 804698d87..d44507ad8 100644 --- a/src/murfey/server/api/session_info.py +++ b/src/murfey/server/api/session_info.py @@ -12,6 +12,7 @@ import murfey.server.api.websocket as ws from murfey.server import _transport_object from murfey.server.api import templates +from murfey.server.api.auth import MurfeyInstrumentNameFrontend as MurfeyInstrumentName from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID from murfey.server.api.auth import validate_token from murfey.server.api.shared import get_foil_hole as _get_foil_hole @@ -74,12 +75,14 @@ def connections_check(): @router.get("/instruments/{instrument_name}/machine") -def machine_info_by_instrument(instrument_name: str) -> Optional[MachineConfig]: +def machine_info_by_instrument( + instrument_name: MurfeyInstrumentName, +) -> Optional[MachineConfig]: return get_machine_config_for_instrument(instrument_name) @router.get("/instruments/{instrument_name}/visits_raw", response_model=List[Visit]) -def get_current_visits(instrument_name: str, db=ispyb_db): +def get_current_visits(instrument_name: MurfeyInstrumentName, db=ispyb_db): logger.debug( f"Received request to look up ongoing visits for {sanitise(instrument_name)}" ) @@ -87,7 +90,9 @@ def get_current_visits(instrument_name: str, db=ispyb_db): @router.get("/instruments/{instrument_name}/visits/") -def all_visit_info(instrument_name: str, request: Request, db=ispyb_db): +def all_visit_info( + instrument_name: MurfeyInstrumentName, request: Request, db=ispyb_db +): visits = get_all_ongoing_visits(instrument_name, db) if visits: @@ -159,7 +164,7 @@ class VisitEndTime(BaseModel): @router.post("/instruments/{instrument_name}/visits/{visit}/session/{name}") def create_session( - instrument_name: str, + instrument_name: MurfeyInstrumentName, visit: str, name: str, visit_end_time: VisitEndTime, @@ -195,7 +200,7 @@ def remove_session(session_id: MurfeySessionID, db=murfey_db): @router.get("/instruments/{instrument_name}/visits/{visit_name}/sessions") def get_sessions_with_visit( - instrument_name: str, visit_name: str, db=murfey_db + instrument_name: MurfeyInstrumentName, visit_name: str, db=murfey_db ) -> List[Session]: sessions = db.exec( select(Session) @@ -207,7 +212,7 @@ def get_sessions_with_visit( @router.get("/instruments/{instrument_name}/sessions") async def get_sessions_by_instrument_name( - instrument_name: str, db=murfey_db + instrument_name: MurfeyInstrumentName, db=murfey_db ) -> List[Session]: sessions = db.exec( select(Session).where(Session.instrument_name == instrument_name)