diff --git a/src/murfey/client/__init__.py b/src/murfey/client/__init__.py index 52a866c50..4f5a64029 100644 --- a/src/murfey/client/__init__.py +++ b/src/murfey/client/__init__.py @@ -2,8 +2,6 @@ import argparse import configparser - -# import json import logging import os import platform @@ -12,18 +10,14 @@ import time import webbrowser from datetime import datetime -from functools import partial from pathlib import Path from queue import Queue from typing import List, Literal from urllib.parse import ParseResult, urlparse import requests - -# from multiprocessing import Process, Queue from rich.prompt import Confirm -import murfey.client.rsync import murfey.client.update import murfey.client.watchdir import murfey.client.websocket @@ -31,38 +25,11 @@ from murfey.client.instance_environment import MurfeyInstanceEnvironment from murfey.client.tui.app import MurfeyTUI from murfey.client.tui.status_bar import StatusBar -from murfey.util import _get_visit_list - -# from asyncio import Queue - - -# from rich.prompt import Prompt - +from murfey.util.client import _get_visit_list, authorised_requests, read_config log = logging.getLogger("murfey.client") - -def read_config() -> configparser.ConfigParser: - config = configparser.ConfigParser() - try: - mcch = os.environ.get("MURFEY_CLIENT_CONFIG_HOME") - murfey_client_config_home = Path(mcch) if mcch else Path.home() - with open(murfey_client_config_home / ".murfey") as configfile: - config.read_file(configfile) - except FileNotFoundError: - log.warning( - f"Murfey client configuration file {murfey_client_config_home / '.murfey'} not found" - ) - if "Murfey" not in config: - config["Murfey"] = {} - return config - - -token = read_config()["Murfey"].get("token", "") - -requests.get = partial(requests.get, headers={"Authorization": f"Bearer {token}"}) -requests.post = partial(requests.post, headers={"Authorization": f"Bearer {token}"}) -requests.delete = partial(requests.delete, headers={"Authorization": f"Bearer {token}"}) +requests.get, requests.post, requests.put, requests.delete = authorised_requests() def write_config(config: configparser.ConfigParser): diff --git a/src/murfey/client/analyser.py b/src/murfey/client/analyser.py index 877ed5026..40e8d9068 100644 --- a/src/murfey/client/analyser.py +++ b/src/murfey/client/analyser.py @@ -22,7 +22,7 @@ from murfey.client.instance_environment import MurfeyInstanceEnvironment from murfey.client.rsync import RSyncerUpdate, TransferResult from murfey.client.tui.forms import FormDependency -from murfey.util import Observer, get_machine_config_client +from murfey.util.client import Observer, get_machine_config_client from murfey.util.mdoc import get_block from murfey.util.models import PreprocessingParametersTomo, ProcessingParametersSPA diff --git a/src/murfey/client/contexts/clem.py b/src/murfey/client/contexts/clem.py index 6cb0c6bb7..37d03f789 100644 --- a/src/murfey/client/contexts/clem.py +++ b/src/murfey/client/contexts/clem.py @@ -14,7 +14,7 @@ from murfey.client.context import Context from murfey.client.instance_environment import MurfeyInstanceEnvironment -from murfey.util import capture_post, get_machine_config_client +from murfey.util.client import capture_post, get_machine_config_client # Create logger object logger = logging.getLogger("murfey.client.contexts.clem") diff --git a/src/murfey/client/contexts/fib.py b/src/murfey/client/contexts/fib.py index 12afdb6a7..393b849b2 100644 --- a/src/murfey/client/contexts/fib.py +++ b/src/murfey/client/contexts/fib.py @@ -10,7 +10,7 @@ from murfey.client.context import Context from murfey.client.instance_environment import MurfeyInstanceEnvironment -from murfey.util import authorised_requests +from murfey.util.client import authorised_requests logger = logging.getLogger("murfey.client.contexts.fib") diff --git a/src/murfey/client/contexts/spa.py b/src/murfey/client/contexts/spa.py index c61b414fd..c37036f9f 100644 --- a/src/murfey/client/contexts/spa.py +++ b/src/murfey/client/contexts/spa.py @@ -15,7 +15,7 @@ MurfeyID, MurfeyInstanceEnvironment, ) -from murfey.util import ( +from murfey.util.client import ( authorised_requests, capture_get, capture_post, diff --git a/src/murfey/client/contexts/spa_metadata.py b/src/murfey/client/contexts/spa_metadata.py index 8e5317343..79546a7ca 100644 --- a/src/murfey/client/contexts/spa_metadata.py +++ b/src/murfey/client/contexts/spa_metadata.py @@ -8,7 +8,11 @@ from murfey.client.context import Context from murfey.client.contexts.spa import _file_transferred_to, _get_source from murfey.client.instance_environment import MurfeyInstanceEnvironment, SampleInfo -from murfey.util import authorised_requests, capture_post, get_machine_config_client +from murfey.util.client import ( + authorised_requests, + capture_post, + get_machine_config_client, +) from murfey.util.spa_metadata import ( FoilHoleInfo, get_grid_square_atlas_positions, diff --git a/src/murfey/client/contexts/tomo.py b/src/murfey/client/contexts/tomo.py index 072c0bf36..138f3e9ef 100644 --- a/src/murfey/client/contexts/tomo.py +++ b/src/murfey/client/contexts/tomo.py @@ -17,7 +17,11 @@ MurfeyID, MurfeyInstanceEnvironment, ) -from murfey.util import authorised_requests, capture_post, get_machine_config_client +from murfey.util.client import ( + authorised_requests, + capture_post, + get_machine_config_client, +) from murfey.util.mdoc import get_block, get_global_data, get_num_blocks from murfey.util.tomo import midpoint diff --git a/src/murfey/client/multigrid_control.py b/src/murfey/client/multigrid_control.py index ab1bade3d..316bf8bed 100644 --- a/src/murfey/client/multigrid_control.py +++ b/src/murfey/client/multigrid_control.py @@ -20,7 +20,8 @@ from murfey.client.rsync import RSyncer, RSyncerUpdate, TransferResult from murfey.client.tui.screens import determine_default_destination from murfey.client.watchdir import DirWatcher -from murfey.util import capture_post, get_machine_config_client, posix_path +from murfey.util import posix_path +from murfey.util.client import capture_post, get_machine_config_client log = logging.getLogger("murfey.client.mutligrid_control") diff --git a/src/murfey/client/rsync.py b/src/murfey/client/rsync.py index f53d40440..2f4036e5c 100644 --- a/src/murfey/client/rsync.py +++ b/src/murfey/client/rsync.py @@ -19,7 +19,7 @@ from urllib.parse import ParseResult from murfey.client.tui.status_bar import StatusBar -from murfey.util import Observer +from murfey.util.client import Observer logger = logging.getLogger("murfey.client.rsync") diff --git a/src/murfey/client/tui/app.py b/src/murfey/client/tui/app.py index 09bff5271..0d5060a55 100644 --- a/src/murfey/client/tui/app.py +++ b/src/murfey/client/tui/app.py @@ -33,10 +33,10 @@ from murfey.client.tui.status_bar import StatusBar from murfey.client.watchdir import DirWatcher from murfey.client.watchdir_multigrid import MultigridDirWatcher -from murfey.util import ( +from murfey.util import posix_path +from murfey.util.client import ( capture_post, get_machine_config_client, - posix_path, read_config, set_default_acquisition_output, ) diff --git a/src/murfey/client/tui/screens.py b/src/murfey/client/tui/screens.py index 0ba3c30a8..0ff7d839b 100644 --- a/src/murfey/client/tui/screens.py +++ b/src/murfey/client/tui/screens.py @@ -56,7 +56,8 @@ ) from murfey.client.rsync import RSyncer from murfey.client.tui.forms import FormDependency -from murfey.util import capture_post, get_machine_config_client, posix_path, read_config +from murfey.util import posix_path +from murfey.util.client import capture_post, get_machine_config_client, read_config from murfey.util.models import PreprocessingParametersTomo, ProcessingParametersSPA log = logging.getLogger("murfey.tui.screens") diff --git a/src/murfey/client/watchdir.py b/src/murfey/client/watchdir.py index 57b5aac6d..c1473304d 100644 --- a/src/murfey/client/watchdir.py +++ b/src/murfey/client/watchdir.py @@ -14,8 +14,8 @@ from pathlib import Path from typing import List, NamedTuple, Optional -import murfey.util from murfey.client.tui.status_bar import StatusBar +from murfey.util.client import Observer log = logging.getLogger("murfey.client.watchdir") @@ -26,7 +26,7 @@ class _FileInfo(NamedTuple): settling_time: Optional[float] = None -class DirWatcher(murfey.util.Observer): +class DirWatcher(Observer): def __init__( self, path: str | os.PathLike, diff --git a/src/murfey/client/watchdir_multigrid.py b/src/murfey/client/watchdir_multigrid.py index cf3a95c90..46a9a2a11 100644 --- a/src/murfey/client/watchdir_multigrid.py +++ b/src/murfey/client/watchdir_multigrid.py @@ -7,12 +7,12 @@ from pathlib import Path from typing import List -import murfey.util +from murfey.util.client import Observer log = logging.getLogger("murfey.client.watchdir_multigrid") -class MultigridDirWatcher(murfey.util.Observer): +class MultigridDirWatcher(Observer): def __init__( self, path: str | os.PathLike, diff --git a/src/murfey/instrument_server/__init__.py b/src/murfey/instrument_server/__init__.py index bfbdc77c6..be0fc49fc 100644 --- a/src/murfey/instrument_server/__init__.py +++ b/src/murfey/instrument_server/__init__.py @@ -5,9 +5,9 @@ from rich.logging import RichHandler import murfey -from murfey.client import read_config from murfey.client.customlogging import CustomHandler from murfey.util import LogFilter +from murfey.util.client import read_config logger = logging.getLogger("murfey.instrument_server") diff --git a/src/murfey/instrument_server/api.py b/src/murfey/instrument_server/api.py index 55f5c3982..4cb6fbeae 100644 --- a/src/murfey/instrument_server/api.py +++ b/src/murfey/instrument_server/api.py @@ -17,11 +17,11 @@ from pydantic import BaseModel from werkzeug.utils import secure_filename -from murfey.client import read_config from murfey.client.multigrid_control import MultigridController from murfey.client.rsync import RSyncer from murfey.client.watchdir_multigrid import MultigridDirWatcher from murfey.util import posix_path, sanitise, sanitise_nonpath, secure_path +from murfey.util.client import read_config from murfey.util.instrument_models import MultigridWatcherSpec from murfey.util.models import File, Token diff --git a/src/murfey/util/__init__.py b/src/murfey/util/__init__.py index ff500389b..994043292 100644 --- a/src/murfey/util/__init__.py +++ b/src/murfey/util/__init__.py @@ -1,67 +1,17 @@ from __future__ import annotations -import asyncio -import configparser -import copy -import inspect -import json import logging -import os -import shutil -from functools import lru_cache, partial from pathlib import Path from queue import Queue from threading import Thread -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union -from urllib.parse import ParseResult, urlparse, urlunparse +from typing import Optional from uuid import uuid4 -import requests from werkzeug.utils import secure_filename -from murfey.util.models import Visit - logger = logging.getLogger("murfey.util") -def read_config() -> configparser.ConfigParser: - config = configparser.ConfigParser() - try: - mcch = os.environ.get("MURFEY_CLIENT_CONFIG_HOME") - murfey_client_config_home = Path(mcch) if mcch else Path.home() - with open(murfey_client_config_home / ".murfey") as configfile: - config.read_file(configfile) - except FileNotFoundError: - logger.warning( - f"Murfey client configuration file {murfey_client_config_home / '.murfey'} not found" - ) - if "Murfey" not in config: - config["Murfey"] = {} - return config - - -@lru_cache(maxsize=1) -def get_machine_config_client( - url: str, instrument_name: str = "", demo: bool = False -) -> dict: - _instrument_name: str | None = instrument_name or os.getenv("BEAMLINE") - if not _instrument_name: - return {} - return requests.get(f"{url}/instruments/{_instrument_name}/machine").json() - - -def authorised_requests() -> Tuple[Callable, Callable, Callable, Callable]: - token = read_config()["Murfey"].get("token", "") - _get = partial(requests.get, headers={"Authorization": f"Bearer {token}"}) - _post = partial(requests.post, headers={"Authorization": f"Bearer {token}"}) - _put = partial(requests.put, headers={"Authorization": f"Bearer {token}"}) - _delete = partial(requests.delete, headers={"Authorization": f"Bearer {token}"}) - return _get, _post, _put, _delete - - -requests.get, requests.post, requests.put, requests.delete = authorised_requests() - - def sanitise(in_string: str) -> str: return in_string.replace("\r\n", "").replace("\n", "") @@ -113,175 +63,6 @@ def posix_path(path: Path) -> str: return str(path) -def _get_visit_list(api_base: ParseResult, instrument_name: str): - get_visits_url = api_base._replace( - path=f"/instruments/{instrument_name}/visits_raw" - ) - server_reply = requests.get(get_visits_url.geturl()) - if server_reply.status_code != 200: - raise ValueError(f"Server unreachable ({server_reply.status_code})") - return [Visit.parse_obj(v) for v in server_reply.json()] - - -def capture_post(url: str, json: dict | list = {}) -> requests.Response | None: - try: - response = requests.post(url, json=json) - except Exception as e: - logger.error(f"Exception encountered in post to {url}: {e}") - response = requests.Response() - if response.status_code != 200: - logger.warning( - f"Response to post to {url} with data {json} had status code " - f"{response.status_code}. The reason given was {response.reason}" - ) - split_url = urlparse(url) - client_config = read_config() - failure_url = urlunparse( - split_url._replace( - path=f"/instruments/{client_config['Murfey']['instrument_name']}/failed_client_post" - ) - ) - try: - resend_response = requests.post( - failure_url, json={"url": url, "data": json} - ) - except Exception as e: - logger.error(f"Exception encountered in post to {failure_url}: {e}") - resend_response = requests.Response() - if resend_response.status_code != 200: - logger.warning( - f"Response to post to {failure_url} failed with {resend_response.reason}" - ) - - return response - - -def capture_get(url: str) -> requests.Response | None: - try: - response = requests.get(url) - except Exception as e: - logger.error(f"Exception encountered in get from {url}: {e}") - response = None - if response and response.status_code != 200: - logger.warning( - f"Response to get from {url} had status code {response.status_code}. " - f"The reason given was {response.reason}" - ) - return response - - -def set_default_acquisition_output( - new_output_dir: Path, - software_settings_output_directories: Dict[str, List[str]], - safe: bool = True, -): - for p, keys in software_settings_output_directories.items(): - if safe: - settings_copy_path = Path(p) - settings_copy_path = settings_copy_path.parent / ( - "_murfey_" + settings_copy_path.name - ) - shutil.copy(p, str(settings_copy_path)) - with open(p, "r") as for_parsing: - settings = json.load(for_parsing) - # for safety - settings_copy = copy.deepcopy(settings) - - def _set(d: dict, keys_list: List[str], value: str) -> dict: - if len(keys_list) > 1: - tmp_value: Union[dict, str] = _set( - d[keys_list[0]], keys_list[1:], value - ) - else: - tmp_value = value - return {_k: tmp_value if _k == keys_list[0] else _v for _k, _v in d.items()} - - settings_copy = _set(settings_copy, keys, str(new_output_dir)) - - def _check_dict_structure(d1: dict, d2: dict) -> bool: - if set(d1.keys()) != set(d2.keys()): - return False - for k in d1.keys(): - if isinstance(d1[k], dict): - if not isinstance(d2[k], dict): - return False - _check_dict_structure(d1[k], d2[k]) - return True - - if _check_dict_structure(settings, settings_copy): - with open(p, "w") as sf: - json.dump(settings_copy, sf) - - -class Observer: - """ - A helper class implementing the observer pattern supporting both - synchronous and asynchronous notification calls and both synchronous and - asynchronous callback functions. - """ - - # The class here should be derived from typing.Generic[P] - # with P = ParamSpec("P"), and the notify/anotify functions should use - # *args: P.args, **kwargs: P.kwargs. - # However, ParamSpec is Python 3.10+ (PEP 612), so we can't use that yet. - - def __init__(self): - self._listeners: list[Callable[..., Awaitable[None] | None]] = [] - self._secondary_listeners: list[Callable[..., Awaitable[None] | None]] = [] - self._final_listeners: list[Callable[..., Awaitable[None] | None]] = [] - super().__init__() - - def subscribe( - self, - fn: Callable[..., Awaitable[None] | None], - secondary: bool = False, - final: bool = False, - ): - if final: - self._final_listeners.append(fn) - elif secondary: - self._secondary_listeners.append(fn) - else: - self._listeners.append(fn) - - async def anotify( - self, *args, secondary: bool = False, final: bool = False, **kwargs - ) -> None: - awaitables: list[Awaitable] = [] - listeners = ( - self._secondary_listeners - if secondary - else self._final_listeners if final else self._listeners - ) - for notify_function in listeners: - result = notify_function(*args, **kwargs) - if result is not None and inspect.isawaitable(result): - awaitables.append(result) - if awaitables: - await self._await_all(awaitables) - - @staticmethod - async def _await_all(awaitables: list[Awaitable]): - for awaitable in asyncio.as_completed(awaitables): - await awaitable - - def notify( - self, *args, secondary: bool = False, final: bool = False, **kwargs - ) -> None: - awaitables: list[Awaitable] = [] - listeners = ( - self._secondary_listeners - if secondary - else self._final_listeners if final else self._listeners - ) - for notify_function in listeners: - result = notify_function(*args, **kwargs) - if result is not None and inspect.isawaitable(result): - awaitables.append(result) - if awaitables: - asyncio.run(self._await_all(awaitables)) - - class Processor: def __init__(self, name: Optional[str] = None): self._in: Queue = Queue() diff --git a/src/murfey/util/client.py b/src/murfey/util/client.py new file mode 100644 index 000000000..0e9bd3c4c --- /dev/null +++ b/src/murfey/util/client.py @@ -0,0 +1,244 @@ +""" +Utility functions used solely by the Murfey client. They help set up its +configuration, communicate with the backend server using the correct credentials, +and set default directories to work with. +""" + +from __future__ import annotations + +import asyncio +import configparser +import copy +import inspect +import json +import logging +import os +import shutil +from functools import lru_cache, partial +from pathlib import Path +from typing import Awaitable, Callable, Optional, Union +from urllib.parse import ParseResult, urlparse, urlunparse + +import requests + +from murfey.util.models import Visit + +logger = logging.getLogger("murfey.util.client") + + +def read_config() -> configparser.ConfigParser: + config = configparser.ConfigParser() + + # Look for 'MURFEY_CLIENT_CONFIGURATION' environment variable first + mcc = os.environ.get("MURFEY_CLIENT_CONFIGURATION") + if mcc: + config_file = Path(mcc) + # If not set, look for 'MURFEY_CLIENT_CONFIG_HOME' or '~' and then for '.murfey' + else: + mcch = os.environ.get("MURFEY_CLIENT_CONFIG_HOME") + murfey_client_config_home = Path(mcch) if mcch else Path.home() + config_file = murfey_client_config_home / ".murfey" + + # Attempt to read the file and return the config + try: + with open(config_file) as file: + config.read_file(file) + except FileNotFoundError: + logger.warning( + f"Murfey client configuration file {str(config_file)!r} not found" + ) + if "Murfey" not in config: + config["Murfey"] = {} + return config + + +@lru_cache(maxsize=1) +def get_machine_config_client( + url: str, instrument_name: str = "", demo: bool = False +) -> dict: + _instrument_name: Optional[str] = instrument_name or os.getenv("BEAMLINE") + if not _instrument_name: + return {} + return requests.get(f"{url}/instruments/{_instrument_name}/machine").json() + + +def authorised_requests() -> tuple[Callable, Callable, Callable, Callable]: + token = read_config()["Murfey"].get("token", "") + _get = partial(requests.get, headers={"Authorization": f"Bearer {token}"}) + _post = partial(requests.post, headers={"Authorization": f"Bearer {token}"}) + _put = partial(requests.put, headers={"Authorization": f"Bearer {token}"}) + _delete = partial(requests.delete, headers={"Authorization": f"Bearer {token}"}) + return _get, _post, _put, _delete + + +requests.get, requests.post, requests.put, requests.delete = authorised_requests() + + +def _get_visit_list(api_base: ParseResult, instrument_name: str): + proxy_path = api_base.path.rstrip("/") + get_visits_url = api_base._replace( + path=f"{proxy_path}/instruments/{instrument_name}/visits_raw" + ) + server_reply = requests.get(get_visits_url.geturl()) + if server_reply.status_code != 200: + raise ValueError(f"Server unreachable ({server_reply.status_code})") + return [Visit.parse_obj(v) for v in server_reply.json()] + + +def capture_post(url: str, json: Union[dict, list] = {}) -> Optional[requests.Response]: + try: + response = requests.post(url, json=json) + except Exception as e: + logger.error(f"Exception encountered in post to {url}: {e}") + response = requests.Response() + if response.status_code != 200: + logger.warning( + f"Response to post to {url} with data {json} had status code " + f"{response.status_code}. The reason given was {response.reason}" + ) + split_url = urlparse(url) + client_config = read_config() + failure_url = urlunparse( + split_url._replace( + path=f"/instruments/{client_config['Murfey']['instrument_name']}/failed_client_post" + ) + ) + try: + resend_response = requests.post( + failure_url, json={"url": url, "data": json} + ) + except Exception as e: + logger.error(f"Exception encountered in post to {failure_url}: {e}") + resend_response = requests.Response() + if resend_response.status_code != 200: + logger.warning( + f"Response to post to {failure_url} failed with {resend_response.reason}" + ) + + return response + + +def capture_get(url: str) -> Optional[requests.Response]: + try: + response = requests.get(url) + except Exception as e: + logger.error(f"Exception encountered in get from {url}: {e}") + response = None + if response and response.status_code != 200: + logger.warning( + f"Response to get from {url} had status code {response.status_code}. " + f"The reason given was {response.reason}" + ) + return response + + +def set_default_acquisition_output( + new_output_dir: Path, + software_settings_output_directories: dict[str, list[str]], + safe: bool = True, +): + for p, keys in software_settings_output_directories.items(): + if safe: + settings_copy_path = Path(p) + settings_copy_path = settings_copy_path.parent / ( + "_murfey_" + settings_copy_path.name + ) + shutil.copy(p, str(settings_copy_path)) + with open(p, "r") as for_parsing: + settings = json.load(for_parsing) + # for safety + settings_copy = copy.deepcopy(settings) + + def _set(d: dict, keys_list: list[str], value: str) -> dict: + if len(keys_list) > 1: + tmp_value: Union[dict, str] = _set( + d[keys_list[0]], keys_list[1:], value + ) + else: + tmp_value = value + return {_k: tmp_value if _k == keys_list[0] else _v for _k, _v in d.items()} + + settings_copy = _set(settings_copy, keys, str(new_output_dir)) + + def _check_dict_structure(d1: dict, d2: dict) -> bool: + if set(d1.keys()) != set(d2.keys()): + return False + for k in d1.keys(): + if isinstance(d1[k], dict): + if not isinstance(d2[k], dict): + return False + _check_dict_structure(d1[k], d2[k]) + return True + + if _check_dict_structure(settings, settings_copy): + with open(p, "w") as sf: + json.dump(settings_copy, sf) + + +class Observer: + """ + A helper class implementing the observer pattern supporting both + synchronous and asynchronous notification calls and both synchronous and + asynchronous callback functions. + """ + + # The class here should be derived from typing.Generic[P] + # with P = ParamSpec("P"), and the notify/anotify functions should use + # *args: P.args, **kwargs: P.kwargs. + # However, ParamSpec is Python 3.10+ (PEP 612), so we can't use that yet. + + def __init__(self): + self._listeners: list[Callable[..., Awaitable[None] | None]] = [] + self._secondary_listeners: list[Callable[..., Awaitable[None] | None]] = [] + self._final_listeners: list[Callable[..., Awaitable[None] | None]] = [] + super().__init__() + + def subscribe( + self, + fn: Callable[..., Awaitable[None] | None], + secondary: bool = False, + final: bool = False, + ): + if final: + self._final_listeners.append(fn) + elif secondary: + self._secondary_listeners.append(fn) + else: + self._listeners.append(fn) + + async def anotify( + self, *args, secondary: bool = False, final: bool = False, **kwargs + ) -> None: + awaitables: list[Awaitable] = [] + listeners = ( + self._secondary_listeners + if secondary + else self._final_listeners if final else self._listeners + ) + for notify_function in listeners: + result = notify_function(*args, **kwargs) + if result is not None and inspect.isawaitable(result): + awaitables.append(result) + if awaitables: + await self._await_all(awaitables) + + @staticmethod + async def _await_all(awaitables: list[Awaitable]): + for awaitable in asyncio.as_completed(awaitables): + await awaitable + + def notify( + self, *args, secondary: bool = False, final: bool = False, **kwargs + ) -> None: + awaitables: list[Awaitable] = [] + listeners = ( + self._secondary_listeners + if secondary + else self._final_listeners if final else self._listeners + ) + for notify_function in listeners: + result = notify_function(*args, **kwargs) + if result is not None and inspect.isawaitable(result): + awaitables.append(result) + if awaitables: + asyncio.run(self._await_all(awaitables)) diff --git a/tests/util/test_client.py b/tests/util/test_client.py new file mode 100644 index 000000000..98e920f88 --- /dev/null +++ b/tests/util/test_client.py @@ -0,0 +1,170 @@ +import json +import os +from pathlib import Path +from unittest.mock import Mock, patch +from urllib.parse import urlparse + +from pytest import mark + +from murfey.util.client import ( + _get_visit_list, + read_config, + set_default_acquisition_output, +) +from murfey.util.models import Visit + +test_read_config_params_matrix = ( + # Environment variable to set | Append to tmp_path + ( + "MURFEY_CLIENT_CONFIGURATION", + "config/murfey-client-config.cfg", + ), + ( + "MURFEY_CLIENT_CONFIG_HOME", + "config", + ), + ( + "", + "", + ), # Test default home directory +) + + +@mark.parametrize("test_params", test_read_config_params_matrix) +def test_read_config( + test_params: tuple[str, str], + tmp_path, + mock_client_configuration, +): + # Unpack test params + env_var, partial_path = test_params + + # Construct the environment variable and the expected config file path + env_var_dict: dict[str, str] = {} + if env_var: + full_path = tmp_path / partial_path + env_var_dict[env_var] = str(full_path) + file_path = full_path if full_path.suffix else full_path / ".murfey" + else: + file_path = Path().home() / ".murfey" + + # Make directories all the way to the requested place + file_path.parent.mkdir(parents=True, exist_ok=True) + + # Write the client config fixture to the specified file + with open(file_path, "w") as file: + mock_client_configuration.write(file) + + # Patch the OS environment variable and run the function + with patch.dict(os.environ, env_var_dict, clear=False): + config = read_config() + + # Compare returned config with mock one + assert dict(config["Murfey"]) == dict(mock_client_configuration["Murfey"]) + + +test_get_visit_list_params_matrix = ( + ("http://0.0.0.0:8000",), + ("http://0.0.0.0:8000/api",), + ("http://murfey_server",), + ("http://murfey_server/api",), + ("http://murfey_server.com",), +) + + +@mark.parametrize("test_params", test_get_visit_list_params_matrix) +@patch("murfey.util.client.requests") +def test_get_visit_list( + mock_request, + test_params: tuple[str], + mock_client_configuration, +): + # Unpack test params and set up other params + (server_url,) = test_params + instrument_name = mock_client_configuration["Murfey"]["instrument_name"] + + # Construct the expected request response + example_visits = [ + { + "start": "1999-09-09T09:00:00", + "end": "1999-09-11T09:00:00", + "session_id": 123456789, + "name": "cm12345-0", + "beamline": "murfey", + "proposal_title": "Commissioning Session 1", + }, + { + "start": "1999-09-09T09:00:00", + "end": "1999-09-11T09:00:00", + "session_id": 246913578, + "name": "cm23456-1", + "beamline": "murfey", + "proposal_title": "Cryo-cycle 1999", + }, + ] + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = example_visits + mock_request.get.return_value = mock_response + + # read_config() has to be patched using fixture, so has to be done in function + with patch("murfey.util.client.read_config", mock_client_configuration): + visits = _get_visit_list(urlparse(server_url), instrument_name) + + # Check that request was sent with the correct URL + expected_url = f"{server_url}/instruments/{instrument_name}/visits_raw" + mock_request.get.assert_called_once_with(expected_url) + + # Check that expected outputs are correct (order-sensitive) + for v, visit in enumerate(visits): + assert visit.dict() == Visit.parse_obj(example_visits[v]).dict() + + +def test_set_default_acquisition_output_normal_operation(tmp_path): + output_dir = tmp_path / "settings.json" + settings_json = { + "a": { + "b": {"data_dir": str(tmp_path)}, + "c": { + "d": 1, + }, + } + } + with open(output_dir, "w") as sf: + json.dump(settings_json, sf) + set_default_acquisition_output( + tmp_path / "visit", {str(tmp_path / "settings.json"): ["a", "b", "data_dir"]} + ) + assert (tmp_path / "_murfey_settings.json").is_file() + with open(output_dir, "r") as sf: + data = json.load(sf) + assert data["a"]["b"]["data_dir"] == str(tmp_path / "visit") + assert data["a"]["c"]["d"] == 1 + with open(output_dir.parent / "_murfey_settings.json", "r") as sf: + data = json.load(sf) + assert data["a"]["b"]["data_dir"] == str(tmp_path) + assert data["a"]["c"]["d"] == 1 + + +def test_set_default_acquisition_output_no_file_copy(tmp_path): + output_dir = tmp_path / "settings.json" + settings_json = { + "a": { + "b": {"data_dir": str(tmp_path)}, + "c": { + "d": 1, + }, + } + } + with open(output_dir, "w") as sf: + json.dump(settings_json, sf) + set_default_acquisition_output( + tmp_path / "visit", + {str(tmp_path / "settings.json"): ["a", "b", "data_dir"]}, + safe=False, + ) + assert not (tmp_path / "_murfey_settings.json").is_file() + with open(output_dir, "r") as sf: + data = json.load(sf) + assert data["a"]["b"]["data_dir"] == str(tmp_path / "visit") + assert data["a"]["c"]["d"] == 1 diff --git a/tests/util/test_lif.py b/tests/util/test_lif.py deleted file mode 100644 index b8f224939..000000000 --- a/tests/util/test_lif.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Contains unit tests for lif.py. -""" - -from __future__ import annotations diff --git a/tests/util/test_set_default_acquisition_output.py b/tests/util/test_set_default_acquisition_output.py deleted file mode 100644 index 8b28cf7d5..000000000 --- a/tests/util/test_set_default_acquisition_output.py +++ /dev/null @@ -1,53 +0,0 @@ -import json - -from murfey.util import set_default_acquisition_output - - -def test_set_default_acquisition_output_normal_operation(tmp_path): - output_dir = tmp_path / "settings.json" - settings_json = { - "a": { - "b": {"data_dir": str(tmp_path)}, - "c": { - "d": 1, - }, - } - } - with open(output_dir, "w") as sf: - json.dump(settings_json, sf) - set_default_acquisition_output( - tmp_path / "visit", {str(tmp_path / "settings.json"): ["a", "b", "data_dir"]} - ) - assert (tmp_path / "_murfey_settings.json").is_file() - with open(output_dir, "r") as sf: - data = json.load(sf) - assert data["a"]["b"]["data_dir"] == str(tmp_path / "visit") - assert data["a"]["c"]["d"] == 1 - with open(output_dir.parent / "_murfey_settings.json", "r") as sf: - data = json.load(sf) - assert data["a"]["b"]["data_dir"] == str(tmp_path) - assert data["a"]["c"]["d"] == 1 - - -def test_set_default_acquisition_output_no_file_copy(tmp_path): - output_dir = tmp_path / "settings.json" - settings_json = { - "a": { - "b": {"data_dir": str(tmp_path)}, - "c": { - "d": 1, - }, - } - } - with open(output_dir, "w") as sf: - json.dump(settings_json, sf) - set_default_acquisition_output( - tmp_path / "visit", - {str(tmp_path / "settings.json"): ["a", "b", "data_dir"]}, - safe=False, - ) - assert not (tmp_path / "_murfey_settings.json").is_file() - with open(output_dir, "r") as sf: - data = json.load(sf) - assert data["a"]["b"]["data_dir"] == str(tmp_path / "visit") - assert data["a"]["c"]["d"] == 1