diff --git a/pyproject.toml b/pyproject.toml index 41b1a2143..e33dd2365 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,8 +96,6 @@ GitHub = "https://github.com/DiamondLightSource/python-murfey" "murfey.spa_inject" = "murfey.cli.inject_spa_processing:run" "murfey.spa_ispyb_entries" = "murfey.cli.spa_ispyb_messages:run" "murfey.transfer" = "murfey.cli.transfer:run" -[project.entry-points."murfey.auth.token_validation"] -"password" = "murfey.server.api.auth:password_token_validation" [project.entry-points."murfey.config.extraction"] "murfey_machine" = "murfey.util.config:get_extended_machine_config" [project.entry-points."murfey.workflows"] diff --git a/src/murfey/client/contexts/spa.py b/src/murfey/client/contexts/spa.py index 89017f60a..6b249c8a3 100644 --- a/src/murfey/client/contexts/spa.py +++ b/src/murfey/client/contexts/spa.py @@ -567,7 +567,7 @@ def post_transfer( ) if not environment.movie_counters.get(str(source)): movie_counts_get = capture_get( - f"{environment.url.geturl()}{url_path_for('session_info.router', 'count_number_of_movies')}", + f"{environment.url.geturl()}{url_path_for('session_control.router', 'count_number_of_movies')}", ) if movie_counts_get is not None: environment.movie_counters[str(source)] = count( @@ -581,7 +581,7 @@ def post_transfer( eer_fractionation_file = None if file_transferred_to.suffix == ".eer": response = capture_post( - f"{str(environment.url.geturl())}{url_path_for('file_manip.router', 'write_eer_fractionation_file', visit_name=environment.visit, session_id=environment.murfey_session)}", + f"{str(environment.url.geturl())}{url_path_for('file_io_instrument.router', 'write_eer_fractionation_file', visit_name=environment.visit, session_id=environment.murfey_session)}", json={ "eer_path": str(file_transferred_to), "fractionation": environment.data_collection_parameters[ diff --git a/src/murfey/client/contexts/tomo.py b/src/murfey/client/contexts/tomo.py index ab502febf..9562041f7 100644 --- a/src/murfey/client/contexts/tomo.py +++ b/src/murfey/client/contexts/tomo.py @@ -317,7 +317,7 @@ def _add_tilt( eer_fractionation_file = None if environment.data_collection_parameters.get("num_eer_frames"): response = requests.post( - f"{str(environment.url.geturl())}{url_path_for('file_manip.router', 'write_eer_fractionation_file', visit_name=environment.visit, session_id=environment.murfey_session)}", + f"{str(environment.url.geturl())}{url_path_for('file_io_instrument.router', 'write_eer_fractionation_file', visit_name=environment.visit, session_id=environment.murfey_session)}", json={ "num_frames": environment.data_collection_parameters[ "num_eer_frames" diff --git a/src/murfey/client/multigrid_control.py b/src/murfey/client/multigrid_control.py index 03ab2cfaa..d32aa8747 100644 --- a/src/murfey/client/multigrid_control.py +++ b/src/murfey/client/multigrid_control.py @@ -251,7 +251,7 @@ def _start_rsyncer( log.info(f"starting rsyncer: {source}") if transfer: # Always make sure the destination directory exists - make_directory_url = f"{self.murfey_url}{url_path_for('file_manip.router', 'make_rsyncer_destination', session_id=self.session_id)}" + make_directory_url = f"{self.murfey_url}{url_path_for('file_io_instrument.router', 'make_rsyncer_destination', session_id=self.session_id)}" capture_post(make_directory_url, json={"destination": destination}) if self._environment: self._environment.default_destinations[source] = destination @@ -437,7 +437,7 @@ def _start_dc(self, json, from_form: bool = False): log.info("Registering tomography processing parameters") if self._environment.data_collection_parameters.get("num_eer_frames"): eer_response = requests.post( - f"{str(self._environment.url.geturl())}{url_path_for('file_manip.router', 'write_eer_fractionation_file', visit_name=self._environment.visit, session_id=self._environment.murfey_session)}", + f"{str(self._environment.url.geturl())}{url_path_for('file_io_instrument.router', 'write_eer_fractionation_file', visit_name=self._environment.visit, session_id=self._environment.murfey_session)}", json={ "num_frames": self._environment.data_collection_parameters[ "num_eer_frames" diff --git a/src/murfey/client/tui/app.py b/src/murfey/client/tui/app.py index ff6542de4..b6d73ab5b 100644 --- a/src/murfey/client/tui/app.py +++ b/src/murfey/client/tui/app.py @@ -209,7 +209,7 @@ def _start_rsyncer( log.info(f"starting rsyncer: {source}") if transfer: # Always make sure the destination directory exists - make_directory_url = f"{str(self._url.geturl())}{url_path_for('file_manip.router', 'make_rsyncer_destination', session_id=self._environment.murfey_session)}" + make_directory_url = f"{str(self._url.geturl())}{url_path_for('file_io_instrument.router', 'make_rsyncer_destination', session_id=self._environment.murfey_session)}" capture_post(make_directory_url, json={"destination": destination}) if self._environment: self._environment.default_destinations[source] = destination @@ -488,7 +488,7 @@ def _start_dc(self, json, from_form: bool = False): log.info("Registering tomography processing parameters") if self.app._environment.data_collection_parameters.get("num_eer_frames"): eer_response = requests.post( - f"{str(self.app._environment.url.geturl())}{url_path_for('file_manip.router', 'write_eer_fractionation_file', visit_name=self.app._environment.visit, session_id=self.app._environment.murfey_session)}", + f"{str(self.app._environment.url.geturl())}{url_path_for('file_io_instrument.router', 'write_eer_fractionation_file', visit_name=self.app._environment.visit, session_id=self.app._environment.murfey_session)}", json={ "num_frames": self.app._environment.data_collection_parameters[ "num_eer_frames" diff --git a/src/murfey/client/tui/screens.py b/src/murfey/client/tui/screens.py index 04b52c5be..3ff79d5ae 100644 --- a/src/murfey/client/tui/screens.py +++ b/src/murfey/client/tui/screens.py @@ -110,7 +110,7 @@ def determine_default_destination( _default = environment.destination_registry[source_name] else: suggested_path_response = capture_post( - url=f"{str(environment.url.geturl())}{url_path_for('file_manip.router', 'suggest_path', visit_name=visit, session_id=environment.murfey_session)}", + url=f"{str(environment.url.geturl())}{url_path_for('file_io_instrument.router', 'suggest_path', visit_name=visit, session_id=environment.murfey_session)}", json={ "base_path": f"{destination}/{visit}/{mid_path.parent if include_mid_path else ''}/raw", "touch": touch, @@ -906,7 +906,7 @@ def on_button_pressed(self, event): f"Gain reference file {posix_path(self._dir_tree._gain_reference)!r} was not successfully transferred to {visit_path}/processing" ) process_gain_response = requests.post( - url=f"{str(self.app._environment.url.geturl())}{url_path_for('file_manip.router', 'process_gain', session_id=self.app._environment.murfey_session)}", + url=f"{str(self.app._environment.url.geturl())}{url_path_for('file_io_instrument.router', 'process_gain', session_id=self.app._environment.murfey_session)}", json={ "gain_ref": str(self._dir_tree._gain_reference), "eer": bool( diff --git a/src/murfey/instrument_server/api.py b/src/murfey/instrument_server/api.py index 93f49099e..95ff5f506 100644 --- a/src/murfey/instrument_server/api.py +++ b/src/murfey/instrument_server/api.py @@ -51,6 +51,9 @@ def validate_session_token( session_id: int, token: Annotated[str, Depends(oauth2_scheme)] ): + """ + Validates the token received from the backend server + """ try: decoded_data = jwt.decode( token, @@ -62,7 +65,7 @@ def validate_session_token( except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", + detail="Could not validate credentials from backend", headers={"WWW-Authenticate": "Bearer"}, ) return session_id diff --git a/src/murfey/server/api/auth.py b/src/murfey/server/api/auth.py index edcdc6589..4a61b2868 100644 --- a/src/murfey/server/api/auth.py +++ b/src/murfey/server/api/auth.py @@ -3,18 +3,22 @@ import secrets import time from logging import getLogger -from typing import Annotated, Dict +from typing import Dict from uuid import uuid4 import aiohttp import requests -from backports.entry_points_selectable import entry_points -from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import ( + APIKeyCookie, + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, +) from jose import JWTError, jwt from passlib.context import CryptContext from pydantic import BaseModel from sqlmodel import Session, create_engine, select +from typing_extensions import Annotated from murfey.server.murfey_db import murfey_db, url from murfey.util.api import url_path_for @@ -26,39 +30,10 @@ logger = getLogger("murfey.server.api.auth") # Set up router -router = APIRouter(tags=["Authentication"]) - - -class CookieScheme(HTTPBearer): - def __init__( - self, - *, - description: str | None = None, - auto_error: bool = True, - cookie_key: str = "cookie_auth", - ): - """ - Args: - cookie_key: Cookie key to look for in requests - """ - super().__init__( - description=description, - auto_error=auto_error, - ) - - self.cookie_key = cookie_key - - async def __call__(self, request: Request): - token = request.cookies.get(self.cookie_key) - if token is None: - if self.auto_error: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated", - ) - else: - return None - return token +router = APIRouter( + prefix="/auth", + tags=["Authentication"], +) # Set up variables used for authentication @@ -69,25 +44,15 @@ async def __call__(self, request: Request): if security_config.auth_type == "password": oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") else: - oauth2_scheme = CookieScheme(cookie_key=security_config.cookie_key) + oauth2_scheme = APIKeyCookie(name=security_config.cookie_key) +if security_config.instrument_auth_type == "token": + instrument_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +else: + instrument_oauth2_scheme = lambda *args, **kwargs: None pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") instrument_server_tokens: Dict[float, dict] = {} - -""" -HELPER FUNCTIONS -""" - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) - - -def hash_password(password: str) -> str: - return pwd_context.hash(password) - - # Set up database engine try: _url = url(security_config) @@ -96,22 +61,19 @@ def hash_password(password: str) -> str: engine = None -def validate_user(username: str, password: str) -> bool: - try: - with Session(engine) as murfey_db: - user = murfey_db.exec(select(User).where(User.username == username)).one() - except Exception: - return False - return verify_password(password, user.hashed_password) +def hash_password(password: str) -> str: + return pwd_context.hash(password) -def validate_visit(visit_name: str, token: str) -> bool: - if validators := entry_points().select( - group="murfey.auth.session_validation", - name=security_config.auth_type, - ): - return validators[0].load()(visit_name, token) - return True +""" +======================================================================================= +TOKEN VALIDATION FUNCTIONS +======================================================================================= + +Functions and helpers used to validate incoming requests from both the client and +the frontend. 'validate_token()' and 'validate_instrument_token()' are imported +int the other FastAPI modules and attached as dependencies to the routers. +""" def check_user(username: str) -> bool: @@ -123,40 +85,12 @@ def check_user(username: str) -> bool: return username in [u.username for u in users] -def validate_instrument_server_token(timestamp: float) -> bool: - return timestamp in instrument_server_tokens.keys() - - -def validate_instrument_server_session_token(session_id: int, visit: str): - with Session(engine) as murfey_db: - session_data = murfey_db.exec( - select(MurfeySession).where(MurfeySession.id == session_id) - ).all() - if len(session_data) != 1: - return False - return visit == session_data[0].visit - - -def password_token_validation(token: str): - decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - # first check if the token has expired - if expiry_time := decoded_data.get("expiry_time"): - if expiry_time < time.time(): - raise JWTError - if decoded_data.get("user"): - if not check_user(decoded_data["user"]): - raise JWTError - elif decoded_data.get("session") is not None: - if not validate_instrument_server_session_token( - decoded_data["session"], decoded_data["visit"] - ): - raise JWTError - else: - raise JWTError - - async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]): + """ + Used by the backend routers to validate requests coming in from frontend. + """ try: + # Validate using auth URL if provided; will error if invalid if auth_url: headers = ( {} @@ -170,116 +104,279 @@ async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]): ) async with aiohttp.ClientSession(cookies=cookies) as session: async with session.get( - f"{auth_url}{url_path_for('auth.router', 'simple_token_validation')}", + f"{auth_url}/validate_token", headers=headers, ) as response: success = response.status == 200 validation_outcome = await response.json() if not (success and validation_outcome.get("valid")): raise JWTError + # If authenticating using cookies; an auth URL MUST be provided else: - if validators := entry_points().select( - group="murfey.auth.token_validation", - name=security_config.auth_type, - ): - validators[0].load()(token) + if security_config.auth_type == "cookie": + raise JWTError + # Validate using password + if security_config.auth_type == "password": + decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + # Check that the user is present and is valid + if decoded_data.get("user"): + if not check_user(decoded_data["user"]): + raise JWTError else: raise JWTError except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", + detail="Could not validate credentials from frontend", headers={"WWW-Authenticate": "Bearer"}, ) return None -async def validate_session_access( - session_id: int, token: Annotated[str, Depends(oauth2_scheme)] -) -> int: - await validate_token(token) +def validate_session_against_visit(session_id: int, visit: str): + """ + Checks that the session ID is associated with the claimed visit. + """ with Session(engine) as murfey_db: - visit_name = ( + session_data = murfey_db.exec( + select(MurfeySession).where(MurfeySession.id == session_id) + ).all() + if len(session_data) != 1: + return False + return visit == session_data[0].visit + + +async def validate_instrument_token( + token: Annotated[str, Depends(instrument_oauth2_scheme)] +): + """ + Used by the backend routers to check the incoming instrument server token. + """ + try: + # Validate using auth URL if provided + if security_config.instrument_auth_url: + async with aiohttp.ClientSession() as session: + headers = ( + {} + if not security_config.instrument_auth_type + else {"Authorization": f"Bearer {token}"} + ) + async with session.get( + f"{security_config.instrument_auth_url}/validate_token", + headers=headers, + ) as response: + success = response.status == 200 + validation_outcome = await response.json() + if not (success and validation_outcome.get("valid")): + raise JWTError + else: + # First, check if the token has expired + decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + if expiry_time := decoded_data.get("expiry_time"): + if expiry_time < time.time(): + raise JWTError + elif decoded_data.get("session") is not None: + # Check that the decoded session corresponds to the visit + if not validate_session_against_visit( + decoded_data["session"], decoded_data["visit"] + ): + raise JWTError + else: + raise JWTError + except JWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials from instrument", + headers={"WWW-Authenticate": "Bearer"}, + ) + return None + + +""" +======================================================================================= +SESSION ID VALIDATION +======================================================================================= + +Annotated ints are defined here that trigger validation of the session IDs in incoming +requests, verifying that the session is allowed to access the particular visit. + +The 'MurfeySessionID...' types are imported and used in the type hints of the endpoint +functions in the other FastAPI routers, depending on whether requests from the frontend +or the instrument are expected. +""" + + +def get_visit_name(session_id: int) -> str: + with Session(engine) as murfey_db: + return ( murfey_db.exec(select(MurfeySession).where(MurfeySession.id == session_id)) .one() .visit ) - if not validate_visit(visit_name, token): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="You do not have access to this visit", - headers={"WWW-Authenticate": "Bearer"}, + + +async def validate_frontend_session_access( + session_id: int, + token: Annotated[str, Depends(oauth2_scheme)], +) -> int: + """ + Validates whether a frontend request can access information about this session + """ + visit_name = get_visit_name(session_id) + + if auth_url: + headers = ( + {} + if security_config.auth_type == "cookie" + else {"Authorization": f"Bearer {token}"} ) + cookies = ( + {security_config.cookie_key: token} + if security_config.auth_type == "cookie" + else {} + ) + async with aiohttp.ClientSession(cookies=cookies) as session: + async with session.get( + f"{auth_url}/validate_visit_access/{visit_name}", + headers=headers, + ) as response: + success = response.status == 200 + validation_outcome: dict = await response.json() + if not (success and validation_outcome.get("valid")): + logger.warning("Unauthorised visit access request from frontend") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="You do not have access to this visit", + headers={"WWW-Authenticate": "Bearer"}, + ) return session_id -class Token(BaseModel): - access_token: str - token_type: str +async def validate_instrument_session_access( + session_id: int, + token: Annotated[str, Depends(instrument_oauth2_scheme)], +) -> int: + """ + Validates whether an instrument request can access information about this session + """ + visit_name = get_visit_name(session_id) + + if security_config.instrument_auth_url: + async with aiohttp.ClientSession() as session: + headers = ( + {} + if not security_config.instrument_auth_type + else {"Authorization": f"Bearer {token}"} + ) + async with session.get( + f"{security_config.instrument_auth_url}/validate_visit_access/{visit_name}", + headers=headers, + ) as response: + success = response.status == 200 + validation_outcome = await response.json() + if not (success and validation_outcome.get("valid")): + logger.warning("Unauthorised visit access request from instrument") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="You do not have access to this visit", + headers={"WWW-Authenticate": "Bearer"}, + ) + return session_id + + +# Set validation conditions for the session ID based on where the request is from +MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)] +MurfeySessionIDInstrument = Annotated[int, Depends(validate_instrument_session_access)] + + +""" +======================================================================================= +API ENDPOINTS AND HELPER FUNCTIONS/CLASSES +======================================================================================= +""" + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + +def validate_user(username: str, password: str) -> bool: + try: + with Session(engine) as murfey_db: + user = murfey_db.exec(select(User).where(User.username == username)).one() + except Exception: + return False + return verify_password(password, user.hashed_password) def create_access_token(data: dict, token: str = "") -> str: - if auth_url and data.get("session"): - session_id = data["session"] - if not isinstance(session_id, int) and session_id > 0: - # check the session ID is alphanumeric for security - raise ValueError("Session ID was invalid (not alphanumeric)") - minted_token_response = requests.get( - f"{auth_url}{url_path_for('auth.router', 'mint_session_token', session_id=session_id)}", - headers={"Authorization": f"Bearer {token}"}, - ) - if minted_token_response.status_code != 200: - raise RuntimeError( - f"Request received status code {minted_token_response.status_code} when trying to create session token" + + # If authenticating with password, auth URL needs a 'mint_session_token' endpoint + if security_config.auth_type == "password": + if auth_url and data.get("session"): + session_id = data["session"] + if not isinstance(session_id, int) and session_id > 0: + # Check the session ID is alphanumeric for security + raise ValueError("Session ID was invalid (not alphanumeric)") + minted_token_response = requests.get( + f"{auth_url}{url_path_for('auth.router', 'mint_session_token', session_id=session_id)}", + headers={"Authorization": f"Bearer {token}"}, ) - return minted_token_response.json()["access_token"] + if minted_token_response.status_code != 200: + raise RuntimeError( + f"Request received status code {minted_token_response.status_code} when trying to create session token" + ) + return minted_token_response.json()["access_token"] to_encode = data.copy() + # Make token for instrument encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -MurfeySessionID = Annotated[int, Depends(validate_session_access)] - -""" -API ENDPOINTS -""" +class Token(BaseModel): + access_token: str + token_type: str @router.post("/token") async def generate_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ) -> Token: - if auth_url: - data = aiohttp.FormData() - data.add_field("username", form_data.username) - data.add_field("password", form_data.password) - async with aiohttp.ClientSession() as session: - async with session.post( - f"{auth_url}{url_path_for('auth.router', 'generate_token')}", - data=data, - ) as response: - validated = response.status == 200 - token = await response.json() - access_token = token.get("access_token") - else: - validated = validate_user(form_data.username, form_data.password) - if not validated: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Bearer"}, - ) - if not auth_url: - access_token = create_access_token( - data={"user": form_data.username}, - ) - return Token(access_token=access_token, token_type="bearer") + # Only generate a token if it's a password + if security_config.auth_type == "password": + if auth_url: + data = aiohttp.FormData() + data.add_field("username", form_data.username) + data.add_field("password", form_data.password) + async with aiohttp.ClientSession() as session: + async with session.post( + f"{auth_url}{url_path_for('auth.router', 'generate_token')}", + data=data, + ) as response: + validated = response.status == 200 + token = await response.json() + access_token = token.get("access_token") + else: + validated = validate_user(form_data.username, form_data.password) + access_token = create_access_token( + data={"user": form_data.username}, + ) + if not validated: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + return Token(access_token=access_token, token_type="bearer") + + # Return empty token otherwise + return Token(access_token="", token_type="bearer") @router.get("/sessions/{session_id}/token") -async def mint_session_token(session_id: MurfeySessionID, db=murfey_db): +async def mint_session_token(session_id: MurfeySessionIDFrontend, db=murfey_db): visit = ( db.exec(select(MurfeySession).where(MurfeySession.id == session_id)).one().visit ) @@ -298,5 +395,7 @@ async def mint_session_token(session_id: MurfeySessionID, db=murfey_db): @router.get("/validate_token") -async def simple_token_validation(token: Annotated[str, Depends(validate_token)]): +async def simple_token_validation( + token: Annotated[str, Depends(validate_instrument_token)] +): return {"valid": True} diff --git a/src/murfey/server/api/clem.py b/src/murfey/server/api/clem.py index 1bc043fb8..a13b547b7 100644 --- a/src/murfey/server/api/clem.py +++ b/src/murfey/server/api/clem.py @@ -31,7 +31,10 @@ logger = getLogger("murfey.server.api.clem") # Create APIRouter class object -router = APIRouter(tags=["Workflows: CLEM"]) +router = APIRouter( + prefix="/workflow/clem", + tags=["Workflows: CLEM"], +) # Valid file types valid_file_types = ( diff --git a/src/murfey/server/api/file_io_frontend.py b/src/murfey/server/api/file_io_frontend.py new file mode 100644 index 000000000..531fa81c8 --- /dev/null +++ b/src/murfey/server/api/file_io_frontend.py @@ -0,0 +1,26 @@ +from logging import getLogger + +from fastapi import APIRouter, Depends + +from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID +from murfey.server.api.auth import validate_token +from murfey.server.api.file_io_shared import GainReference +from murfey.server.api.file_io_shared import process_gain as _process_gain +from murfey.server.murfey_db import murfey_db + +logger = getLogger("murfey.server.api.file_io_frontend") + + +router = APIRouter( + prefix="/file_io/frontend", + dependencies=[Depends(validate_token)], + tags=["File I/O: Frontend"], +) + + +@router.post("/sessions/{session_id}/process_gain") +async def process_gain( + session_id: MurfeySessionID, gain_reference_params: GainReference, db=murfey_db +): + result = await _process_gain(session_id, gain_reference_params, db) + return result diff --git a/src/murfey/server/api/file_manip.py b/src/murfey/server/api/file_io_instrument.py similarity index 68% rename from src/murfey/server/api/file_manip.py rename to src/murfey/server/api/file_io_instrument.py index 01724c7eb..ed42ab5ec 100644 --- a/src/murfey/server/api/file_manip.py +++ b/src/murfey/server/api/file_io_instrument.py @@ -8,20 +8,23 @@ from sqlmodel import select from werkzeug.utils import secure_filename -from murfey.server.api.auth import MurfeySessionID, validate_token -from murfey.server.gain import Camera, prepare_eer_gain, prepare_gain +from murfey.server.api.auth import MurfeySessionIDInstrument as MurfeySessionID +from murfey.server.api.auth import validate_instrument_token +from murfey.server.api.file_io_shared import GainReference +from murfey.server.api.file_io_shared import process_gain as _process_gain from murfey.server.murfey_db import murfey_db from murfey.util import sanitise, secure_path from murfey.util.config import get_machine_config from murfey.util.db import Session, SessionProcessingParameters from murfey.util.eer import num_frames -logger = getLogger("murfey.server.api.file_manip") +logger = getLogger("murfey.server.api.file_io_instrument") + router = APIRouter( - prefix="/file_manipulation", - dependencies=[Depends(validate_token)], - tags=["File Manipulation"], + prefix="/file_io/instrument", + dependencies=[Depends(validate_instrument_token)], + tags=["File I/O: Instrument"], ) @@ -106,85 +109,12 @@ def make_rsyncer_destination(session_id: int, destination: Dest, db=murfey_db): return destination -class GainReference(BaseModel): - gain_ref: Path - rescale: bool = True - eer: bool = False - tag: str = "" - - @router.post("/sessions/{session_id}/process_gain") async def process_gain( session_id: MurfeySessionID, gain_reference_params: GainReference, db=murfey_db ): - murfey_session = db.exec(select(Session).where(Session.id == session_id)).one() - visit_name = murfey_session.visit - instrument_name = murfey_session.instrument_name - machine_config = get_machine_config(instrument_name=instrument_name)[ - instrument_name - ] - camera = getattr(Camera, machine_config.camera) - if gain_reference_params.eer: - executables = machine_config.external_executables_eer - else: - executables = machine_config.external_executables - env = machine_config.external_environment - safe_path_name = secure_filename(gain_reference_params.gain_ref.name) - filepath = ( - Path(machine_config.rsync_basepath) - / str(datetime.now().year) - / secure_filename(visit_name) - / machine_config.gain_directory_name - ) - - # Check under previous year if the folder doesn't exist - if not filepath.exists(): - filepath_prev = filepath - filepath = ( - Path(machine_config.rsync_basepath) - / str(datetime.now().year - 1) - / secure_filename(visit_name) - / machine_config.gain_directory_name - ) - # If it's not in the previous year, it's a genuine error - if not filepath.exists(): - log_message = ( - "Unable to find gain reference directory under " - f"{str(filepath_prev)!r} or {str(filepath)}" - ) - logger.error(log_message) - raise FileNotFoundError(log_message) - - if gain_reference_params.eer: - new_gain_ref, new_gain_ref_superres = await prepare_eer_gain( - filepath / safe_path_name, - executables, - env, - tag=gain_reference_params.tag, - ) - else: - new_gain_ref, new_gain_ref_superres = await prepare_gain( - camera, - filepath / safe_path_name, - executables, - env, - rescale=gain_reference_params.rescale, - tag=gain_reference_params.tag, - ) - if new_gain_ref and new_gain_ref_superres: - return { - "gain_ref": new_gain_ref.relative_to(Path(machine_config.rsync_basepath)), - "gain_ref_superres": new_gain_ref_superres.relative_to( - Path(machine_config.rsync_basepath) - ), - } - elif new_gain_ref: - return { - "gain_ref": new_gain_ref.relative_to(Path(machine_config.rsync_basepath)), - "gain_ref_superres": None, - } - else: - return {"gain_ref": str(filepath / safe_path_name), "gain_ref_superres": None} + result = await _process_gain(session_id, gain_reference_params, db) + return result class FractionationParameters(BaseModel): diff --git a/src/murfey/server/api/file_io_shared.py b/src/murfey/server/api/file_io_shared.py new file mode 100644 index 000000000..310c8e4b9 --- /dev/null +++ b/src/murfey/server/api/file_io_shared.py @@ -0,0 +1,94 @@ +from datetime import datetime +from logging import getLogger +from pathlib import Path + +from pydantic import BaseModel +from sqlmodel import select +from werkzeug.utils import secure_filename + +from murfey.server.gain import Camera, prepare_eer_gain, prepare_gain +from murfey.server.murfey_db import murfey_db +from murfey.util.config import get_machine_config +from murfey.util.db import Session + +logger = getLogger("murfey.server.api.file_io_shared") + + +class GainReference(BaseModel): + gain_ref: Path + rescale: bool = True + eer: bool = False + tag: str = "" + + +async def process_gain( + session_id: int, gain_reference_params: GainReference, db=murfey_db +): + murfey_session = db.exec(select(Session).where(Session.id == session_id)).one() + visit_name = murfey_session.visit + instrument_name = murfey_session.instrument_name + machine_config = get_machine_config(instrument_name=instrument_name)[ + instrument_name + ] + camera = getattr(Camera, machine_config.camera) + if gain_reference_params.eer: + executables = machine_config.external_executables_eer + else: + executables = machine_config.external_executables + env = machine_config.external_environment + safe_path_name = secure_filename(gain_reference_params.gain_ref.name) + filepath = ( + Path(machine_config.rsync_basepath) + / str(datetime.now().year) + / secure_filename(visit_name) + / machine_config.gain_directory_name + ) + + # Check under previous year if the folder doesn't exist + if not filepath.exists(): + filepath_prev = filepath + filepath = ( + Path(machine_config.rsync_basepath) + / str(datetime.now().year - 1) + / secure_filename(visit_name) + / machine_config.gain_directory_name + ) + # If it's not in the previous year, it's a genuine error + if not filepath.exists(): + log_message = ( + "Unable to find gain reference directory under " + f"{str(filepath_prev)!r} or {str(filepath)}" + ) + logger.error(log_message) + raise FileNotFoundError(log_message) + + if gain_reference_params.eer: + new_gain_ref, new_gain_ref_superres = await prepare_eer_gain( + filepath / safe_path_name, + executables, + env, + tag=gain_reference_params.tag, + ) + else: + new_gain_ref, new_gain_ref_superres = await prepare_gain( + camera, + filepath / safe_path_name, + executables, + env, + rescale=gain_reference_params.rescale, + tag=gain_reference_params.tag, + ) + if new_gain_ref and new_gain_ref_superres: + return { + "gain_ref": new_gain_ref.relative_to(Path(machine_config.rsync_basepath)), + "gain_ref_superres": new_gain_ref_superres.relative_to( + Path(machine_config.rsync_basepath) + ), + } + elif new_gain_ref: + return { + "gain_ref": new_gain_ref.relative_to(Path(machine_config.rsync_basepath)), + "gain_ref_superres": None, + } + else: + return {"gain_ref": str(filepath / safe_path_name), "gain_ref_superres": None} diff --git a/src/murfey/server/api/instrument.py b/src/murfey/server/api/instrument.py index 59665399a..34e238208 100644 --- a/src/murfey/server/api/instrument.py +++ b/src/murfey/server/api/instrument.py @@ -12,8 +12,8 @@ from sqlmodel import select from werkzeug.utils import secure_filename +from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID from murfey.server.api.auth import ( - MurfeySessionID, create_access_token, instrument_server_tokens, oauth2_scheme, diff --git a/src/murfey/server/api/processing_parameters.py b/src/murfey/server/api/processing_parameters.py index 11b9f57de..80061a234 100644 --- a/src/murfey/server/api/processing_parameters.py +++ b/src/murfey/server/api/processing_parameters.py @@ -6,7 +6,8 @@ from pydantic import BaseModel from sqlmodel import Session, select -from murfey.server.api.auth import MurfeySessionID, validate_token +from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID +from murfey.server.api.auth import validate_token from murfey.server.murfey_db import murfey_db from murfey.util.db import SessionProcessingParameters diff --git a/src/murfey/server/api/prometheus.py b/src/murfey/server/api/prometheus.py index 5e8d867ea..6457f25be 100644 --- a/src/murfey/server/api/prometheus.py +++ b/src/murfey/server/api/prometheus.py @@ -8,7 +8,7 @@ from sqlmodel import select import murfey.server.prometheus as prom -from murfey.server.api.auth import validate_token +from murfey.server.api.auth import validate_instrument_token from murfey.server.murfey_db import murfey_db from murfey.util import sanitise from murfey.util.db import RsyncInstance @@ -18,7 +18,7 @@ router = APIRouter( prefix="/prometheus", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Prometheus"], ) diff --git a/src/murfey/server/api/session_control.py b/src/murfey/server/api/session_control.py index 418226f87..8be560b40 100644 --- a/src/murfey/server/api/session_control.py +++ b/src/murfey/server/api/session_control.py @@ -7,12 +7,14 @@ from fastapi.responses import FileResponse from ispyb.sqlalchemy import AutoProcProgram as ISPyBAutoProcProgram from pydantic import BaseModel +from sqlalchemy import func from sqlmodel import select from werkzeug.utils import secure_filename import murfey.server.prometheus as prom from murfey.server import _transport_object -from murfey.server.api.auth import MurfeySessionID, validate_token +from murfey.server.api.auth import MurfeySessionIDInstrument as MurfeySessionID +from murfey.server.api.auth import validate_instrument_token from murfey.server.api.shared import get_foil_hole as _get_foil_hole from murfey.server.api.shared import ( get_foil_holes_from_grid_square as _get_foil_holes_from_grid_square, @@ -38,6 +40,7 @@ DataCollectionGroup, FoilHole, GridSquare, + Movie, ProcessingJob, RsyncInstance, Session, @@ -60,7 +63,7 @@ router = APIRouter( prefix="/session_control", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Session Control: General"], ) @@ -177,6 +180,14 @@ def register_processing_success_in_ispyb( _transport_object.do_update_processing_status(updated) +@router.get("/num_movies") +def count_number_of_movies(db=murfey_db) -> Dict[str, int]: + res = db.exec( + select(Movie.tag, func.count(Movie.murfey_id)).group_by(Movie.tag) + ).all() + return {r[0]: r[1] for r in res} + + class PostInfo(BaseModel): url: str data: dict @@ -297,7 +308,7 @@ def delete_rsyncer(session_id: int, source: Path, db=murfey_db): spa_router = APIRouter( prefix="/session_control/spa", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Session Control: SPA"], ) @@ -355,7 +366,7 @@ def register_foil_hole( correlative_router = APIRouter( prefix="/session_control/correlative", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Session Control: Correlative Imaging"], ) diff --git a/src/murfey/server/api/session_info.py b/src/murfey/server/api/session_info.py index a844cb822..804698d87 100644 --- a/src/murfey/server/api/session_info.py +++ b/src/murfey/server/api/session_info.py @@ -6,14 +6,14 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import FileResponse from pydantic import BaseModel -from sqlalchemy import func from sqlmodel import select from werkzeug.utils import secure_filename import murfey.server.api.websocket as ws from murfey.server import _transport_object from murfey.server.api import templates -from murfey.server.api.auth import MurfeySessionID, validate_token +from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID +from murfey.server.api.auth import validate_token from murfey.server.api.shared import get_foil_hole as _get_foil_hole from murfey.server.api.shared import ( get_foil_holes_from_grid_square as _get_foil_holes_from_grid_square, @@ -241,14 +241,6 @@ async def get_clients(db=murfey_db): return clients -@router.get("/num_movies") -def count_number_of_movies(db=murfey_db) -> Dict[str, int]: - res = db.exec( - select(Movie.tag, func.count(Movie.murfey_id)).group_by(Movie.tag) - ).all() - return {r[0]: r[1] for r in res} - - class CurrentGainRef(BaseModel): path: str diff --git a/src/murfey/server/api/workflow.py b/src/murfey/server/api/workflow.py index 1e47f7981..06432d96b 100644 --- a/src/murfey/server/api/workflow.py +++ b/src/murfey/server/api/workflow.py @@ -26,7 +26,8 @@ import murfey.server.prometheus as prom from murfey.server import _transport_object -from murfey.server.api.auth import MurfeySessionID, validate_token +from murfey.server.api.auth import MurfeySessionIDInstrument as MurfeySessionID +from murfey.server.api.auth import validate_instrument_token from murfey.server.feedback import ( _murfey_id, check_tilt_series_mc, @@ -68,7 +69,7 @@ router = APIRouter( prefix="/workflow", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Workflows: General"], ) @@ -288,7 +289,7 @@ def register_proc( spa_router = APIRouter( prefix="/workflow/spa", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Workflows: SPA"], ) @@ -496,7 +497,7 @@ async def request_spa_preprocessing( tomo_router = APIRouter( prefix="/workflow/tomo", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Workflows: CryoET"], ) @@ -880,7 +881,7 @@ def _add_tilt(): correlative_router = APIRouter( prefix="/workflow/correlative", - dependencies=[Depends(validate_token)], + dependencies=[Depends(validate_instrument_token)], tags=["Workflows: Correlative Imaging"], ) diff --git a/src/murfey/server/main.py b/src/murfey/server/main.py index fb4b595ea..d588476af 100644 --- a/src/murfey/server/main.py +++ b/src/murfey/server/main.py @@ -14,7 +14,8 @@ import murfey.server.api.bootstrap import murfey.server.api.clem import murfey.server.api.display -import murfey.server.api.file_manip +import murfey.server.api.file_io_frontend +import murfey.server.api.file_io_instrument import murfey.server.api.hub import murfey.server.api.instrument import murfey.server.api.mag_table @@ -71,7 +72,8 @@ class Settings(BaseSettings): app.include_router(murfey.server.api.display.router) app.include_router(murfey.server.api.processing_parameters.router) -app.include_router(murfey.server.api.file_manip.router) +app.include_router(murfey.server.api.file_io_frontend.router) +app.include_router(murfey.server.api.file_io_instrument.router) app.include_router(murfey.server.api.instrument.router) diff --git a/src/murfey/util/config.py b/src/murfey/util/config.py index 9b43fc74b..4c385d385 100644 --- a/src/murfey/util/config.py +++ b/src/murfey/util/config.py @@ -128,6 +128,8 @@ class Security(BaseModel): auth_key: str = "" auth_type: Literal["password", "cookie"] = "password" auth_url: str = "" + instrument_auth_type: Literal["token", ""] = "token" + instrument_auth_url: str = "" cookie_key: str = "" session_validation: str = "" session_token_timeout: Optional[int] = None diff --git a/src/murfey/util/route_manifest.yaml b/src/murfey/util/route_manifest.yaml index ae1b3dadc..ae187f251 100644 --- a/src/murfey/util/route_manifest.yaml +++ b/src/murfey/util/route_manifest.yaml @@ -105,17 +105,17 @@ murfey.instrument_server.api.router: methods: - POST murfey.server.api.auth.router: - - path: /token + - path: /auth/token function: generate_token path_params: [] methods: - POST - - path: /sessions/{session_id}/token + - path: /auth/sessions/{session_id}/token function: mint_session_token path_params: [] methods: - GET - - path: /validate_token + - path: /auth/validate_token function: simple_token_validation path_params: [] methods: @@ -315,56 +315,56 @@ murfey.server.api.bootstrap.version: methods: - GET murfey.server.api.clem.router: - - path: /sessions/{session_id}/clem/lif_files + - path: /workflow/clem/sessions/{session_id}/clem/lif_files function: register_lif_file path_params: - name: session_id type: int methods: - POST - - path: /sessions/{session_id}/clem/tiff_files + - path: /workflow/clem/sessions/{session_id}/clem/tiff_files function: register_tiff_file path_params: - name: session_id type: int methods: - POST - - path: /sessions/{session_id}/clem/metadata_files + - path: /workflow/clem/sessions/{session_id}/clem/metadata_files function: register_clem_metadata path_params: - name: session_id type: int methods: - POST - - path: /sessions/{session_id}/clem/image_series + - path: /workflow/clem/sessions/{session_id}/clem/image_series function: register_image_series path_params: - name: session_id type: int methods: - POST - - path: /sessions/{session_id}/clem/image_stacks + - path: /workflow/clem/sessions/{session_id}/clem/image_stacks function: register_image_stack path_params: - name: session_id type: int methods: - POST - - path: /sessions/{session_id}/clem/preprocessing/process_raw_lifs + - path: /workflow/clem/sessions/{session_id}/clem/preprocessing/process_raw_lifs function: process_raw_lifs path_params: - name: session_id type: int methods: - POST - - path: /sessions/{session_id}/clem/preprocessing/process_raw_tiffs + - path: /workflow/clem/sessions/{session_id}/clem/preprocessing/process_raw_tiffs function: process_raw_tiffs path_params: - name: session_id type: int methods: - POST - - path: /sessions/{session_id}/clem/processing/align_and_merge_stacks + - path: /workflow/clem/sessions/{session_id}/clem/processing/align_and_merge_stacks function: align_and_merge_stacks path_params: - name: session_id @@ -410,8 +410,14 @@ murfey.server.api.display.router: type: int methods: - GET -murfey.server.api.file_manip.router: - - path: /file_manipulation/visits/{visit_name}/{session_id}/suggested_path +murfey.server.api.file_io_frontend.router: + - path: /file_io/frontend/sessions/{session_id}/process_gain + function: process_gain + path_params: [] + methods: + - POST +murfey.server.api.file_io_instrument.router: + - path: /file_io/instrument/visits/{visit_name}/{session_id}/suggested_path function: suggest_path path_params: - name: visit_name @@ -420,19 +426,19 @@ murfey.server.api.file_manip.router: type: int methods: - POST - - path: /file_manipulation/sessions/{session_id}/make_rsyncer_destination + - path: /file_io/instrument/sessions/{session_id}/make_rsyncer_destination function: make_rsyncer_destination path_params: - name: session_id type: int methods: - POST - - path: /file_manipulation/sessions/{session_id}/process_gain + - path: /file_io/instrument/sessions/{session_id}/process_gain function: process_gain path_params: [] methods: - POST - - path: /file_manipulation/visits/{visit_name}/{session_id}/eer_fractionation_file + - path: /file_io/instrument/visits/{visit_name}/{session_id}/eer_fractionation_file function: write_eer_fractionation_file path_params: - name: visit_name @@ -700,6 +706,11 @@ murfey.server.api.session_control.router: path_params: [] methods: - POST + - path: /session_control/num_movies + function: count_number_of_movies + path_params: [] + methods: + - GET - path: /session_control/instruments/{instrument_name}/failed_client_post function: failed_client_post path_params: @@ -915,11 +926,6 @@ murfey.server.api.session_info.router: path_params: [] methods: - GET - - path: /session_info/num_movies - function: count_number_of_movies - path_params: [] - methods: - - GET - path: /session_info/sessions/{session_id}/current_gain_ref function: update_current_gain_ref path_params: [] diff --git a/src/murfey/workflows/spa/flush_spa_preprocess.py b/src/murfey/workflows/spa/flush_spa_preprocess.py index bc226e537..b225dfb69 100644 --- a/src/murfey/workflows/spa/flush_spa_preprocess.py +++ b/src/murfey/workflows/spa/flush_spa_preprocess.py @@ -7,7 +7,7 @@ from sqlmodel import Session, select from murfey.server import _transport_object -from murfey.server.api.auth import MurfeySessionID +from murfey.server.api.auth import MurfeySessionIDInstrument as MurfeySessionID from murfey.server.feedback import _murfey_id from murfey.util import sanitise, secure_path from murfey.util.config import get_machine_config, get_microscope diff --git a/tests/conftest.py b/tests/conftest.py index bb1d86f37..ad55f10d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -285,14 +285,15 @@ def ispyb_db_session( ======================================================================================= """ +murfey_db_url = ( + f"postgresql+psycopg2://{os.environ['POSTGRES_USER']}:{os.environ['POSTGRES_PASSWORD']}" + f"@{os.environ['POSTGRES_HOST']}:{os.environ['POSTGRES_PORT']}/{os.environ['POSTGRES_DB']}" +) + @pytest.fixture(scope="session") def murfey_db_engine(): - url = ( - f"postgresql+psycopg2://{os.environ['POSTGRES_USER']}:{os.environ['POSTGRES_PASSWORD']}" - f"@{os.environ['POSTGRES_HOST']}:{os.environ['POSTGRES_PORT']}/{os.environ['POSTGRES_DB']}" - ) - engine = create_engine(url) + engine = create_engine(murfey_db_url) SQLModel.metadata.create_all(engine) yield engine engine.dispose() diff --git a/tests/server/api/test_movies.py b/tests/server/api/test_movies.py index bc633016b..00cc4f464 100644 --- a/tests/server/api/test_movies.py +++ b/tests/server/api/test_movies.py @@ -1,61 +1,81 @@ -from unittest.mock import Mock, create_autospec, patch - -import pytest -from fastapi.testclient import TestClient from sqlmodel import Session -from murfey.server.main import app -from murfey.server.murfey_db import murfey_db_session - - -@pytest.fixture(scope="module") -def test_user(): - return {"username": "testuser", "password": "testpass"} - - -def movies_return(): - return [("Supervisor_1", 2)] - - -expression = Mock() -expression.all = movies_return - -mock_session = create_autospec(Session, instance=True) -mock_session.exec.return_value = expression - - -def override_murfey_db(): - try: - db = mock_session - yield db - finally: - db.close() - - -app.dependency_overrides[murfey_db_session] = override_murfey_db - -client = TestClient(app) - - -def login(test_user): - with patch( - "murfey.server.api.auth.validate_user", return_value=True - ) as mock_validate: - response = client.post("/token", data=test_user) - assert mock_validate.called_once() - assert response.status_code == 200 - token = response.json()["access_token"] - assert token is not None - return token - - -@patch("murfey.server.api.auth.check_user", return_value=True) -def test_movie_count(mock_check, test_user): - token = login(test_user) - response = client.get( - "/session_info/num_movies", headers={"Authorization": f"Bearer {token}"} +from murfey.server.api.session_control import count_number_of_movies +from murfey.util.db import ( + AutoProcProgram, + DataCollection, + DataCollectionGroup, + Movie, + MurfeyLedger, + ProcessingJob, +) +from tests.conftest import ExampleVisit, get_or_create_db_entry + + +def test_movie_count( + murfey_db_session: Session, # From conftest.py +): + + # Insert table dependencies + dcg_entry: DataCollectionGroup = get_or_create_db_entry( + murfey_db_session, + DataCollectionGroup, + lookup_kwargs={ + "id": 0, + "session_id": ExampleVisit.murfey_session_id, + "tag": "test_dcg", + }, + ) + dc_entry: DataCollection = get_or_create_db_entry( + murfey_db_session, + DataCollection, + lookup_kwargs={ + "id": 0, + "tag": "test_dc", + "dcg_id": dcg_entry.id, + }, + ) + processing_job_entry: ProcessingJob = get_or_create_db_entry( + murfey_db_session, + ProcessingJob, + lookup_kwargs={ + "id": 0, + "recipe": "test_recipe", + "dc_id": dc_entry.id, + }, ) - assert mock_check.called_once() - assert response.status_code == 200 - assert len(mock_session.method_calls) == 2 - assert response.json() == {"Supervisor_1": 2} + autoproc_entry: AutoProcProgram = get_or_create_db_entry( + murfey_db_session, + AutoProcProgram, + lookup_kwargs={ + "id": 0, + "pj_id": processing_job_entry.id, + }, + ) + + # Insert test movies and one-to-one dependencies + tag = "test_movie" + num_movies = 5 + for i in range(num_movies): + murfey_ledger_entry: MurfeyLedger = get_or_create_db_entry( + murfey_db_session, + MurfeyLedger, + lookup_kwargs={ + "id": i, + "app_id": autoproc_entry.id, + }, + ) + _: Movie = get_or_create_db_entry( + murfey_db_session, + Movie, + lookup_kwargs={ + "murfey_id": murfey_ledger_entry.id, + "path": "/some/path", + "image_number": i, + "tag": tag, + }, + ) + + # Run function and evaluate result + result = count_number_of_movies(murfey_db_session) + assert result == {tag: num_movies} diff --git a/tests/server/test_main.py b/tests/server/test_main.py index 2018885d3..b212b705d 100644 --- a/tests/server/test_main.py +++ b/tests/server/test_main.py @@ -6,6 +6,7 @@ from fastapi.testclient import TestClient from murfey.server.main import app +from murfey.util.api import url_path_for client = TestClient(app) @@ -19,7 +20,10 @@ def login(test_user): with patch( "murfey.server.api.auth.validate_user", return_value=True ) as mock_validate: - response = client.post("/token", data=test_user) + response = client.post( + f"{url_path_for('auth.router', 'generate_token')}", + data=test_user, + ) assert mock_validate.called_once() assert response.status_code == 200 token = response.json()["access_token"] @@ -31,14 +35,17 @@ def login(test_user): def test_read_main(mock_check, test_user): token = login(test_user) response = client.get( - "/session_info/connections", headers={"Authorization": f"Bearer {token}"} + "/session_info/connections", + headers={"Authorization": f"Bearer {token}"} ) assert mock_check.called_once() assert response.status_code == 200 def test_pypi_proxy(): - response = client.get("/pypi/fastapi") + response = client.get( + f"{url_path_for('bootstrap.pypi', 'get_pypi_package_downloads_list', package='fastapi')}" + ) assert response.status_code == 200 assert "