From 3321482a43c953208d1fb13ed283426f165958a8 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 26 Feb 2026 09:40:25 +0100 Subject: [PATCH 01/12] Use pydantic to validate request and response data --- src/simdb/database/database.py | 41 +- src/simdb/database/models/file.py | 58 ++- src/simdb/database/models/metadata.py | 9 + src/simdb/database/models/simulation.py | 66 ++++ src/simdb/remote/apis/files.py | 31 +- src/simdb/remote/apis/v1_2/simulations.py | 458 +++++++++------------- src/simdb/remote/app.py | 1 + src/simdb/remote/core/pydantic_utils.py | 291 ++++++++++++++ src/simdb/remote/models.py | 34 +- 9 files changed, 688 insertions(+), 301 deletions(-) create mode 100644 src/simdb/remote/core/pydantic_utils.py diff --git a/src/simdb/database/database.py b/src/simdb/database/database.py index 37ba0071..ff3ea9ca 100644 --- a/src/simdb/database/database.py +++ b/src/simdb/database/database.py @@ -16,6 +16,7 @@ from simdb.config import Config from simdb.query import QueryType, query_compare +from simdb.remote.models import SimulationReference from .models import Base from .models.file import File @@ -226,8 +227,8 @@ def _find_simulation(self, sim_ref: str) -> "Simulation": ) except SQLAlchemyError: simulation = None - if not simulation: - raise DatabaseError(f"Simulation {sim_ref} not found.") from None + if not simulation: + raise DatabaseError(f"Simulation {sim_ref} not found.") from None return simulation def remove(self): @@ -571,6 +572,24 @@ def get_simulation_parents(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_parents_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.input_for.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.outputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_simulation_children(self, simulation: "Simulation") -> List[dict]: subquery = ( self.session.query(File.checksum) @@ -587,6 +606,24 @@ def get_simulation_children(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_children_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.output_of.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.inputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_file(self, file_uuid_str: str) -> "File": """ Get the specified file from the database. diff --git a/src/simdb/database/models/file.py b/src/simdb/database/models/file.py index a05aaa59..a43ea96f 100644 --- a/src/simdb/database/models/file.py +++ b/src/simdb/database/models/file.py @@ -13,7 +13,8 @@ from simdb.config.config import Config from simdb.docstrings import inherit_docstrings from simdb.imas.checksum import checksum as imas_checksum -from simdb.imas.utils import imas_timestamp +from simdb.imas.utils import imas_files, imas_timestamp +from simdb.remote.models import FileData, FileGetDataResponse, FileInfo from simdb.uda.checksum import checksum as uda_checksum from .base import Base @@ -125,6 +126,23 @@ def from_data(cls, data: Dict) -> "File": file.datetime = date_parser.parse(checked_get(data, "datetime", str)) return file + @classmethod + def from_data_model(cls, data: FileData) -> "File": + data_type = data.type + uri = data.uri + file = File( + DataObject.Type[data_type], urilib.URI(uri), perform_integrity_check=False + ) + file.uuid = data.uuid + file.usage = data.usage + file.checksum = data.checksum + file.purpose = data.purpose + file.sensitivity = data.sensitivity + file.access = data.access + file.embargo = data.embargo + file.datetime = data.datetime + return file + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "uuid": self.uuid, @@ -139,3 +157,41 @@ def data(self, recurse: bool = False) -> Dict[str, str]: "datetime": self.datetime.isoformat(), } return data + + def to_model(self) -> FileData: + return FileData( + type=self.type.name, + uri=str(self.uri), + uuid=self.uuid, + checksum=self.checksum, + datetime=self.datetime, + usage=self.usage, + purpose=self.purpose, + sensitivity=self.sensitivity, + access=self.access, + embargo=self.embargo, + ) + + def to_model_with_path(self) -> FileGetDataResponse: + if self.type.name == "FILE": + if self.uri.path is None: + raise ValueError("File path not set") + files = [FileInfo(path=self.uri.path, checksum=self.checksum)] + else: + files = [ + FileInfo(path=path, checksum=sha1_checksum(URI(f"file:{path}"))) + for path in imas_files(self.uri) + ] + return FileGetDataResponse( + type=self.type.name, + uri=str(self.uri), + uuid=self.uuid, + checksum=self.checksum, + datetime=self.datetime, + usage=self.usage, + purpose=self.purpose, + sensitivity=self.sensitivity, + access=self.access, + embargo=self.embargo, + files=files, + ) diff --git a/src/simdb/database/models/metadata.py b/src/simdb/database/models/metadata.py index 628f158c..81bcde9e 100644 --- a/src/simdb/database/models/metadata.py +++ b/src/simdb/database/models/metadata.py @@ -4,6 +4,7 @@ from sqlalchemy import types as sql_types from simdb.docstrings import inherit_docstrings +from simdb.remote.models import MetadataData from .base import Base @@ -32,6 +33,11 @@ def from_data(cls, data: Dict) -> "MetaData": meta = MetaData(data["element"], data["value"]) return meta + @classmethod + def from_data_model(cls, data: MetadataData) -> "MetaData": + meta = MetaData(data.element, data.value) + return meta + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "element": self.element, @@ -39,5 +45,8 @@ def data(self, recurse: bool = False) -> Dict[str, str]: } return data + def to_model(self) -> MetadataData: + return MetadataData(element=self.element, value=self.value) + Index("metadata_index", MetaData.sim_id, MetaData.element, unique=True) diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 1ee21ab6..a88b6d6a 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -9,6 +9,13 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union +from simdb.remote.models import ( + FileDataList, + MetadataDataList, + SimulationData, + SimulationDataResponse, +) + if sys.version_info < (3, 11): from backports.datetime_fromisoformat import MonkeyPatch @@ -336,6 +343,17 @@ def from_data(cls, data: Dict[str, Union[str, Dict, List]]) -> "Simulation": simulation.meta.append(MetaData.from_data(el)) return simulation + @classmethod + def from_data_model(cls, data: SimulationData) -> "Simulation": + simulation = Simulation(None) + simulation.uuid = data.uuid + simulation.alias = data.alias + simulation.datetime = data.datetime + simulation.inputs = [File.from_data_model(el) for el in data.inputs.root] + simulation.outputs = [File.from_data_model(el) for el in data.outputs.root] + simulation.meta = [MetaData.from_data_model(el) for el in data.metadata.root] + return simulation + def data( self, recurse: bool = False, meta_keys: Optional[List[str]] = None ) -> Dict[str, Union[str, List]]: @@ -354,6 +372,54 @@ def data( ] return data + def to_model( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationData: + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() + if recurse: + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) + elif meta_keys: + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) + return SimulationData( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + ) + + def to_model_with_refs( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationDataResponse: + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() + if recurse: + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) + elif meta_keys: + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) + return SimulationDataResponse( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + parents=[], + children=[], + ) + def meta_dict(self) -> Dict[str, Union[Dict, Any]]: meta = {m.element: m.value for m in self.meta} return unflatten_dict(meta) diff --git a/src/simdb/remote/apis/files.py b/src/simdb/remote/apis/files.py index 3753ff88..c4047c11 100644 --- a/src/simdb/remote/apis/files.py +++ b/src/simdb/remote/apis/files.py @@ -2,7 +2,7 @@ import json import uuid from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, cast +from typing import Dict, Iterable, List, Optional import magic from flask import Response, jsonify, request, send_file, stream_with_context @@ -18,7 +18,9 @@ from simdb.remote.core.auth import User, requires_auth from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path +from simdb.remote.core.pydantic_utils import PydanticResponse, pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import ErrorResponse, FileDataList, FileGetDataResponse from simdb.uri import URI api = Namespace("files", path="/") @@ -170,9 +172,10 @@ def _handle_file_upload() -> Response: @api.route("/files") class FileList(Resource): @requires_auth() - def get(self, user: User): + @pydantic_validate(api) + def get(self, user: User) -> PydanticResponse[FileDataList]: files = current_app.db.list_files() - return jsonify([file.data() for file in files]) + return FileDataList.model_validate([file.data() for file in files]) @requires_auth() def post(self, user: User): @@ -189,25 +192,15 @@ def post(self, user: User): @api.route("/file/") class File(Resource): @requires_auth() - def get(self, file_uuid: str, user: Optional[User] = None): + @pydantic_validate(api) + def get( + self, file_uuid: str, user: Optional[User] = None + ) -> PydanticResponse[FileGetDataResponse]: try: file = current_app.db.get_file(file_uuid) - data = cast(Dict[str, Any], file.data(recurse=True)) - if file.type == DataObject.Type.FILE: - data["files"] = [ - { - "path": str(file.uri.path), - "checksum": file.checksum, - } - ] - else: - data["files"] = [ - {"path": str(path), "checksum": sha1_checksum(URI(f"file:{path}"))} - for path in imas_files(file.uri) - ] - return jsonify(data) + return file.to_model_with_path() except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) @api.route("/file/download/") diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 37c4d12e..b9f8a1df 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -1,14 +1,12 @@ import contextlib import datetime -import gzip import itertools import tarfile from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Annotated, Any, Dict, List, Optional, Tuple, cast -from flask import json as flask_json # fallback -from flask import jsonify, request, send_file +from flask import request, send_file from flask_restx import Namespace, Resource from simdb.database import DatabaseError @@ -23,7 +21,32 @@ from simdb.remote.core.cache import cache, cache_key, clear_cache from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path +from simdb.remote.core.pydantic_utils import ( + Body, + Header, + PydanticResponse, + pydantic_validate, +) from simdb.remote.core.typing import current_app +from simdb.remote.models import ( + DeletedSimulation, + ErrorResponse, + MetadataDataList, + MetadataDeleteData, + MetadataDeleteResponse, + MetadataPatchData, + PaginatedResponse, + PaginationData, + SimulationDataResponse, + SimulationDeleteResponse, + SimulationListItem, + SimulationPatchResponse, + SimulationPostData, + SimulationPostResponse, + SimulationTraceData, + StatusPatchData, + ValidationResult, +) from simdb.uri import URI from simdb.validation import ValidationError, Validator from simdb.validation.file import find_file_validator @@ -55,7 +78,7 @@ def _update_simulation_status( server.send_message(f"Simulation {simulation.alias}", msg, to_addresses) -def _validate(simulation, user) -> Dict: +def _validate(simulation, user) -> ValidationResult: schemas = Validator.validation_schemas(current_app.simdb_config, simulation) try: for schema in schemas: @@ -65,10 +88,7 @@ def _validate(simulation, user) -> Dict: ) except ValidationError as err: _update_simulation_status(simulation, models_sim.Simulation.Status.FAILED, user) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) file_validator_type = current_app.simdb_config.get_string_option( "file_validation.type", default=None @@ -88,16 +108,11 @@ def _validate(simulation, user) -> Dict: _update_simulation_status( simulation, models_sim.Simulation.Status.FAILED, user ) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) else: error("Invalid file validator specified in configuration") - return { - "passed": True, - } + return ValidationResult(passed=True, error=None) def _set_alias(alias: str): @@ -121,11 +136,8 @@ def _set_alias(alias: str): return alias, next_id -def _build_trace(sim_id: str) -> Dict[str, Any]: - try: - simulation = current_app.db.get_simulation(sim_id) - except DatabaseError as err: - return {"error": str(err)} +def _build_trace(sim_id: str) -> SimulationTraceData: + simulation = current_app.db.get_simulation(sim_id) data: Dict[str, Any] = cast(Dict[str, Any], simulation.data(recurse=False)) status = simulation.find_meta("status") @@ -152,99 +164,19 @@ def _build_trace(sim_id: str) -> Dict[str, Any]: if replaces_reason: data["replaces_reason"] = replaces_reason[0].value - return data - - -def _get_json_aware(force: bool = False, silent: bool = False): - """ - Parse JSON like Flask's request.get_json, but handle Content-Encoding: gzip. - - force/silent mimic request.get_json behavior. - - Uses Flask's JSON provider to ensure identical types/decoding. - """ - - # Match request.get_json content-type check unless forced - if not force: - mimetype = request.mimetype or "" - if mimetype != "application/json": - if silent: - return None - raise ValueError("Invalid Content-Type (application/json required)") - - raw = request.get_data(cache=False) - enc = (request.headers.get("Content-Encoding") or "").lower() - if "gzip" in enc: - # Only decompress if actually gzipped - with contextlib.suppress(OSError): - raw = gzip.decompress(raw) - - # Use the same charset resolution as Flask (defaults to utf-8) - charset = "utf-8" - try: - params = request.mimetype_params or {} - charset = params.get("charset", "utf-8") - except Exception: - pass - - data = raw.decode(charset, errors="strict") - - # Use Flask's JSON provider for identical behavior - try: - # for Flask >= 2.2 - loads = current_app.json.loads # type: ignore[unresolved-attribute] - except Exception: - loads = flask_json.loads - - try: - return loads(data) - except Exception: - if silent: - return None - raise + return SimulationTraceData.model_validate(data) @api.route("/simulations") class SimulationList(Resource): - LIMIT_HEADER = "simdb-result-limit" - PAGE_HEADER = "simdb-page" - SORT_BY_HEADER = "simdb-sort-by" - SORT_ASC_HEADER = "simdb-sort-asc" - - parser = api.parser() - parser.add_argument( - LIMIT_HEADER, location="headers", type=int, help="Limit returned results" - ) - parser.add_argument( - PAGE_HEADER, - location="headers", - type=int, - help="Specify the page of results to return", - ) - parser.add_argument( - SORT_BY_HEADER, - location="headers", - type=str, - help="Specify the field to sort the results by", - ) - parser.add_argument( - SORT_ASC_HEADER, - location="headers", - type=bool, - help="Specify if the results are sorted ascending or descending", - ) - - @api.expect(parser) - @api.response(200, "Success") - @api.response(401, "Unauthorized") @requires_auth() + @pydantic_validate(api) # @cache.cached(key_prefix=cache_key) - def get(self, user: User): - limit = int(request.headers.get(SimulationList.LIMIT_HEADER) or 100) - page = int(request.headers.get(SimulationList.PAGE_HEADER) or 1) - sort_by = request.headers.get(SimulationList.SORT_BY_HEADER, "") - sort_asc = ( - request.headers.get(SimulationList.SORT_ASC_HEADER, "false").lower() - == "true" - ) + def get( + self, + user: User, + pagination: Annotated[PaginationData, Header()], + ) -> PydanticResponse[PaginatedResponse[SimulationListItem]]: names = [] constraints = [] if request.args: @@ -262,70 +194,61 @@ def get(self, user: User): count, data = current_app.db.query_meta_data( constraints, names, - limit=limit, - page=page, - sort_by=sort_by, - sort_asc=sort_asc, + limit=pagination.limit, + page=pagination.page, + sort_by=pagination.sort_by, + sort_asc=pagination.sort_asc, ) else: count, data = current_app.db.list_simulation_data( meta_keys=names, - limit=limit, - page=page, - sort_by=sort_by, - sort_asc=sort_asc, + limit=pagination.limit, + page=pagination.page, + sort_by=pagination.sort_by, + sort_asc=pagination.sort_asc, ) - return jsonify({"count": count, "page": page, "limit": limit, "results": data}) + return PaginatedResponse[SimulationListItem].model_validate( + { + "count": count, + "page": pagination.page, + "limit": pagination.limit, + "results": data, + } + ) @requires_auth() - def post(self, user: User): + @pydantic_validate(api) + def post( + self, + user: User, + body: Annotated[SimulationPostData, Body()], + ) -> PydanticResponse[SimulationPostResponse]: try: - # _get_json_aware is a custom function to handle JSON parsing - # similar to Flask's request.get_json, but with gzip support. - # It returns None if the content type is not application/json. - # If silent=True, it returns None instead of raising an error. - # If force=True, it ignores the content type check. - data = _get_json_aware() - if not data: - return error("Invalid or missing JSON data") - - if "simulation" not in data: - return error("Simulation data not provided") - - add_watcher = data.get("add_watcher", True) - - simulation = models_sim.Simulation.from_data(data["simulation"]) + simulation = models_sim.Simulation.from_data_model(body.simulation) # Simulation Upload (Push) Date simulation.datetime = datetime.datetime.now() - if data["uploaded_by"] is not None: - simulation.set_meta("uploaded_by", data["uploaded_by"]) - elif user.email is not None: - simulation.set_meta("uploaded_by", user.email) - elif user.name is not None: - simulation.set_meta("uploaded_by", user.name) - else: - simulation.set_meta("uploaded_by", "anonymous") - if add_watcher: + uploaded_by = body.uploaded_by or user.email or user.name or "anonymous" + + simulation.set_meta("uploaded_by", uploaded_by) + + if body.add_watcher: simulation.watchers.append( models_watcher.Watcher( user.name, user.email, models_watcher.Notification.ALL ) ) - if "alias" in data["simulation"]: - alias = data["simulation"]["alias"] - if alias is not None: - (updated_alias, next_id) = _set_alias(alias) - if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) - simulation.alias = updated_alias - else: - simulation.alias = alias + alias = body.simulation.alias + if alias is not None: + (updated_alias, next_id) = _set_alias(alias) + if updated_alias: + simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.alias = updated_alias else: - simulation.alias = simulation.uuid.hex + simulation.alias = alias else: simulation.alias = simulation.uuid.hex @@ -334,17 +257,23 @@ def post(self, user: User): common_root = find_common_root(sim_file_paths) config = current_app.simdb_config + copy_files = config.get_option("server.copy_files", default=True) + imas_remote_host = config.get_option( + "server.imas_remote_host", default=None + ) - if config.get_option("server.copy_files", default=True): + if copy_files or imas_remote_host: staging_dir = ( Path(config.get_string_option("server.upload_folder")) / simulation.uuid.hex ) for sim_file in files: - if sim_file.uri.scheme == "file": - if sim_file.uri.path is None: - raise ValueError("Simulation path not set") + if ( + copy_files + and sim_file.uri.scheme == "file" + and sim_file.uri.path is not None + ): path = secure_path(sim_file.uri.path, common_root, staging_dir) if not path.exists(): raise ValueError( @@ -352,73 +281,54 @@ def post(self, user: User): ) sim_file.uri = URI(scheme="file", path=path) elif sim_file.uri.scheme == "imas": - path = secure_path( - Path(sim_file.uri.query["path"]), - common_root, - staging_dir, - is_file=common_root is not None, - ) - sim_file.uri = convert_uri(sim_file.uri, path, config) - elif config.get_option("server.imas_remote_host", default=None): - staging_dir = ( - Path(config.get_string_option("server.upload_folder")) - / simulation.uuid.hex - ) - - for sim_file in files: - if sim_file.uri.scheme == "imas": - if config.get_option("server.copy_files", default=True): + if copy_files: path = secure_path( Path(sim_file.uri.query["path"]), common_root, staging_dir, is_file=common_root is not None, ) - sim_file.uri = convert_uri(sim_file.uri, path, config) else: path = Path(sim_file.uri.query["path"]) - sim_file.uri = convert_uri(sim_file.uri, path, config) + sim_file.uri = convert_uri(sim_file.uri, path, config) - result = { - "ingested": simulation.uuid.hex, - } + result = SimulationPostResponse( + ingested=simulation.uuid, error=None, validation=None + ) + + error_on_fail = current_app.simdb_config.get_option( + "validation.error_on_fail", default=False + ) if current_app.simdb_config.get_option( "validation.auto_validate", default=False ): - result["validation"] = _validate(simulation, user) + result.validation = _validate(simulation, user) - if current_app.simdb_config.get_option( - "validation.error_on_fail", default=False - ): - if simulation.status == models_sim.Simulation.Status.NOT_VALIDATED: - raise Exception( - "Validation config option error_on_fail=True without " - "auto_validate=True." + if not result.validation.passed and error_on_fail: + return ErrorResponse( + error=f"Simulation validation failed and server has " + f"error_on_fail=True.\n{result.validation.error}" ) - elif simulation.status == models_sim.Simulation.Status.FAILED: - result["error"] = f"""Simulation validation failed and server has - error_on_fail=True.\n{result["validation"]["error"]}""" - response = jsonify(result) - response.status_code = 400 - return response + elif error_on_fail: + raise RuntimeError( + "Validation config option error_on_fail=True without " + "auto_validate=True." + ) + disable_replaces = config.get_option( + "development.disable_replaces", default=False + ) replaces = simulation.find_meta("replaces") - if ( - not current_app.simdb_config.get_option( - "development.disable_replaces", default=False - ) - and replaces - and replaces[0].value - ): + + if not disable_replaces and replaces and replaces[0].value: sim_id = replaces[0].value try: replaces_sim = current_app.db.get_simulation(sim_id) except DatabaseError: replaces_sim = None - if replaces_sim is None: - pass - else: + + if replaces_sim is not None: _update_simulation_status( replaces_sim, models_sim.Simulation.Status.DEPRECATED, user ) @@ -431,54 +341,58 @@ def post(self, user: User): with contextlib.suppress(OSError): create_alias_dir(simulation) - return jsonify(result) + return result except (DatabaseError, ValueError) as err: - return error(str(err)) + return ErrorResponse(error=str(err)) @api.route("/simulation/") class Simulation(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, sim_id: str, user: User): + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> PydanticResponse[SimulationDataResponse]: try: simulation = current_app.db.get_simulation(sim_id) if simulation: - sim_data = simulation.data(recurse=True) - sim_data["children"] = current_app.db.get_simulation_children( + sim_data = simulation.to_model_with_refs(recurse=True) + + sim_data.children = current_app.db.get_simulation_children_ref( simulation ) - sim_data["parents"] = current_app.db.get_simulation_parents(simulation) - return jsonify(sim_data) - return error("Simulation not found") + sim_data.parents = current_app.db.get_simulation_children_ref( + simulation + ) + return sim_data + return ErrorResponse(error="Simulation not found") except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) - parser = api.parser() - parser.add_argument( - "status", type=str, location="json", help="status", required=True - ) - - @api.expect(parser) @requires_auth("admin") - def patch(self, sim_id: str, user: Optional[User] = None): + @pydantic_validate(api) + def patch( + self, + sim_id: str, + user: Optional[User], + body: Annotated[StatusPatchData, Body()], + ) -> PydanticResponse[SimulationPatchResponse]: try: - data = request.get_json() or {} - if "status" not in data: - return error("Status not provided") simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") - status = models_sim.Simulation.Status(data["status"]) + status = models_sim.Simulation.Status(body.status) _update_simulation_status(simulation, status, user) current_app.db.insert_simulation(simulation) clear_cache() - return {} + return SimulationPatchResponse() except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) @requires_auth("admin") - def delete(self, sim_id: str, user: User): + @pydantic_validate(api) + def delete( + self, sim_id: str, user: User + ) -> PydanticResponse[SimulationDeleteResponse]: try: simulation = current_app.db.delete_simulation(sim_id) clear_cache() @@ -497,48 +411,46 @@ def delete(self, sim_id: str, user: User): directory = first_file.uri.path.parent if directory != Path() and directory != Path("/"): directory.rmdir() - return jsonify({"deleted": {"simulation": simulation.uuid, "files": files}}) + return SimulationDeleteResponse( + deleted=DeletedSimulation(simulation=simulation.uuid, files=files) + ) except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) @api.route("/simulation/metadata/") class SimulationMeta(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, sim_id: str, user: User): + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> PydanticResponse[MetadataDataList]: try: simulation = current_app.db.get_simulation(sim_id) if simulation: - return jsonify([meta.data() for meta in simulation.meta]) - return error("Simulation not found") + return MetadataDataList.model_validate( + [meta.data() for meta in simulation.meta] + ) + return ErrorResponse(error="Simulation not found") except DatabaseError as err: - return error(str(err)) - - parser = api.parser() - parser.add_argument("key", type=str, location="json", help="status", required=True) - parser.add_argument( - "value", type=str, location="json", help="status", required=True - ) + return ErrorResponse(error=str(err)) - @api.expect(parser) @requires_auth("admin") - def patch(self, sim_id: str, user: Optional[User] = None): + @pydantic_validate(api) + def patch( + self, + sim_id: str, + user: Optional[User], + body: Annotated[MetadataPatchData, Body()], + ) -> PydanticResponse[MetadataDataList]: try: - data = request.get_json() or {} - - if "key" not in data: - return error("Metadata key not provided") - - if "value" not in data: - return error("New metadata value not provided") - - key = data["key"] - value = data["value"].lower() + key = body.key + value = body.value.lower() simulation = current_app.db.get_simulation(sim_id) if simulation is None: - raise ValueError(f"Simulation {sim_id} not found.") - old_values = [meta.data() for meta in simulation.find_meta(key)] + return ErrorResponse(error=f"Simulation {sim_id} not found.") + old_values = MetadataDataList.model_validate( + [meta.data() for meta in simulation.find_meta(key)] + ) if key.lower() != "status": simulation.set_meta(key, value) else: @@ -549,60 +461,54 @@ def patch(self, sim_id: str, user: Optional[User] = None): clear_cache() return old_values except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) - parser_delete = api.parser() - parser_delete.add_argument( - "key", type=str, location="json", help="metadata key", required=True - ) - - @api.expect(parser_delete) @requires_auth("admin") - def delete(self, sim_id: str, user: Optional[User] = None): + @pydantic_validate(api) + def delete( + self, + sim_id: str, + user: Optional[User], + body: Annotated[MetadataDeleteData, Body()], + ) -> PydanticResponse[MetadataDeleteResponse]: try: - data = request.get_json() or {} - - if "key" not in data: - return error("Metadata key not provided") - - key = data["key"] - simulation = current_app.db.get_simulation(sim_id) if simulation is None: - raise ValueError(f"Simulation {sim_id} not found.") + return ErrorResponse(error=f"Simulation {sim_id} not found.") - simulation.remove_meta(key) + simulation.remove_meta(body.key) current_app.db.insert_simulation(simulation) clear_cache() - return {} + return MetadataDeleteResponse() except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) @api.route("/validate/") class ValidateSimulation(Resource): @requires_auth() - def post(self, sim_id, user: User): + @pydantic_validate(api) + def post(self, sim_id, user: User) -> PydanticResponse[ValidationResult]: try: simulation = current_app.db.get_simulation(sim_id) result = _validate(simulation, user) current_app.db.insert_simulation(simulation) clear_cache() - return jsonify(result) + return result except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) @api.route("/trace/") class SimulationTrace(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, sim_id: str, user: User): + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> PydanticResponse[SimulationTraceData]: try: - data = _build_trace(sim_id) - return jsonify(data) + return _build_trace(sim_id) except DatabaseError as err: - return error(str(err)) + return ErrorResponse(error=str(err)) @api.route("/simulation/package/") diff --git a/src/simdb/remote/app.py b/src/simdb/remote/app.py index 7cf0cabf..042a7d50 100644 --- a/src/simdb/remote/app.py +++ b/src/simdb/remote/app.py @@ -31,6 +31,7 @@ def create_app( CORS(app, resources={r"/*": {"origins": "*"}}) app.config["TESTING"] = testing app.config["DEBUG"] = debug + app.config["RESTX_INCLUDE_ALL_MODELS"] = True app.config["PROFILE"] = profile app.json_encoder = cast(Type[JSONEncoder], CustomEncoder) app.json_decoder = cast(Type[JSONDecoder], CustomDecoder) diff --git a/src/simdb/remote/core/pydantic_utils.py b/src/simdb/remote/core/pydantic_utils.py new file mode 100644 index 00000000..aaf561a8 --- /dev/null +++ b/src/simdb/remote/core/pydantic_utils.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import contextlib +import copy +import functools +import gzip +import inspect +import typing +from typing import ( + Annotated, + Any, + TypeVar, + Union, + get_args, + get_origin, +) + +from flask import Response, request +from flask_restx import Namespace +from pydantic import BaseModel, ValidationError + +from simdb.remote.core.errors import error as _error +from simdb.remote.models import ErrorResponse + +M = TypeVar("M", bound=BaseModel) + +PydanticResponse = Union[ErrorResponse, M] +# --------------------------------------------------------------------------- +# Marker classes for Annotated-style parameter declarations +# --------------------------------------------------------------------------- + + +class _ParamSource: + """Base class for parameter source markers.""" + + +class Header(_ParamSource): + """Marker: populate this parameter from ``request.headers``.""" + + +class Body(_ParamSource): + """Marker: populate this parameter from the JSON request body.""" + + +class Query(_ParamSource): + """Marker: populate this parameter from ``request.args`` (query string).""" + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _register_defs(ns: Namespace, defs): + for name, schema in defs.items(): + # Clean the schema: rewrite refs and remove internal $defs + clean_schema = copy.deepcopy(schema) + children = clean_schema.pop("$defs", {}) + ns.schema_model(name, clean_schema) + _register_defs(ns, children) + + +def _collect_and_register(ns: Namespace, model: type[BaseModel]): + """ + Registers a Pydantic model and all nested dependencies to a Flask-RESTX Namespace. + """ + full_schema = model.model_json_schema(ref_template="#/definitions/{model}") + all_defs = full_schema.get("$defs", {}) + + # 1. Register all sub-models found in $defs first + _register_defs(ns, all_defs) + + # 2. Register the root model + root_name = model.__name__ + root_schema = copy.deepcopy(full_schema) + root_schema.pop("$defs", None) + + return ns.schema_model(root_name, root_schema) + + +def _get_annotated_params( + f: Any, +) -> list[tuple[str, type[BaseModel], _ParamSource]]: + """Inspect *f*'s signature and return a list of ``(param_name, model, source)`` + tuples for every parameter annotated as ``Annotated[SomePydanticModel, Source()]``. + """ + results = [] + sig = inspect.signature(f) + + for param_name, param in sig.parameters.items(): + annotation = param.annotation + if annotation is inspect.Parameter.empty: + continue + if get_origin(annotation) is Annotated: + args = get_args(annotation) + if len(args) >= 2: + model_type = args[0] + source = args[1] + if ( + isinstance(source, _ParamSource) + and isinstance(model_type, type) + and issubclass(model_type, BaseModel) + ): + results.append((param_name, model_type, source)) + return results + + +def _validate_param( + model: type[BaseModel], + source: _ParamSource, +) -> BaseModel: + """Validate and return a Pydantic model instance from the appropriate + part of the current Flask request. + + Raises + ------ + ValidationError + If the data does not conform to the model. + ValueError + If the request body is missing or not valid JSON (for Body sources). + """ + if isinstance(source, Header): + # Convert Werkzeug Headers to a plain dict (lowercase keys) + raw = {k.lower(): v for k, v in request.headers.items()} + return model.model_validate(raw) + + elif isinstance(source, Query): + # Convert ImmutableMultiDict to a plain dict (lists for multi-values) + raw = request.args.to_dict(flat=False) + # Flatten single-value lists for convenience + flat = {k: v[0] if len(v) == 1 else v for k, v in raw.items()} + return model.model_validate(flat) + + else: + enc = (request.headers.get("Content-Encoding") or "").lower() + request_data = request.get_data(cache=False) + if request_data is None: + raise ValueError("Invalid or missing JSON body") + + if enc == "gzip": + with contextlib.suppress(OSError): + request_data = gzip.decompress(request_data) + + return model.model_validate_json(request_data) + + +# --------------------------------------------------------------------------- +# FastAPI-style route decorator +# --------------------------------------------------------------------------- + + +def pydantic_validate( + ns: Namespace, + *, + response_model: type[BaseModel] | None = None, + error_model: type[BaseModel] | None = None, +) -> Any: + """Decorator factory that wires up Pydantic validation for a Flask-RESTX endpoint. + + Inspects the decorated function's signature for parameters annotated with + ``Annotated[SomePydanticModel, Header()]``, ``Annotated[SomePydanticModel, Body()]`` + or ``Annotated[SomePydanticModel, Query()]``, validates the corresponding parts of + the incoming request, and injects the validated model instances as keyword + arguments. + + All discovered input models are automatically registered with *ns* for + Swagger/OpenAPI documentation. Body models are registered as ``@ns.expect`` + models; header/query models are registered as parser arguments. + + If the function's return annotation is a ``BaseModel`` subclass (or if + *response_model* is provided explicitly) the return value is automatically + serialised with ``model_dump(mode="json")`` and wrapped in ``jsonify``. + + Parameters + ---------- + ns: + The Flask-RESTX ``Namespace`` (or ``Api``) to register models on. + response_model: + Optional explicit response model. If ``None`` the decorator tries to + infer it from the function's return annotation. + + Returns + ------- + A decorator suitable for use on Flask-RESTX ``Resource`` methods. + + Example + ------- + .. code-block:: python + + from typing import Annotated + from simdb.remote.core.pydantic_utils import restx_route, Header, Body, Query + + class SimulationList(Resource): + @restx_route(api) + def get( + self, + user: User, + pagination: Annotated[PaginationData, Header()], + ) -> PydanticResponse[PaginatedResponse]: + ... + + @restx_route(api) + def post( + self, + user: User, + body: Annotated[SimulationPostData, Body()], + ) -> PydanticResponse[SimulationPostResponse]: + ... + """ + + def decorator(f): + annotated_params = _get_annotated_params(f) + + # Determine response model from return annotation if not given explicitly + _response_model = response_model + _error_model = error_model + if _response_model is None: + ret = inspect.signature(f).return_annotation + if ( + ret is not inspect.Parameter.empty + and typing.get_origin(ret) is typing.Union + and issubclass(typing.get_args(ret)[0], ErrorResponse) + and issubclass(typing.get_args(ret)[1], BaseModel) + ): + _response_model = typing.get_args(ret)[1] + _error_model = typing.get_args(ret)[0] + + # Register all input models with the namespace + _registered = {} + _body_schema = None + for _param_name, model_type, source in annotated_params: + schema = _collect_and_register(ns, model_type) + if isinstance(source, Body): + _body_schema = schema + + # Register response model + _resp_schema = None + if _response_model is not None: + _resp_schema = _collect_and_register(ns, _response_model) + + _error_schema = None + if _error_model is not None: + _error_schema = _collect_and_register(ns, _error_model) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + for param_name, model_type, source in annotated_params: + try: + validated = _validate_param(model_type, source) + except ValueError as exc: + return _error(str(exc)) + except ValidationError as exc: + first_error = exc.errors()[0] + loc = " -> ".join(str(loc) for loc in first_error["loc"]) + msg = f"Validation error at '{loc}': {first_error['msg']}" + return _error(msg) + kwargs[param_name] = validated + + try: + result = f(*args, **kwargs) + except Exception as err: + return Response( + response=ErrorResponse(error=str(err)).model_dump_json(), + status=400, + mimetype="application/json", + ) + + if isinstance(result, ErrorResponse): + return Response( + response=result.model_dump_json(), + status=400, + mimetype="application/json", + ) + if isinstance(result, BaseModel): + return Response( + result.model_dump_json(), + mimetype="application/json", + ) + + return result + + if _body_schema is not None: + wrapper = ns.expect(_body_schema)(wrapper) + if _resp_schema is not None: + wrapper = ns.response(200, "Success", _resp_schema)(wrapper) + if _error_schema is not None: + wrapper = ns.response(400, "Error", _error_schema)(wrapper) + + return wrapper + + return decorator diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 2152d0e9..a8db2e24 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -18,13 +18,18 @@ from uuid import UUID, uuid1 from pydantic import ( - BaseModel, + BaseModel as _BaseModel, +) +from pydantic import ( BeforeValidator, + ConfigDict, Field, PlainSerializer, - RootModel, model_validator, ) +from pydantic import ( + RootModel as _RootModel, +) from simdb.cli.manifest import DataObject @@ -32,6 +37,14 @@ """UUID serialized as a hex string.""" +class BaseModel(_BaseModel): + model_config = ConfigDict(use_attribute_docstrings=True) + + +class RootModel(_RootModel): + model_config = ConfigDict(use_attribute_docstrings=True) + + def _deserialize_custom_uuid(v: Any) -> UUID: """Deserialize CustomUUID format back to UUID.""" if isinstance(v, UUID): @@ -64,7 +77,7 @@ class StatusPatchData(BaseModel): class DeletedSimulation(BaseModel): """Reference to a deleted simulation.""" - uuid: UUID + simulation: UUID """UUID of the deleted simulation.""" files: List[str] """List of deleted file paths.""" @@ -171,6 +184,10 @@ def as_querystring(self) -> str: return urlencode(self.as_dict()) +class MetadataDeleteResponse(BaseModel): + pass + + class SimulationReference(BaseModel): """Reference to a simulation.""" @@ -206,6 +223,10 @@ class SimulationDataResponse(SimulationData): """Child simulations.""" +class SimulationPatchResponse(BaseModel): + pass + + class SimulationPostData(BaseModel): """Data for creating a new simulation.""" @@ -469,3 +490,10 @@ class StagingDirectoryResponse(BaseModel): staging_dir: Path """Path to the staging dir.""" + + +class ErrorResponse(BaseModel): + """Response model for server errors.""" + + error: str + """Error description.""" From a28cfa823c04c493937483342d5cd100ba3dbdd8 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 26 Feb 2026 10:17:40 +0100 Subject: [PATCH 02/12] Raise errors instead of returning them --- src/simdb/remote/apis/files.py | 12 ++--- src/simdb/remote/apis/v1_2/simulations.py | 57 +++++++++++------------ src/simdb/remote/core/pydantic_utils.py | 36 ++++++++------ 3 files changed, 54 insertions(+), 51 deletions(-) diff --git a/src/simdb/remote/apis/files.py b/src/simdb/remote/apis/files.py index c4047c11..94417401 100644 --- a/src/simdb/remote/apis/files.py +++ b/src/simdb/remote/apis/files.py @@ -18,9 +18,9 @@ from simdb.remote.core.auth import User, requires_auth from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path -from simdb.remote.core.pydantic_utils import PydanticResponse, pydantic_validate +from simdb.remote.core.pydantic_utils import ResponseException, pydantic_validate from simdb.remote.core.typing import current_app -from simdb.remote.models import ErrorResponse, FileDataList, FileGetDataResponse +from simdb.remote.models import FileDataList, FileGetDataResponse from simdb.uri import URI api = Namespace("files", path="/") @@ -173,7 +173,7 @@ def _handle_file_upload() -> Response: class FileList(Resource): @requires_auth() @pydantic_validate(api) - def get(self, user: User) -> PydanticResponse[FileDataList]: + def get(self, user: User) -> FileDataList: files = current_app.db.list_files() return FileDataList.model_validate([file.data() for file in files]) @@ -193,14 +193,12 @@ def post(self, user: User): class File(Resource): @requires_auth() @pydantic_validate(api) - def get( - self, file_uuid: str, user: Optional[User] = None - ) -> PydanticResponse[FileGetDataResponse]: + def get(self, file_uuid: str, user: Optional[User] = None) -> FileGetDataResponse: try: file = current_app.db.get_file(file_uuid) return file.to_model_with_path() except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @api.route("/file/download/") diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index b9f8a1df..7feb9d90 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -24,13 +24,12 @@ from simdb.remote.core.pydantic_utils import ( Body, Header, - PydanticResponse, + ResponseException, pydantic_validate, ) from simdb.remote.core.typing import current_app from simdb.remote.models import ( DeletedSimulation, - ErrorResponse, MetadataDataList, MetadataDeleteData, MetadataDeleteResponse, @@ -176,7 +175,7 @@ def get( self, user: User, pagination: Annotated[PaginationData, Header()], - ) -> PydanticResponse[PaginatedResponse[SimulationListItem]]: + ) -> PaginatedResponse[SimulationListItem]: names = [] constraints = [] if request.args: @@ -223,7 +222,7 @@ def post( self, user: User, body: Annotated[SimulationPostData, Body()], - ) -> PydanticResponse[SimulationPostResponse]: + ) -> SimulationPostResponse: try: simulation = models_sim.Simulation.from_data_model(body.simulation) @@ -306,12 +305,12 @@ def post( result.validation = _validate(simulation, user) if not result.validation.passed and error_on_fail: - return ErrorResponse( - error=f"Simulation validation failed and server has " + raise ResponseException( + f"Simulation validation failed and server has " f"error_on_fail=True.\n{result.validation.error}" ) elif error_on_fail: - raise RuntimeError( + raise ResponseException( "Validation config option error_on_fail=True without " "auto_validate=True." ) @@ -343,7 +342,7 @@ def post( return result except (DatabaseError, ValueError) as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @api.route("/simulation/") @@ -351,7 +350,7 @@ class Simulation(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] @pydantic_validate(api) - def get(self, sim_id: str, user: User) -> PydanticResponse[SimulationDataResponse]: + def get(self, sim_id: str, user: User) -> SimulationDataResponse: try: simulation = current_app.db.get_simulation(sim_id) if simulation: @@ -364,9 +363,9 @@ def get(self, sim_id: str, user: User) -> PydanticResponse[SimulationDataRespons simulation ) return sim_data - return ErrorResponse(error="Simulation not found") + raise ResponseException("Simulation not found") except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @requires_auth("admin") @pydantic_validate(api) @@ -375,7 +374,7 @@ def patch( sim_id: str, user: Optional[User], body: Annotated[StatusPatchData, Body()], - ) -> PydanticResponse[SimulationPatchResponse]: + ) -> SimulationPatchResponse: try: simulation = current_app.db.get_simulation(sim_id) if simulation is None: @@ -386,13 +385,11 @@ def patch( clear_cache() return SimulationPatchResponse() except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @requires_auth("admin") @pydantic_validate(api) - def delete( - self, sim_id: str, user: User - ) -> PydanticResponse[SimulationDeleteResponse]: + def delete(self, sim_id: str, user: User) -> SimulationDeleteResponse: try: simulation = current_app.db.delete_simulation(sim_id) clear_cache() @@ -415,7 +412,7 @@ def delete( deleted=DeletedSimulation(simulation=simulation.uuid, files=files) ) except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @api.route("/simulation/metadata/") @@ -423,16 +420,16 @@ class SimulationMeta(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] @pydantic_validate(api) - def get(self, sim_id: str, user: User) -> PydanticResponse[MetadataDataList]: + def get(self, sim_id: str, user: User) -> MetadataDataList: try: simulation = current_app.db.get_simulation(sim_id) if simulation: return MetadataDataList.model_validate( [meta.data() for meta in simulation.meta] ) - return ErrorResponse(error="Simulation not found") + raise ResponseException("Simulation not found") except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @requires_auth("admin") @pydantic_validate(api) @@ -441,13 +438,13 @@ def patch( sim_id: str, user: Optional[User], body: Annotated[MetadataPatchData, Body()], - ) -> PydanticResponse[MetadataDataList]: + ) -> MetadataDataList: try: key = body.key value = body.value.lower() simulation = current_app.db.get_simulation(sim_id) if simulation is None: - return ErrorResponse(error=f"Simulation {sim_id} not found.") + raise ResponseException(f"Simulation {sim_id} not found.") old_values = MetadataDataList.model_validate( [meta.data() for meta in simulation.find_meta(key)] ) @@ -461,7 +458,7 @@ def patch( clear_cache() return old_values except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @requires_auth("admin") @pydantic_validate(api) @@ -470,25 +467,25 @@ def delete( sim_id: str, user: Optional[User], body: Annotated[MetadataDeleteData, Body()], - ) -> PydanticResponse[MetadataDeleteResponse]: + ) -> MetadataDeleteResponse: try: simulation = current_app.db.get_simulation(sim_id) if simulation is None: - return ErrorResponse(error=f"Simulation {sim_id} not found.") + raise ResponseException(f"Simulation {sim_id} not found.") simulation.remove_meta(body.key) current_app.db.insert_simulation(simulation) clear_cache() return MetadataDeleteResponse() except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @api.route("/validate/") class ValidateSimulation(Resource): @requires_auth() @pydantic_validate(api) - def post(self, sim_id, user: User) -> PydanticResponse[ValidationResult]: + def post(self, sim_id, user: User) -> ValidationResult: try: simulation = current_app.db.get_simulation(sim_id) result = _validate(simulation, user) @@ -496,7 +493,7 @@ def post(self, sim_id, user: User) -> PydanticResponse[ValidationResult]: clear_cache() return result except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @api.route("/trace/") @@ -504,11 +501,11 @@ class SimulationTrace(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] @pydantic_validate(api) - def get(self, sim_id: str, user: User) -> PydanticResponse[SimulationTraceData]: + def get(self, sim_id: str, user: User) -> SimulationTraceData: try: return _build_trace(sim_id) except DatabaseError as err: - return ErrorResponse(error=str(err)) + raise ResponseException(str(err)) from err @api.route("/simulation/package/") diff --git a/src/simdb/remote/core/pydantic_utils.py b/src/simdb/remote/core/pydantic_utils.py index aaf561a8..4c61028a 100644 --- a/src/simdb/remote/core/pydantic_utils.py +++ b/src/simdb/remote/core/pydantic_utils.py @@ -5,12 +5,9 @@ import functools import gzip import inspect -import typing from typing import ( Annotated, Any, - TypeVar, - Union, get_args, get_origin, ) @@ -22,9 +19,16 @@ from simdb.remote.core.errors import error as _error from simdb.remote.models import ErrorResponse -M = TypeVar("M", bound=BaseModel) -PydanticResponse = Union[ErrorResponse, M] +class ResponseException(Exception): + """Raised an error has occurred in the request.""" + + def __init__(self, message, return_code=400): + super().__init__(message) + self.message = message + self.return_code = return_code + + # --------------------------------------------------------------------------- # Marker classes for Annotated-style parameter declarations # --------------------------------------------------------------------------- @@ -153,7 +157,7 @@ def pydantic_validate( ns: Namespace, *, response_model: type[BaseModel] | None = None, - error_model: type[BaseModel] | None = None, + error_model: type[BaseModel] = ErrorResponse, ) -> Any: """Decorator factory that wires up Pydantic validation for a Flask-RESTX endpoint. @@ -196,7 +200,7 @@ def get( self, user: User, pagination: Annotated[PaginationData, Header()], - ) -> PydanticResponse[PaginatedResponse]: + ) -> PaginatedResponse: ... @restx_route(api) @@ -204,7 +208,7 @@ def post( self, user: User, body: Annotated[SimulationPostData, Body()], - ) -> PydanticResponse[SimulationPostResponse]: + ) -> SimulationPostResponse: ... """ @@ -218,12 +222,10 @@ def decorator(f): ret = inspect.signature(f).return_annotation if ( ret is not inspect.Parameter.empty - and typing.get_origin(ret) is typing.Union - and issubclass(typing.get_args(ret)[0], ErrorResponse) - and issubclass(typing.get_args(ret)[1], BaseModel) + and inspect.isclass(ret) + and issubclass(ret, BaseModel) ): - _response_model = typing.get_args(ret)[1] - _error_model = typing.get_args(ret)[0] + _response_model = ret # Register all input models with the namespace _registered = {} @@ -258,9 +260,15 @@ def wrapper(*args, **kwargs): try: result = f(*args, **kwargs) + except ResponseException as err: + return Response( + response=_error_model(error=err.message).model_dump_json(), + status=err.return_code, + mimetype="application/json", + ) except Exception as err: return Response( - response=ErrorResponse(error=str(err)).model_dump_json(), + response=_error_model(error=str(err)).model_dump_json(), status=400, mimetype="application/json", ) From 6e0f8b5352fcde9e7124e247d7464f17b447ab8d Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 26 Feb 2026 14:50:29 +0100 Subject: [PATCH 03/12] Cleanup endpoints --- src/simdb/database/models/simulation.py | 24 ++ src/simdb/database/models/watcher.py | 4 + src/simdb/remote/apis/files.py | 11 +- src/simdb/remote/apis/v1_2/simulations.py | 401 ++++++++++------------ src/simdb/remote/apis/watchers.py | 85 +++-- src/simdb/remote/core/pydantic_utils.py | 6 +- 6 files changed, 260 insertions(+), 271 deletions(-) diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index a88b6d6a..b2465558 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -14,6 +14,7 @@ MetadataDataList, SimulationData, SimulationDataResponse, + SimulationTraceData, ) if sys.version_info < (3, 11): @@ -420,6 +421,29 @@ def to_model_with_refs( children=[], ) + def to_model_trace( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationTraceData: + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() + if recurse: + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) + elif meta_keys: + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) + return SimulationTraceData( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + ) + def meta_dict(self) -> Dict[str, Union[Dict, Any]]: meta = {m.element: m.value for m in self.meta} return unflatten_dict(meta) diff --git a/src/simdb/database/models/watcher.py b/src/simdb/database/models/watcher.py index 7785dc93..c7532b72 100644 --- a/src/simdb/database/models/watcher.py +++ b/src/simdb/database/models/watcher.py @@ -7,6 +7,7 @@ from simdb.docstrings import inherit_docstrings from simdb.notifications import Notification +from simdb.remote.models import WatcherData from .base import Base from .types import ChoiceType @@ -59,3 +60,6 @@ def data(self, recurse: bool = False) -> Dict[str, str]: "notification": str(self.notification), } return data + + def to_model(self) -> WatcherData: + return WatcherData.model_validate(self.data()) diff --git a/src/simdb/remote/apis/files.py b/src/simdb/remote/apis/files.py index 94417401..e846fa57 100644 --- a/src/simdb/remote/apis/files.py +++ b/src/simdb/remote/apis/files.py @@ -18,7 +18,7 @@ from simdb.remote.core.auth import User, requires_auth from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path -from simdb.remote.core.pydantic_utils import ResponseException, pydantic_validate +from simdb.remote.core.pydantic_utils import pydantic_validate from simdb.remote.core.typing import current_app from simdb.remote.models import FileDataList, FileGetDataResponse from simdb.uri import URI @@ -175,7 +175,7 @@ class FileList(Resource): @pydantic_validate(api) def get(self, user: User) -> FileDataList: files = current_app.db.list_files() - return FileDataList.model_validate([file.data() for file in files]) + return FileDataList.model_validate([file.to_model() for file in files]) @requires_auth() def post(self, user: User): @@ -194,11 +194,8 @@ class File(Resource): @requires_auth() @pydantic_validate(api) def get(self, file_uuid: str, user: Optional[User] = None) -> FileGetDataResponse: - try: - file = current_app.db.get_file(file_uuid) - return file.to_model_with_path() - except DatabaseError as err: - raise ResponseException(str(err)) from err + file = current_app.db.get_file(file_uuid) + return file.to_model_with_path() @api.route("/file/download/") diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 7feb9d90..419e19e9 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -4,7 +4,7 @@ import tarfile from io import BytesIO from pathlib import Path -from typing import Annotated, Any, Dict, List, Optional, Tuple, cast +from typing import Annotated, List, Optional, Tuple from flask import request, send_file from flask_restx import Namespace, Resource @@ -137,33 +137,29 @@ def _set_alias(alias: str): def _build_trace(sim_id: str) -> SimulationTraceData: simulation = current_app.db.get_simulation(sim_id) - data: Dict[str, Any] = cast(Dict[str, Any], simulation.data(recurse=False)) + data = simulation.to_model_trace(recurse=False) - status = simulation.find_meta("status") - if status: - status_value = status[0].value - if isinstance(status_value, str): - data["status"] = status_value - else: - data["status"] = status_value.value - status_on_name = str(data["status"]) + "_on" - status_on = simulation.find_meta(status_on_name) - if status_on: - data[status_on_name] = status_on[0].value + def get_meta_val(key, default=None): + meta = simulation.find_meta(key) + return meta[0].value if meta else default + + status_val = get_meta_val("status") + if status_val: + data.status = status_val if isinstance(status_val, str) else status_val.value - replaces = simulation.find_meta("replaces") - if replaces: - data["replaces"] = _build_trace(replaces[0].value) + status_on_key = f"{data.status}_on" + status_on_val = get_meta_val(status_on_key) + if status_on_val: + setattr(data, status_on_key, status_on_val) - replaced_on = simulation.find_meta("replaced_on") - if replaced_on: - data["deprecated_on"] = replaced_on[0].value + replaces_id = get_meta_val("replaces") + if replaces_id: + data.replaces = _build_trace(replaces_id) - replaces_reason = simulation.find_meta("replaces_reason") - if replaces_reason: - data["replaces_reason"] = replaces_reason[0].value + data.deprecated_on = get_meta_val("replaced_on") + data.replaces_reason = get_meta_val("replaces_reason") - return SimulationTraceData.model_validate(data) + return data @api.route("/simulations") @@ -223,126 +219,121 @@ def post( user: User, body: Annotated[SimulationPostData, Body()], ) -> SimulationPostResponse: - try: - simulation = models_sim.Simulation.from_data_model(body.simulation) + simulation = models_sim.Simulation.from_data_model(body.simulation) - # Simulation Upload (Push) Date - simulation.datetime = datetime.datetime.now() + # Simulation Upload (Push) Date + simulation.datetime = datetime.datetime.now() - uploaded_by = body.uploaded_by or user.email or user.name or "anonymous" + uploaded_by = body.uploaded_by or user.email or user.name or "anonymous" - simulation.set_meta("uploaded_by", uploaded_by) + simulation.set_meta("uploaded_by", uploaded_by) - if body.add_watcher: - simulation.watchers.append( - models_watcher.Watcher( - user.name, user.email, models_watcher.Notification.ALL - ) + if body.add_watcher: + simulation.watchers.append( + models_watcher.Watcher( + user.name, user.email, models_watcher.Notification.ALL ) + ) - alias = body.simulation.alias - if alias is not None: - (updated_alias, next_id) = _set_alias(alias) - if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) - simulation.alias = updated_alias - else: - simulation.alias = alias + alias = body.simulation.alias + if alias is not None: + (updated_alias, next_id) = _set_alias(alias) + if updated_alias: + simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.alias = updated_alias else: - simulation.alias = simulation.uuid.hex + simulation.alias = alias + else: + simulation.alias = simulation.uuid.hex - files = list(itertools.chain(simulation.inputs, simulation.outputs)) - sim_file_paths = simulation.file_paths() - common_root = find_common_root(sim_file_paths) + files = list(itertools.chain(simulation.inputs, simulation.outputs)) + sim_file_paths = simulation.file_paths() + common_root = find_common_root(sim_file_paths) - config = current_app.simdb_config - copy_files = config.get_option("server.copy_files", default=True) - imas_remote_host = config.get_option( - "server.imas_remote_host", default=None - ) + config = current_app.simdb_config + copy_files = config.get_option("server.copy_files", default=True) + imas_remote_host = config.get_option("server.imas_remote_host", default=None) - if copy_files or imas_remote_host: - staging_dir = ( - Path(config.get_string_option("server.upload_folder")) - / simulation.uuid.hex - ) - - for sim_file in files: - if ( - copy_files - and sim_file.uri.scheme == "file" - and sim_file.uri.path is not None - ): - path = secure_path(sim_file.uri.path, common_root, staging_dir) - if not path.exists(): - raise ValueError( - f"simulation file {sim_file.uuid} not uploaded" - ) - sim_file.uri = URI(scheme="file", path=path) - elif sim_file.uri.scheme == "imas": - if copy_files: - path = secure_path( - Path(sim_file.uri.query["path"]), - common_root, - staging_dir, - is_file=common_root is not None, - ) - else: - path = Path(sim_file.uri.query["path"]) - sim_file.uri = convert_uri(sim_file.uri, path, config) - - result = SimulationPostResponse( - ingested=simulation.uuid, error=None, validation=None + if copy_files or imas_remote_host: + staging_dir = ( + Path(config.get_string_option("server.upload_folder")) + / simulation.uuid.hex ) - error_on_fail = current_app.simdb_config.get_option( - "validation.error_on_fail", default=False - ) + for sim_file in files: + if ( + copy_files + and sim_file.uri.scheme == "file" + and sim_file.uri.path is not None + ): + path = secure_path(sim_file.uri.path, common_root, staging_dir) + if not path.exists(): + raise ResponseException( + f"simulation file {sim_file.uuid} not uploaded" + ) + sim_file.uri = URI(scheme="file", path=path) + elif sim_file.uri.scheme == "imas": + if copy_files: + path = secure_path( + Path(sim_file.uri.query["path"]), + common_root, + staging_dir, + is_file=common_root is not None, + ) + else: + path = Path(sim_file.uri.query["path"]) + sim_file.uri = convert_uri(sim_file.uri, path, config) + + result = SimulationPostResponse( + ingested=simulation.uuid, error=None, validation=None + ) - if current_app.simdb_config.get_option( - "validation.auto_validate", default=False - ): - result.validation = _validate(simulation, user) + error_on_fail = current_app.simdb_config.get_option( + "validation.error_on_fail", default=False + ) - if not result.validation.passed and error_on_fail: - raise ResponseException( - f"Simulation validation failed and server has " - f"error_on_fail=True.\n{result.validation.error}" - ) - elif error_on_fail: + if current_app.simdb_config.get_option( + "validation.auto_validate", default=False + ): + result.validation = _validate(simulation, user) + + if not result.validation.passed and error_on_fail: raise ResponseException( - "Validation config option error_on_fail=True without " - "auto_validate=True." + f"Simulation validation failed and server has " + f"error_on_fail=True.\n{result.validation.error}" ) - - disable_replaces = config.get_option( - "development.disable_replaces", default=False + elif error_on_fail: + raise ResponseException( + "Validation config option error_on_fail=True without " + "auto_validate=True." ) - replaces = simulation.find_meta("replaces") - - if not disable_replaces and replaces and replaces[0].value: - sim_id = replaces[0].value - try: - replaces_sim = current_app.db.get_simulation(sim_id) - except DatabaseError: - replaces_sim = None - if replaces_sim is not None: - _update_simulation_status( - replaces_sim, models_sim.Simulation.Status.DEPRECATED, user - ) - replaces_sim.set_meta("replaced_by", simulation.uuid) - current_app.db.insert_simulation(replaces_sim) + disable_replaces = config.get_option( + "development.disable_replaces", default=False + ) + replaces = simulation.find_meta("replaces") + + if not disable_replaces and replaces and replaces[0].value: + sim_id = replaces[0].value + try: + replaces_sim = current_app.db.get_simulation(sim_id) + except DatabaseError: + replaces_sim = None + + if replaces_sim is not None: + _update_simulation_status( + replaces_sim, models_sim.Simulation.Status.DEPRECATED, user + ) + replaces_sim.set_meta("replaced_by", simulation.uuid) + current_app.db.insert_simulation(replaces_sim) - current_app.db.insert_simulation(simulation) - clear_cache() + current_app.db.insert_simulation(simulation) + clear_cache() - with contextlib.suppress(OSError): - create_alias_dir(simulation) + with contextlib.suppress(OSError): + create_alias_dir(simulation) - return result - except (DatabaseError, ValueError) as err: - raise ResponseException(str(err)) from err + return result @api.route("/simulation/") @@ -351,21 +342,14 @@ class Simulation(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] @pydantic_validate(api) def get(self, sim_id: str, user: User) -> SimulationDataResponse: - try: - simulation = current_app.db.get_simulation(sim_id) - if simulation: - sim_data = simulation.to_model_with_refs(recurse=True) + simulation = current_app.db.get_simulation(sim_id) + if simulation: + sim_data = simulation.to_model_with_refs(recurse=True) - sim_data.children = current_app.db.get_simulation_children_ref( - simulation - ) - sim_data.parents = current_app.db.get_simulation_children_ref( - simulation - ) - return sim_data - raise ResponseException("Simulation not found") - except DatabaseError as err: - raise ResponseException(str(err)) from err + sim_data.children = current_app.db.get_simulation_children_ref(simulation) + sim_data.parents = current_app.db.get_simulation_children_ref(simulation) + return sim_data + raise ResponseException("Simulation not found") @requires_auth("admin") @pydantic_validate(api) @@ -375,44 +359,38 @@ def patch( user: Optional[User], body: Annotated[StatusPatchData, Body()], ) -> SimulationPatchResponse: - try: - simulation = current_app.db.get_simulation(sim_id) - if simulation is None: - raise ValueError(f"Simulation {sim_id} not found.") - status = models_sim.Simulation.Status(body.status) - _update_simulation_status(simulation, status, user) - current_app.db.insert_simulation(simulation) - clear_cache() - return SimulationPatchResponse() - except DatabaseError as err: - raise ResponseException(str(err)) from err + simulation = current_app.db.get_simulation(sim_id) + if simulation is None: + raise ValueError(f"Simulation {sim_id} not found.") + status = models_sim.Simulation.Status(body.status) + _update_simulation_status(simulation, status, user) + current_app.db.insert_simulation(simulation) + clear_cache() + return SimulationPatchResponse() @requires_auth("admin") @pydantic_validate(api) def delete(self, sim_id: str, user: User) -> SimulationDeleteResponse: - try: - simulation = current_app.db.delete_simulation(sim_id) - clear_cache() - files = [] - for file in itertools.chain(simulation.inputs, simulation.outputs): - if file.uri.scheme == "file": - if file.uri.path is None: - raise ValueError("File path not set") - files.append(f"{file.uuid} ({file.uri.path.name})") - file.uri.path.unlink() - if simulation.inputs or simulation.outputs: - first_file = ( - simulation.inputs[0] if simulation.inputs else simulation.outputs[0] - ) - if first_file.uri.path is not None: - directory = first_file.uri.path.parent - if directory != Path() and directory != Path("/"): - directory.rmdir() - return SimulationDeleteResponse( - deleted=DeletedSimulation(simulation=simulation.uuid, files=files) + simulation = current_app.db.delete_simulation(sim_id) + clear_cache() + files = [] + for file in itertools.chain(simulation.inputs, simulation.outputs): + if file.uri.scheme == "file": + if file.uri.path is None: + raise ValueError("File path not set") + files.append(f"{file.uuid} ({file.uri.path.name})") + file.uri.path.unlink() + if simulation.inputs or simulation.outputs: + first_file = ( + simulation.inputs[0] if simulation.inputs else simulation.outputs[0] ) - except DatabaseError as err: - raise ResponseException(str(err)) from err + if first_file.uri.path is not None: + directory = first_file.uri.path.parent + if directory != Path() and directory != Path("/"): + directory.rmdir() + return SimulationDeleteResponse( + deleted=DeletedSimulation(simulation=simulation.uuid, files=files) + ) @api.route("/simulation/metadata/") @@ -421,15 +399,12 @@ class SimulationMeta(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] @pydantic_validate(api) def get(self, sim_id: str, user: User) -> MetadataDataList: - try: - simulation = current_app.db.get_simulation(sim_id) - if simulation: - return MetadataDataList.model_validate( - [meta.data() for meta in simulation.meta] - ) - raise ResponseException("Simulation not found") - except DatabaseError as err: - raise ResponseException(str(err)) from err + simulation = current_app.db.get_simulation(sim_id) + if simulation: + return MetadataDataList.model_validate( + [meta.data() for meta in simulation.meta] + ) + raise ResponseException("Simulation not found") @requires_auth("admin") @pydantic_validate(api) @@ -439,26 +414,23 @@ def patch( user: Optional[User], body: Annotated[MetadataPatchData, Body()], ) -> MetadataDataList: - try: - key = body.key - value = body.value.lower() - simulation = current_app.db.get_simulation(sim_id) - if simulation is None: - raise ResponseException(f"Simulation {sim_id} not found.") - old_values = MetadataDataList.model_validate( - [meta.data() for meta in simulation.find_meta(key)] - ) - if key.lower() != "status": - simulation.set_meta(key, value) - else: - status = models_sim.Simulation.Status(value) - _update_simulation_status(simulation, status, user) + key = body.key + value = body.value.lower() + simulation = current_app.db.get_simulation(sim_id) + if simulation is None: + raise ResponseException(f"Simulation {sim_id} not found.") + old_values = MetadataDataList.model_validate( + [meta.data() for meta in simulation.find_meta(key)] + ) + if key.lower() != "status": + simulation.set_meta(key, value) + else: + status = models_sim.Simulation.Status(value) + _update_simulation_status(simulation, status, user) - current_app.db.insert_simulation(simulation) - clear_cache() - return old_values - except DatabaseError as err: - raise ResponseException(str(err)) from err + current_app.db.insert_simulation(simulation) + clear_cache() + return old_values @requires_auth("admin") @pydantic_validate(api) @@ -468,17 +440,14 @@ def delete( user: Optional[User], body: Annotated[MetadataDeleteData, Body()], ) -> MetadataDeleteResponse: - try: - simulation = current_app.db.get_simulation(sim_id) - if simulation is None: - raise ResponseException(f"Simulation {sim_id} not found.") + simulation = current_app.db.get_simulation(sim_id) + if simulation is None: + raise ResponseException(f"Simulation {sim_id} not found.") - simulation.remove_meta(body.key) - current_app.db.insert_simulation(simulation) - clear_cache() - return MetadataDeleteResponse() - except DatabaseError as err: - raise ResponseException(str(err)) from err + simulation.remove_meta(body.key) + current_app.db.insert_simulation(simulation) + clear_cache() + return MetadataDeleteResponse() @api.route("/validate/") @@ -486,14 +455,11 @@ class ValidateSimulation(Resource): @requires_auth() @pydantic_validate(api) def post(self, sim_id, user: User) -> ValidationResult: - try: - simulation = current_app.db.get_simulation(sim_id) - result = _validate(simulation, user) - current_app.db.insert_simulation(simulation) - clear_cache() - return result - except DatabaseError as err: - raise ResponseException(str(err)) from err + simulation = current_app.db.get_simulation(sim_id) + result = _validate(simulation, user) + current_app.db.insert_simulation(simulation) + clear_cache() + return result @api.route("/trace/") @@ -502,10 +468,7 @@ class SimulationTrace(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] @pydantic_validate(api) def get(self, sim_id: str, user: User) -> SimulationTraceData: - try: - return _build_trace(sim_id) - except DatabaseError as err: - raise ResponseException(str(err)) from err + return _build_trace(sim_id) @api.route("/simulation/package/") diff --git a/src/simdb/remote/apis/watchers.py b/src/simdb/remote/apis/watchers.py index 53e0634c..51940aad 100644 --- a/src/simdb/remote/apis/watchers.py +++ b/src/simdb/remote/apis/watchers.py @@ -1,12 +1,20 @@ -from flask import jsonify, request +from typing import Annotated + from flask_restx import Namespace, Resource -from simdb.database import DatabaseError, models +from simdb.database import models from simdb.notifications import Notification from simdb.remote.core.auth import User, requires_auth from simdb.remote.core.cache import clear_cache -from simdb.remote.core.errors import error +from simdb.remote.core.pydantic_utils import Body, pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import ( + WatcherDeleteRequest, + WatcherDeleteResponse, + WatcherGetResponse, + WatcherPostRequest, + WatcherPostResponse, +) api = Namespace("watchers", path="/") @@ -14,52 +22,43 @@ @api.route("/watchers/") class Watcher(Resource): @requires_auth() - def post(self, sim_id: str, user: User): - try: - data = request.get_json() - if data is None: - return error("No data provided") - - username = data.get("user", user.name) - email = data.get("email", user.email) - - if "notification" not in data: - return error("Watcher notification not provided") + @pydantic_validate(api) + def post( + self, sim_id: str, user: User, data: Annotated[WatcherPostRequest, Body()] + ) -> WatcherPostResponse: + username = data.user or user.name + email = data.email or user.email - notification = getattr(Notification, data["notification"]) + notification = getattr(Notification, data.notification) - watcher = models.Watcher(username, email, notification) - current_app.db.add_watcher(sim_id, watcher) - clear_cache() + watcher = models.Watcher(username, email, notification) + current_app.db.add_watcher(sim_id, watcher) + clear_cache() - if username != user.name: - # TODO: send email to notify user that they have been added as a watcher - pass + if username != user.name: + # TODO: send email to notify user that they have been added as a watcher + pass - return jsonify({"added": {"simulation": sim_id, "watcher": data["user"]}}) - except DatabaseError as err: - return error(str(err)) + return WatcherPostResponse.model_validate( + {"added": {"simulation": sim_id, "watcher": username}} + ) @requires_auth() - def delete(self, sim_id: str, user: User): - try: - data = request.get_json() or {} - username = data.get("user", user.name) + @pydantic_validate(api) + def delete( + self, sim_id: str, user: User, data: Annotated[WatcherDeleteRequest, Body()] + ) -> WatcherDeleteResponse: + username = data.user or user.name - current_app.db.remove_watcher(sim_id, username) - clear_cache() - return jsonify({"removed": {"simulation": sim_id, "watcher": username}}) - except DatabaseError as err: - return error(str(err)) + current_app.db.remove_watcher(sim_id, username) + clear_cache() + return WatcherDeleteResponse.model_validate( + {"removed": {"simulation": sim_id, "watcher": username}} + ) @requires_auth() - def get(self, sim_id: str, user: User): - try: - return jsonify( - [ - watcher.data(recurse=True) - for watcher in current_app.db.list_watchers(sim_id) - ] - ) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> WatcherGetResponse: + return WatcherGetResponse( + [watcher.to_model() for watcher in current_app.db.list_watchers(sim_id)] + ) diff --git a/src/simdb/remote/core/pydantic_utils.py b/src/simdb/remote/core/pydantic_utils.py index 4c61028a..63dc0dd3 100644 --- a/src/simdb/remote/core/pydantic_utils.py +++ b/src/simdb/remote/core/pydantic_utils.py @@ -158,6 +158,7 @@ def pydantic_validate( *, response_model: type[BaseModel] | None = None, error_model: type[BaseModel] = ErrorResponse, + error_codes: tuple[int] = (400,), ) -> Any: """Decorator factory that wires up Pydantic validation for a Flask-RESTX endpoint. @@ -269,7 +270,7 @@ def wrapper(*args, **kwargs): except Exception as err: return Response( response=_error_model(error=str(err)).model_dump_json(), - status=400, + status=400, # HTTP status code 500 would make more sense mimetype="application/json", ) @@ -292,7 +293,8 @@ def wrapper(*args, **kwargs): if _resp_schema is not None: wrapper = ns.response(200, "Success", _resp_schema)(wrapper) if _error_schema is not None: - wrapper = ns.response(400, "Error", _error_schema)(wrapper) + for error_code in error_codes: + wrapper = ns.response(error_code, "Error", _error_schema)(wrapper) return wrapper From 6c41588d6f1d20fdd91f723feff47bc575d161d2 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 26 Feb 2026 15:13:36 +0100 Subject: [PATCH 04/12] Add header and query params to docs --- src/simdb/remote/core/pydantic_utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/simdb/remote/core/pydantic_utils.py b/src/simdb/remote/core/pydantic_utils.py index 63dc0dd3..8e9a1070 100644 --- a/src/simdb/remote/core/pydantic_utils.py +++ b/src/simdb/remote/core/pydantic_utils.py @@ -289,13 +289,30 @@ def wrapper(*args, **kwargs): return result if _body_schema is not None: - wrapper = ns.expect(_body_schema)(wrapper) + wrapper = ns.expect(_body_schema, validate=False)(wrapper) if _resp_schema is not None: wrapper = ns.response(200, "Success", _resp_schema)(wrapper) if _error_schema is not None: for error_code in error_codes: wrapper = ns.response(error_code, "Error", _error_schema)(wrapper) + for _param_name, model_type, source in annotated_params: + if isinstance(source, (Query, Header)): + location = "query" if isinstance(source, Query) else "header" + schema = model_type.model_json_schema() + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + for field_name, field_props in properties.items(): + wrapper = ns.param( + name=field_name, + description=field_props.get("description", ""), + _in=location, + required=(field_name in required_fields), + type=field_props.get("type", "string"), + default=field_props.get("default"), + )(wrapper) + return wrapper return decorator From 138c5565c710d5d3c6fa8e5ea0d1ed41c361d552 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 12 Mar 2026 10:45:13 +0100 Subject: [PATCH 05/12] Add validation to staging directory endpoint --- src/simdb/remote/apis/v1_2/__init__.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/simdb/remote/apis/v1_2/__init__.py b/src/simdb/remote/apis/v1_2/__init__.py index c94b2925..920f303b 100644 --- a/src/simdb/remote/apis/v1_2/__init__.py +++ b/src/simdb/remote/apis/v1_2/__init__.py @@ -1,13 +1,14 @@ from pathlib import Path -from flask import jsonify from flask_restx import Api, Resource from simdb.remote.apis.files import api as file_ns from simdb.remote.apis.metadata import api as metadata_ns from simdb.remote.apis.watchers import api as watcher_ns from simdb.remote.core.auth import TokenAuthenticator, User, requires_auth +from simdb.remote.core.pydantic_utils import pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import StagingDirectoryResponse from .simulations import api as sim_ns @@ -37,7 +38,8 @@ @api.route("/staging_dir/") class StagingDirectory(Resource): @requires_auth() - def get(self, sim_hex: str, user: User): + @pydantic_validate(api) + def get(self, sim_hex: str, user: User) -> StagingDirectoryResponse: upload_dir = current_app.simdb_config.get_string_option( "server.user_upload_folder", default=None ) @@ -49,7 +51,7 @@ def get(self, sim_hex: str, user: User): user_folder = False if not sim_hex: - return jsonify({"staging_dir": upload_dir}) + return StagingDirectoryResponse(staging_dir=Path(upload_dir)) staging_dir = ( Path(current_app.simdb_config.get_string_option("server.upload_folder")) @@ -61,4 +63,4 @@ def get(self, sim_hex: str, user: User): # directory. if user_folder: staging_dir.chmod(0o777) - return jsonify({"staging_dir": str(Path(upload_dir) / sim_hex)}) + return StagingDirectoryResponse(staging_dir=Path(upload_dir) / sim_hex) From 6c8ea68e766c3f515d89e9662d813065b1902f87 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 12 Mar 2026 10:45:33 +0100 Subject: [PATCH 06/12] Add validation to metadata endpoints --- src/simdb/remote/apis/metadata.py | 23 ++++++------- src/simdb/remote/models.py | 54 ++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/src/simdb/remote/apis/metadata.py b/src/simdb/remote/apis/metadata.py index 8b750125..dc47cf47 100644 --- a/src/simdb/remote/apis/metadata.py +++ b/src/simdb/remote/apis/metadata.py @@ -1,10 +1,9 @@ -from flask import jsonify from flask_restx import Namespace, Resource -from simdb.database import DatabaseError from simdb.remote.core.cache import cache, cache_key -from simdb.remote.core.errors import error +from simdb.remote.core.pydantic_utils import pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import MetadataKeyInfoList, MetadataValueList api = Namespace("metadata", path="/") @@ -12,18 +11,16 @@ @api.route("/metadata") class MetaData(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self): - try: - return jsonify(current_app.db.list_metadata_keys()) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self) -> MetadataKeyInfoList: + return MetadataKeyInfoList.model_validate(current_app.db.list_metadata_keys()) @api.route("/metadata/") class MetaDataValues(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, name): - try: - return jsonify(current_app.db.list_metadata_values(name)) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self, name: str) -> MetadataValueList: + return MetadataValueList.model_validate( + current_app.db.list_metadata_values(name) + ) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index a8db2e24..bb011fa7 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -188,6 +188,27 @@ class MetadataDeleteResponse(BaseModel): pass +class MetadataKeyInfo(BaseModel): + """Information about a metadata key.""" + + name: str + """Metadata key name.""" + type: str + """Python type name of the metadata value.""" + + +class MetadataKeyInfoList(RootModel): + """List of metadata key info items.""" + + root: List[MetadataKeyInfo] = [] + + +class MetadataValueList(RootModel): + """List of metadata values for a given key.""" + + root: List[Any] = [] + + class SimulationReference(BaseModel): """Reference to a simulation.""" @@ -289,31 +310,26 @@ class PaginatedResponse(BaseModel, Generic[T]): class PaginationData(BaseModel): - """Pagination parameters from request headers.""" + """Pagination parameters from request headers. - limit: int + Fields are populated from HTTP headers. The field aliases match the + lowercased header names as provided by Werkzeug / ``_validate_param``. + Use ``model_validate`` with ``by_alias=False`` (the default) or pass a + dict with the alias keys; Pydantic will resolve them via the + ``model_config`` ``populate_by_name=True`` setting. + """ + + model_config = ConfigDict(populate_by_name=True, use_attribute_docstrings=True) + + limit: int = Field(100, alias="simdb-result-limit") """Number of items per page.""" - page: int + page: int = Field(1, alias="simdb-page") """Current page number.""" - sort_by: str + sort_by: str = Field("", alias="simdb-sort-by") """Field to sort by.""" - sort_asc: bool + sort_asc: bool = Field(False, alias="simdb-sort-asc") """Whether to sort ascending.""" - @model_validator(mode="before") - @classmethod - def parse_headers(cls, data: Any): - """Parse pagination from HTTP headers.""" - if not isinstance(data, dict): - return data - new_data = { - "limit": data.get("simdb-result-limit", 100), - "page": data.get("simdb-page", 1), - "sort_by": data.get("simdb-sort-by", ""), - "sort_asc": data.get("simdb-sort-asc", False), - } - return new_data - class SimulationTraceData(SimulationData): """Simulation data with status history.""" From c686e473cb71a493db9d8f6c95523cca10b6d829 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 12 Mar 2026 10:46:59 +0100 Subject: [PATCH 07/12] Fix issue with sim children and parents --- src/simdb/remote/apis/v1_2/simulations.py | 20 +++++---- tests/remote/api/test_simulations.py | 55 +++++++++++++++++++++++ 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 419e19e9..76f9bf5e 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -342,14 +342,18 @@ class Simulation(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] @pydantic_validate(api) def get(self, sim_id: str, user: User) -> SimulationDataResponse: - simulation = current_app.db.get_simulation(sim_id) - if simulation: - sim_data = simulation.to_model_with_refs(recurse=True) + try: + simulation = current_app.db.get_simulation(sim_id) + except DatabaseError: + raise ResponseException( + f"Simulation with id {sim_id} could not be found" + ) from None - sim_data.children = current_app.db.get_simulation_children_ref(simulation) - sim_data.parents = current_app.db.get_simulation_children_ref(simulation) - return sim_data - raise ResponseException("Simulation not found") + sim_data = simulation.to_model_with_refs(recurse=True) + + sim_data.children = current_app.db.get_simulation_children_ref(simulation) + sim_data.parents = current_app.db.get_simulation_parents_ref(simulation) + return sim_data @requires_auth("admin") @pydantic_validate(api) @@ -361,7 +365,7 @@ def patch( ) -> SimulationPatchResponse: simulation = current_app.db.get_simulation(sim_id) if simulation is None: - raise ValueError(f"Simulation {sim_id} not found.") + raise ResponseException(f"Simulation {sim_id} not found.") status = models_sim.Simulation.Status(body.status) _update_simulation_status(simulation, status, user) current_app.db.insert_simulation(simulation) diff --git a/tests/remote/api/test_simulations.py b/tests/remote/api/test_simulations.py index 2ff30f73..71e7cf50 100644 --- a/tests/remote/api/test_simulations.py +++ b/tests/remote/api/test_simulations.py @@ -736,6 +736,61 @@ def test_delete_simulation_metadata(client): assert data == simulation_data.simulation.metadata +def test_get_simulation_parents_and_children(client): + """GET /v1.2/simulation/{id} must return correct parents and + children. + """ + shared_checksum = "shared_checksum_for_parent_child_test" + + # Parent simulation: produces an output file with a known checksum + parent_output = generate_simulation_file() + parent_output.checksum = shared_checksum + parent_data = generate_simulation_data( + alias="parent-sim-pc-test", + outputs=[parent_output], + ) + rv_parent = post_simulation(client, parent_data) + assert rv_parent.status_code == 200 + + # Child simulation: consumes an input file with the same checksum + child_input = generate_simulation_file() + child_input.checksum = shared_checksum + child_data = generate_simulation_data( + alias="child-sim-pc-test", + inputs=[child_input], + ) + rv_child = post_simulation(client, child_data) + assert rv_child.status_code == 200 + + # Fetch the parent and verify its children list contains the child + # and its parents list is empty (the parent has no parents of its own). + rv = client.get( + f"/v1.2/simulation/{parent_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv.status_code == 200 + parent_response = SimulationDataResponse.model_validate(rv.json) + + parent_children_uuids = [ref.uuid for ref in parent_response.children] + parent_parents_uuids = [ref.uuid for ref in parent_response.parents] + + assert child_data.simulation.uuid in parent_children_uuids + assert child_data.simulation.uuid not in parent_parents_uuids + + # Fetch the child and verify its parents list contains the parent + # and its children list is empty. + rv = client.get( + f"/v1.2/simulation/{child_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv.status_code == 200 + child_response = SimulationDataResponse.model_validate(rv.json) + + child_parents_uuids = [ref.uuid for ref in child_response.parents] + child_children_uuids = [ref.uuid for ref in child_response.children] + + assert parent_data.simulation.uuid in child_parents_uuids + assert parent_data.simulation.uuid not in child_children_uuids + + def test_trace_endpoint(client): """Test trace endpoint returns valid SimulationTraceData and handles replacement chains.""" From af97aca845064fc01b008c8d3a74d11d34987bff Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 12 Mar 2026 10:47:41 +0100 Subject: [PATCH 08/12] Add seperate return codes for client and server errors --- src/simdb/remote/core/pydantic_utils.py | 59 ++++++++++++++++++++----- 1 file changed, 47 insertions(+), 12 deletions(-) diff --git a/src/simdb/remote/core/pydantic_utils.py b/src/simdb/remote/core/pydantic_utils.py index 8e9a1070..0b7ea3a9 100644 --- a/src/simdb/remote/core/pydantic_utils.py +++ b/src/simdb/remote/core/pydantic_utils.py @@ -5,6 +5,7 @@ import functools import gzip import inspect +import logging from typing import ( Annotated, Any, @@ -13,15 +14,17 @@ ) from flask import Response, request -from flask_restx import Namespace +from flask_restx import Api, Namespace from pydantic import BaseModel, ValidationError from simdb.remote.core.errors import error as _error from simdb.remote.models import ErrorResponse +logger = logging.getLogger(__name__) + class ResponseException(Exception): - """Raised an error has occurred in the request.""" + """Raised when a client error has occurred in the request (HTTP 4xx).""" def __init__(self, message, return_code=400): super().__init__(message) @@ -29,6 +32,15 @@ def __init__(self, message, return_code=400): self.return_code = return_code +class ServerException(Exception): + """Raised when an unexpected server-side error has occurred (HTTP 500).""" + + def __init__(self, message, return_code=500): + super().__init__(message) + self.message = message + self.return_code = return_code + + # --------------------------------------------------------------------------- # Marker classes for Annotated-style parameter declarations # --------------------------------------------------------------------------- @@ -55,7 +67,7 @@ class Query(_ParamSource): # --------------------------------------------------------------------------- -def _register_defs(ns: Namespace, defs): +def _register_defs(ns: Namespace | Api, defs): for name, schema in defs.items(): # Clean the schema: rewrite refs and remove internal $defs clean_schema = copy.deepcopy(schema) @@ -64,7 +76,7 @@ def _register_defs(ns: Namespace, defs): _register_defs(ns, children) -def _collect_and_register(ns: Namespace, model: type[BaseModel]): +def _collect_and_register(ns: Namespace | Api, model: type[BaseModel]): """ Registers a Pydantic model and all nested dependencies to a Flask-RESTX Namespace. """ @@ -154,11 +166,11 @@ def _validate_param( def pydantic_validate( - ns: Namespace, + ns: Namespace | Api, *, response_model: type[BaseModel] | None = None, error_model: type[BaseModel] = ErrorResponse, - error_codes: tuple[int] = (400,), + client_error_codes: tuple[int, ...] = (400,), ) -> Any: """Decorator factory that wires up Pydantic validation for a Flask-RESTX endpoint. @@ -176,6 +188,14 @@ def pydantic_validate( *response_model* is provided explicitly) the return value is automatically serialised with ``model_dump(mode="json")`` and wrapped in ``jsonify``. + Error handling distinguishes between client errors and server errors: + + - :class:`ResponseException` (and request validation errors) → HTTP 4xx + (default 400). Use ``return_code`` to customise (e.g. 404, 422). + - :class:`ServerException` → HTTP 5xx (default 500). Use for explicit + server-side failures. + - Any other unhandled :class:`Exception` → HTTP 500, logged as an error. + Parameters ---------- ns: @@ -183,6 +203,10 @@ def pydantic_validate( response_model: Optional explicit response model. If ``None`` the decorator tries to infer it from the function's return annotation. + error_model: + Pydantic model used to serialise error responses. + client_error_codes: + HTTP status codes to document as client error responses in Swagger. Returns ------- @@ -193,10 +217,10 @@ def pydantic_validate( .. code-block:: python from typing import Annotated - from simdb.remote.core.pydantic_utils import restx_route, Header, Body, Query + from simdb.remote.core.pydantic_utils import pydantic_validate, Header, Body class SimulationList(Resource): - @restx_route(api) + @pydantic_validate(api) def get( self, user: User, @@ -204,7 +228,7 @@ def get( ) -> PaginatedResponse: ... - @restx_route(api) + @pydantic_validate(api) def post( self, user: User, @@ -267,10 +291,18 @@ def wrapper(*args, **kwargs): status=err.return_code, mimetype="application/json", ) + except ServerException as err: + logger.error("Server error in %s: %s", f.__qualname__, err.message) + return Response( + response=_error_model(error=err.message).model_dump_json(), + status=err.return_code, + mimetype="application/json", + ) except Exception as err: + logger.exception("Unhandled exception in %s", f.__qualname__) return Response( response=_error_model(error=str(err)).model_dump_json(), - status=400, # HTTP status code 500 would make more sense + status=500, mimetype="application/json", ) @@ -293,8 +325,11 @@ def wrapper(*args, **kwargs): if _resp_schema is not None: wrapper = ns.response(200, "Success", _resp_schema)(wrapper) if _error_schema is not None: - for error_code in error_codes: - wrapper = ns.response(error_code, "Error", _error_schema)(wrapper) + for error_code in client_error_codes: + wrapper = ns.response(error_code, "Client error", _error_schema)( + wrapper + ) + wrapper = ns.response(500, "Server error", _error_schema)(wrapper) for _param_name, model_type, source in annotated_params: if isinstance(source, (Query, Header)): From 19959e1022702250e0d468f709469d3f9b429e7d Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 17 Mar 2026 15:54:31 +0100 Subject: [PATCH 09/12] Fix numpy array serialization --- src/simdb/remote/models.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index bb011fa7..930ba7d1 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -17,6 +17,7 @@ from urllib.parse import urlencode from uuid import UUID, uuid1 +import numpy as np from pydantic import ( BaseModel as _BaseModel, ) @@ -25,6 +26,7 @@ ConfigDict, Field, PlainSerializer, + field_serializer, model_validator, ) from pydantic import ( @@ -125,14 +127,42 @@ def __getitem__(self, item) -> FileData: return self.root[item] +MetadataValue = Union[ + CustomUUID, + str, + int, + float, + bool, + list, + dict, + np.ndarray, + np.generic, + None, +] +"""Supported types for simulation metadata values.""" + + class MetadataData(BaseModel): """Key-value pair for simulation metadata.""" + model_config = ConfigDict( + use_attribute_docstrings=True, arbitrary_types_allowed=True + ) + element: str """Metadata key/name.""" - value: Union[CustomUUID, Any] + value: MetadataValue """Metadata value.""" + @field_serializer("value") + def serialize_value(self, value: Any, _info: Any) -> Any: + """Serialize numpy arrays and scalars to JSON-compatible types.""" + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, np.generic): + return value.item() + return value + def as_dict(self) -> dict: """Convert to dictionary.""" return {self.element: self.value} From ffae6947a7f32e7addd50a6b138cd99c6bbdc5c1 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 17 Mar 2026 16:10:05 +0100 Subject: [PATCH 10/12] Fix uuid serialization --- src/simdb/remote/models.py | 49 ++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 930ba7d1..b670a925 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -26,7 +26,6 @@ ConfigDict, Field, PlainSerializer, - field_serializer, model_validator, ) from pydantic import ( @@ -127,42 +126,40 @@ def __getitem__(self, item) -> FileData: return self.root[item] -MetadataValue = Union[ - CustomUUID, - str, - int, - float, - bool, - list, - dict, - np.ndarray, - np.generic, - None, +def _coerce_numpy(v: Any) -> Any: + """Convert numpy arrays and scalars to plain Python types before validation.""" + if isinstance(v, np.ndarray): + return v.tolist() + if isinstance(v, np.generic): + return v.item() + return v + + +MetadataValue = Annotated[ + Union[ + CustomUUID, + str, + int, + float, + bool, + list, + dict, + None, + ], + BeforeValidator(_coerce_numpy), ] -"""Supported types for simulation metadata values.""" +"""Supported types for simulation metadata values. Numpy arrays and scalars are +automatically converted to their plain Python equivalents before validation.""" class MetadataData(BaseModel): """Key-value pair for simulation metadata.""" - model_config = ConfigDict( - use_attribute_docstrings=True, arbitrary_types_allowed=True - ) - element: str """Metadata key/name.""" value: MetadataValue """Metadata value.""" - @field_serializer("value") - def serialize_value(self, value: Any, _info: Any) -> Any: - """Serialize numpy arrays and scalars to JSON-compatible types.""" - if isinstance(value, np.ndarray): - return value.tolist() - if isinstance(value, np.generic): - return value.item() - return value - def as_dict(self) -> dict: """Convert to dictionary.""" return {self.element: self.value} From 38d1ed418a0ae04b4bd7cb6d03151485e6ff0d95 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Tue, 17 Mar 2026 16:17:24 +0100 Subject: [PATCH 11/12] Exclude none values in JSON --- src/simdb/remote/core/pydantic_utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/simdb/remote/core/pydantic_utils.py b/src/simdb/remote/core/pydantic_utils.py index 0b7ea3a9..3cd40f90 100644 --- a/src/simdb/remote/core/pydantic_utils.py +++ b/src/simdb/remote/core/pydantic_utils.py @@ -287,34 +287,40 @@ def wrapper(*args, **kwargs): result = f(*args, **kwargs) except ResponseException as err: return Response( - response=_error_model(error=err.message).model_dump_json(), + response=_error_model(error=err.message).model_dump_json( + exclude_none=False + ), status=err.return_code, mimetype="application/json", ) except ServerException as err: logger.error("Server error in %s: %s", f.__qualname__, err.message) return Response( - response=_error_model(error=err.message).model_dump_json(), + response=_error_model(error=err.message).model_dump_json( + exclude_none=False + ), status=err.return_code, mimetype="application/json", ) except Exception as err: logger.exception("Unhandled exception in %s", f.__qualname__) return Response( - response=_error_model(error=str(err)).model_dump_json(), + response=_error_model(error=str(err)).model_dump_json( + exclude_none=False + ), status=500, mimetype="application/json", ) if isinstance(result, ErrorResponse): return Response( - response=result.model_dump_json(), + response=result.model_dump_json(exclude_none=False), status=400, mimetype="application/json", ) if isinstance(result, BaseModel): return Response( - result.model_dump_json(), + result.model_dump_json(exclude_none=False), mimetype="application/json", ) From f63c2de7c203aff42881d46a83e52c0941cd85f7 Mon Sep 17 00:00:00 2001 From: Yannick de Jong Date: Thu, 19 Mar 2026 17:51:34 +0100 Subject: [PATCH 12/12] Fix ndarray handling correctly --- src/simdb/remote/models.py | 60 ++++++++++++++++++++----------- tests/remote/api/test_metadata.py | 28 +++++++++++++++ 2 files changed, 67 insertions(+), 21 deletions(-) diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index b670a925..62cee413 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,5 +1,6 @@ """Pydantic models for the SimDB remote API.""" +import base64 from datetime import datetime as dt from datetime import timezone from pathlib import Path @@ -25,6 +26,7 @@ BeforeValidator, ConfigDict, Field, + InstanceOf, PlainSerializer, model_validator, ) @@ -126,27 +128,43 @@ def __getitem__(self, item) -> FileData: return self.root[item] -def _coerce_numpy(v: Any) -> Any: - """Convert numpy arrays and scalars to plain Python types before validation.""" +def _deserialize_numpy(v: Any) -> Any: if isinstance(v, np.ndarray): - return v.tolist() - if isinstance(v, np.generic): - return v.item() - return v - - -MetadataValue = Annotated[ - Union[ - CustomUUID, - str, - int, - float, - bool, - list, - dict, - None, - ], - BeforeValidator(_coerce_numpy), + return v + if isinstance(v, dict) and v.get("_type") == "numpy.ndarray": + np_bytes = base64.b64decode(v["bytes"].encode()) + return np.frombuffer(np_bytes, dtype=v["dtype"]).reshape(v["shape"]) + raise ValueError(f"Cannot deserialize {v} to np.ndarray") + + +def _serialize_numpy(o: np.ndarray) -> dict: + """Serialize numpy arrays to dict format for the web dashboard.""" + encoded_bytes = base64.b64encode(o.data).decode() + return { + "_type": "numpy.ndarray", + "dtype": o.dtype.name, + "shape": o.shape, + "bytes": encoded_bytes, + } + + +NumpyArray = Annotated[ + InstanceOf[np.ndarray], + BeforeValidator(_deserialize_numpy), + PlainSerializer(_serialize_numpy, return_type=dict), +] + + +MetadataValue = Union[ + CustomUUID, + str, + int, + float, + bool, + list, + dict, + NumpyArray, + None, ] """Supported types for simulation metadata values. Numpy arrays and scalars are automatically converted to their plain Python equivalents before validation.""" @@ -233,7 +251,7 @@ class MetadataKeyInfoList(RootModel): class MetadataValueList(RootModel): """List of metadata values for a given key.""" - root: List[Any] = [] + root: List[MetadataValue] = [] class SimulationReference(BaseModel): diff --git a/tests/remote/api/test_metadata.py b/tests/remote/api/test_metadata.py index 0d66b730..afb31fa7 100644 --- a/tests/remote/api/test_metadata.py +++ b/tests/remote/api/test_metadata.py @@ -1,9 +1,12 @@ +import numpy as np from conftest import ( HEADERS, generate_simulation_data, post_simulation, ) +from simdb.remote.models import MetadataKeyInfoList, MetadataValueList + def test_get_metadata_keys(client): """Test GET /v1.2/metadata endpoint - list all metadata keys.""" @@ -49,6 +52,31 @@ def test_get_metadata_values(client): assert "machine-a" in rv.json or "machine-b" in rv.json +def test_get_metadata_array_value(client): + """Test metadata ndarray storage""" + # Create a simulation with array metadata + array_data = np.array([1, 2, 3]) + simulation_data_1 = generate_simulation_data(metadata={"array_machine": array_data}) + rv_post_1 = post_simulation(client, simulation_data_1) + assert rv_post_1.status_code == 200 + + rv = client.get("/v1.2/metadata", headers=HEADERS) + assert rv.status_code == 200 + mkeys = MetadataKeyInfoList.model_validate_json(rv.data) + for k in mkeys.root: + if k.name == "array_machine": + mkey = k + assert mkey.type == "ndarray" + + rv = client.get("/v1.2/metadata/array_machine", headers=HEADERS) + + assert rv.status_code == 200 + mdata = MetadataValueList.model_validate_json(rv.data) + assert len(mdata.root) == 1 + a = mdata.root[0] + assert isinstance(a, np.ndarray) + + def test_get_metadata_values_nonexistent_key(client): """Test GET /v1.2/metadata/{name} endpoint - non-existent key.""" # Get values for a metadata key that doesn't exist