diff --git a/src/simdb/database/database.py b/src/simdb/database/database.py index 37ba007..ff3ea9c 100644 --- a/src/simdb/database/database.py +++ b/src/simdb/database/database.py @@ -16,6 +16,7 @@ from simdb.config import Config from simdb.query import QueryType, query_compare +from simdb.remote.models import SimulationReference from .models import Base from .models.file import File @@ -226,8 +227,8 @@ def _find_simulation(self, sim_ref: str) -> "Simulation": ) except SQLAlchemyError: simulation = None - if not simulation: - raise DatabaseError(f"Simulation {sim_ref} not found.") from None + if not simulation: + raise DatabaseError(f"Simulation {sim_ref} not found.") from None return simulation def remove(self): @@ -571,6 +572,24 @@ def get_simulation_parents(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_parents_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.input_for.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.outputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_simulation_children(self, simulation: "Simulation") -> List[dict]: subquery = ( self.session.query(File.checksum) @@ -587,6 +606,24 @@ def get_simulation_children(self, simulation: "Simulation") -> List[dict]: ) return [{"uuid": r.uuid, "alias": r.alias} for r in query.all()] + def get_simulation_children_ref( + self, simulation: "Simulation" + ) -> List[SimulationReference]: + subquery = ( + self.session.query(File.checksum) + .filter(File.checksum != "") + .filter(File.output_of.contains(simulation)) + .subquery() + ) + query = ( + self.session.query(Simulation.uuid, Simulation.alias) + .join(Simulation.inputs) + .filter(File.checksum.in_(subquery)) + .filter(Simulation.alias != simulation.alias) + .distinct() + ) + return [SimulationReference(uuid=r.uuid, alias=r.alias) for r in query.all()] + def get_file(self, file_uuid_str: str) -> "File": """ Get the specified file from the database. diff --git a/src/simdb/database/models/file.py b/src/simdb/database/models/file.py index a05aaa5..a43ea96 100644 --- a/src/simdb/database/models/file.py +++ b/src/simdb/database/models/file.py @@ -13,7 +13,8 @@ from simdb.config.config import Config from simdb.docstrings import inherit_docstrings from simdb.imas.checksum import checksum as imas_checksum -from simdb.imas.utils import imas_timestamp +from simdb.imas.utils import imas_files, imas_timestamp +from simdb.remote.models import FileData, FileGetDataResponse, FileInfo from simdb.uda.checksum import checksum as uda_checksum from .base import Base @@ -125,6 +126,23 @@ def from_data(cls, data: Dict) -> "File": file.datetime = date_parser.parse(checked_get(data, "datetime", str)) return file + @classmethod + def from_data_model(cls, data: FileData) -> "File": + data_type = data.type + uri = data.uri + file = File( + DataObject.Type[data_type], urilib.URI(uri), perform_integrity_check=False + ) + file.uuid = data.uuid + file.usage = data.usage + file.checksum = data.checksum + file.purpose = data.purpose + file.sensitivity = data.sensitivity + file.access = data.access + file.embargo = data.embargo + file.datetime = data.datetime + return file + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "uuid": self.uuid, @@ -139,3 +157,41 @@ def data(self, recurse: bool = False) -> Dict[str, str]: "datetime": self.datetime.isoformat(), } return data + + def to_model(self) -> FileData: + return FileData( + type=self.type.name, + uri=str(self.uri), + uuid=self.uuid, + checksum=self.checksum, + datetime=self.datetime, + usage=self.usage, + purpose=self.purpose, + sensitivity=self.sensitivity, + access=self.access, + embargo=self.embargo, + ) + + def to_model_with_path(self) -> FileGetDataResponse: + if self.type.name == "FILE": + if self.uri.path is None: + raise ValueError("File path not set") + files = [FileInfo(path=self.uri.path, checksum=self.checksum)] + else: + files = [ + FileInfo(path=path, checksum=sha1_checksum(URI(f"file:{path}"))) + for path in imas_files(self.uri) + ] + return FileGetDataResponse( + type=self.type.name, + uri=str(self.uri), + uuid=self.uuid, + checksum=self.checksum, + datetime=self.datetime, + usage=self.usage, + purpose=self.purpose, + sensitivity=self.sensitivity, + access=self.access, + embargo=self.embargo, + files=files, + ) diff --git a/src/simdb/database/models/metadata.py b/src/simdb/database/models/metadata.py index 628f158..81bcde9 100644 --- a/src/simdb/database/models/metadata.py +++ b/src/simdb/database/models/metadata.py @@ -4,6 +4,7 @@ from sqlalchemy import types as sql_types from simdb.docstrings import inherit_docstrings +from simdb.remote.models import MetadataData from .base import Base @@ -32,6 +33,11 @@ def from_data(cls, data: Dict) -> "MetaData": meta = MetaData(data["element"], data["value"]) return meta + @classmethod + def from_data_model(cls, data: MetadataData) -> "MetaData": + meta = MetaData(data.element, data.value) + return meta + def data(self, recurse: bool = False) -> Dict[str, str]: data = { "element": self.element, @@ -39,5 +45,8 @@ def data(self, recurse: bool = False) -> Dict[str, str]: } return data + def to_model(self) -> MetadataData: + return MetadataData(element=self.element, value=self.value) + Index("metadata_index", MetaData.sim_id, MetaData.element, unique=True) diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 1ee21ab..b246555 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -9,6 +9,14 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union +from simdb.remote.models import ( + FileDataList, + MetadataDataList, + SimulationData, + SimulationDataResponse, + SimulationTraceData, +) + if sys.version_info < (3, 11): from backports.datetime_fromisoformat import MonkeyPatch @@ -336,6 +344,17 @@ def from_data(cls, data: Dict[str, Union[str, Dict, List]]) -> "Simulation": simulation.meta.append(MetaData.from_data(el)) return simulation + @classmethod + def from_data_model(cls, data: SimulationData) -> "Simulation": + simulation = Simulation(None) + simulation.uuid = data.uuid + simulation.alias = data.alias + simulation.datetime = data.datetime + simulation.inputs = [File.from_data_model(el) for el in data.inputs.root] + simulation.outputs = [File.from_data_model(el) for el in data.outputs.root] + simulation.meta = [MetaData.from_data_model(el) for el in data.metadata.root] + return simulation + def data( self, recurse: bool = False, meta_keys: Optional[List[str]] = None ) -> Dict[str, Union[str, List]]: @@ -354,6 +373,77 @@ def data( ] return data + def to_model( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationData: + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() + if recurse: + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) + elif meta_keys: + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) + return SimulationData( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + ) + + def to_model_with_refs( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationDataResponse: + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() + if recurse: + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) + elif meta_keys: + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) + return SimulationDataResponse( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + parents=[], + children=[], + ) + + def to_model_trace( + self, recurse: bool = False, meta_keys: Optional[List[str]] = None + ) -> SimulationTraceData: + inputs = FileDataList() + outputs = FileDataList() + metadata = MetadataDataList() + if recurse: + inputs = FileDataList([f.to_model() for f in self.inputs]) + outputs = FileDataList([f.to_model() for f in self.outputs]) + metadata = MetadataDataList([m.to_model() for m in self.meta]) + elif meta_keys: + metadata = MetadataDataList( + [m.to_model() for m in self.meta if m.element in meta_keys] + ) + return SimulationTraceData( + uuid=self.uuid, + alias=self.alias, + datetime=self.datetime, + inputs=inputs, + outputs=outputs, + metadata=metadata, + ) + def meta_dict(self) -> Dict[str, Union[Dict, Any]]: meta = {m.element: m.value for m in self.meta} return unflatten_dict(meta) diff --git a/src/simdb/database/models/watcher.py b/src/simdb/database/models/watcher.py index 7785dc9..c7532b7 100644 --- a/src/simdb/database/models/watcher.py +++ b/src/simdb/database/models/watcher.py @@ -7,6 +7,7 @@ from simdb.docstrings import inherit_docstrings from simdb.notifications import Notification +from simdb.remote.models import WatcherData from .base import Base from .types import ChoiceType @@ -59,3 +60,6 @@ def data(self, recurse: bool = False) -> Dict[str, str]: "notification": str(self.notification), } return data + + def to_model(self) -> WatcherData: + return WatcherData.model_validate(self.data()) diff --git a/src/simdb/remote/apis/files.py b/src/simdb/remote/apis/files.py index 3753ff8..e846fa5 100644 --- a/src/simdb/remote/apis/files.py +++ b/src/simdb/remote/apis/files.py @@ -2,7 +2,7 @@ import json import uuid from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, cast +from typing import Dict, Iterable, List, Optional import magic from flask import Response, jsonify, request, send_file, stream_with_context @@ -18,7 +18,9 @@ from simdb.remote.core.auth import User, requires_auth from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path +from simdb.remote.core.pydantic_utils import pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import FileDataList, FileGetDataResponse from simdb.uri import URI api = Namespace("files", path="/") @@ -170,9 +172,10 @@ def _handle_file_upload() -> Response: @api.route("/files") class FileList(Resource): @requires_auth() - def get(self, user: User): + @pydantic_validate(api) + def get(self, user: User) -> FileDataList: files = current_app.db.list_files() - return jsonify([file.data() for file in files]) + return FileDataList.model_validate([file.to_model() for file in files]) @requires_auth() def post(self, user: User): @@ -189,25 +192,10 @@ def post(self, user: User): @api.route("/file/") class File(Resource): @requires_auth() - def get(self, file_uuid: str, user: Optional[User] = None): - try: - file = current_app.db.get_file(file_uuid) - data = cast(Dict[str, Any], file.data(recurse=True)) - if file.type == DataObject.Type.FILE: - data["files"] = [ - { - "path": str(file.uri.path), - "checksum": file.checksum, - } - ] - else: - data["files"] = [ - {"path": str(path), "checksum": sha1_checksum(URI(f"file:{path}"))} - for path in imas_files(file.uri) - ] - return jsonify(data) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self, file_uuid: str, user: Optional[User] = None) -> FileGetDataResponse: + file = current_app.db.get_file(file_uuid) + return file.to_model_with_path() @api.route("/file/download/") diff --git a/src/simdb/remote/apis/metadata.py b/src/simdb/remote/apis/metadata.py index 8b75012..dc47cf4 100644 --- a/src/simdb/remote/apis/metadata.py +++ b/src/simdb/remote/apis/metadata.py @@ -1,10 +1,9 @@ -from flask import jsonify from flask_restx import Namespace, Resource -from simdb.database import DatabaseError from simdb.remote.core.cache import cache, cache_key -from simdb.remote.core.errors import error +from simdb.remote.core.pydantic_utils import pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import MetadataKeyInfoList, MetadataValueList api = Namespace("metadata", path="/") @@ -12,18 +11,16 @@ @api.route("/metadata") class MetaData(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self): - try: - return jsonify(current_app.db.list_metadata_keys()) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self) -> MetadataKeyInfoList: + return MetadataKeyInfoList.model_validate(current_app.db.list_metadata_keys()) @api.route("/metadata/") class MetaDataValues(Resource): @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, name): - try: - return jsonify(current_app.db.list_metadata_values(name)) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self, name: str) -> MetadataValueList: + return MetadataValueList.model_validate( + current_app.db.list_metadata_values(name) + ) diff --git a/src/simdb/remote/apis/v1_2/__init__.py b/src/simdb/remote/apis/v1_2/__init__.py index c94b292..920f303 100644 --- a/src/simdb/remote/apis/v1_2/__init__.py +++ b/src/simdb/remote/apis/v1_2/__init__.py @@ -1,13 +1,14 @@ from pathlib import Path -from flask import jsonify from flask_restx import Api, Resource from simdb.remote.apis.files import api as file_ns from simdb.remote.apis.metadata import api as metadata_ns from simdb.remote.apis.watchers import api as watcher_ns from simdb.remote.core.auth import TokenAuthenticator, User, requires_auth +from simdb.remote.core.pydantic_utils import pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import StagingDirectoryResponse from .simulations import api as sim_ns @@ -37,7 +38,8 @@ @api.route("/staging_dir/") class StagingDirectory(Resource): @requires_auth() - def get(self, sim_hex: str, user: User): + @pydantic_validate(api) + def get(self, sim_hex: str, user: User) -> StagingDirectoryResponse: upload_dir = current_app.simdb_config.get_string_option( "server.user_upload_folder", default=None ) @@ -49,7 +51,7 @@ def get(self, sim_hex: str, user: User): user_folder = False if not sim_hex: - return jsonify({"staging_dir": upload_dir}) + return StagingDirectoryResponse(staging_dir=Path(upload_dir)) staging_dir = ( Path(current_app.simdb_config.get_string_option("server.upload_folder")) @@ -61,4 +63,4 @@ def get(self, sim_hex: str, user: User): # directory. if user_folder: staging_dir.chmod(0o777) - return jsonify({"staging_dir": str(Path(upload_dir) / sim_hex)}) + return StagingDirectoryResponse(staging_dir=Path(upload_dir) / sim_hex) diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 37c4d12..76f9bf5 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -1,14 +1,12 @@ import contextlib import datetime -import gzip import itertools import tarfile from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Annotated, List, Optional, Tuple -from flask import json as flask_json # fallback -from flask import jsonify, request, send_file +from flask import request, send_file from flask_restx import Namespace, Resource from simdb.database import DatabaseError @@ -23,7 +21,31 @@ from simdb.remote.core.cache import cache, cache_key, clear_cache from simdb.remote.core.errors import error from simdb.remote.core.path import find_common_root, secure_path +from simdb.remote.core.pydantic_utils import ( + Body, + Header, + ResponseException, + pydantic_validate, +) from simdb.remote.core.typing import current_app +from simdb.remote.models import ( + DeletedSimulation, + MetadataDataList, + MetadataDeleteData, + MetadataDeleteResponse, + MetadataPatchData, + PaginatedResponse, + PaginationData, + SimulationDataResponse, + SimulationDeleteResponse, + SimulationListItem, + SimulationPatchResponse, + SimulationPostData, + SimulationPostResponse, + SimulationTraceData, + StatusPatchData, + ValidationResult, +) from simdb.uri import URI from simdb.validation import ValidationError, Validator from simdb.validation.file import find_file_validator @@ -55,7 +77,7 @@ def _update_simulation_status( server.send_message(f"Simulation {simulation.alias}", msg, to_addresses) -def _validate(simulation, user) -> Dict: +def _validate(simulation, user) -> ValidationResult: schemas = Validator.validation_schemas(current_app.simdb_config, simulation) try: for schema in schemas: @@ -65,10 +87,7 @@ def _validate(simulation, user) -> Dict: ) except ValidationError as err: _update_simulation_status(simulation, models_sim.Simulation.Status.FAILED, user) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) file_validator_type = current_app.simdb_config.get_string_option( "file_validation.type", default=None @@ -88,16 +107,11 @@ def _validate(simulation, user) -> Dict: _update_simulation_status( simulation, models_sim.Simulation.Status.FAILED, user ) - return { - "passed": False, - "error": str(err), - } + return ValidationResult(passed=False, error=str(err)) else: error("Invalid file validator specified in configuration") - return { - "passed": True, - } + return ValidationResult(passed=True, error=None) def _set_alias(alias: str): @@ -121,130 +135,43 @@ def _set_alias(alias: str): return alias, next_id -def _build_trace(sim_id: str) -> Dict[str, Any]: - try: - simulation = current_app.db.get_simulation(sim_id) - except DatabaseError as err: - return {"error": str(err)} - data: Dict[str, Any] = cast(Dict[str, Any], simulation.data(recurse=False)) - - status = simulation.find_meta("status") - if status: - status_value = status[0].value - if isinstance(status_value, str): - data["status"] = status_value - else: - data["status"] = status_value.value - status_on_name = str(data["status"]) + "_on" - status_on = simulation.find_meta(status_on_name) - if status_on: - data[status_on_name] = status_on[0].value - - replaces = simulation.find_meta("replaces") - if replaces: - data["replaces"] = _build_trace(replaces[0].value) - - replaced_on = simulation.find_meta("replaced_on") - if replaced_on: - data["deprecated_on"] = replaced_on[0].value - - replaces_reason = simulation.find_meta("replaces_reason") - if replaces_reason: - data["replaces_reason"] = replaces_reason[0].value - - return data +def _build_trace(sim_id: str) -> SimulationTraceData: + simulation = current_app.db.get_simulation(sim_id) + data = simulation.to_model_trace(recurse=False) + def get_meta_val(key, default=None): + meta = simulation.find_meta(key) + return meta[0].value if meta else default -def _get_json_aware(force: bool = False, silent: bool = False): - """ - Parse JSON like Flask's request.get_json, but handle Content-Encoding: gzip. - - force/silent mimic request.get_json behavior. - - Uses Flask's JSON provider to ensure identical types/decoding. - """ - - # Match request.get_json content-type check unless forced - if not force: - mimetype = request.mimetype or "" - if mimetype != "application/json": - if silent: - return None - raise ValueError("Invalid Content-Type (application/json required)") - - raw = request.get_data(cache=False) - enc = (request.headers.get("Content-Encoding") or "").lower() - if "gzip" in enc: - # Only decompress if actually gzipped - with contextlib.suppress(OSError): - raw = gzip.decompress(raw) + status_val = get_meta_val("status") + if status_val: + data.status = status_val if isinstance(status_val, str) else status_val.value - # Use the same charset resolution as Flask (defaults to utf-8) - charset = "utf-8" - try: - params = request.mimetype_params or {} - charset = params.get("charset", "utf-8") - except Exception: - pass + status_on_key = f"{data.status}_on" + status_on_val = get_meta_val(status_on_key) + if status_on_val: + setattr(data, status_on_key, status_on_val) - data = raw.decode(charset, errors="strict") + replaces_id = get_meta_val("replaces") + if replaces_id: + data.replaces = _build_trace(replaces_id) - # Use Flask's JSON provider for identical behavior - try: - # for Flask >= 2.2 - loads = current_app.json.loads # type: ignore[unresolved-attribute] - except Exception: - loads = flask_json.loads + data.deprecated_on = get_meta_val("replaced_on") + data.replaces_reason = get_meta_val("replaces_reason") - try: - return loads(data) - except Exception: - if silent: - return None - raise + return data @api.route("/simulations") class SimulationList(Resource): - LIMIT_HEADER = "simdb-result-limit" - PAGE_HEADER = "simdb-page" - SORT_BY_HEADER = "simdb-sort-by" - SORT_ASC_HEADER = "simdb-sort-asc" - - parser = api.parser() - parser.add_argument( - LIMIT_HEADER, location="headers", type=int, help="Limit returned results" - ) - parser.add_argument( - PAGE_HEADER, - location="headers", - type=int, - help="Specify the page of results to return", - ) - parser.add_argument( - SORT_BY_HEADER, - location="headers", - type=str, - help="Specify the field to sort the results by", - ) - parser.add_argument( - SORT_ASC_HEADER, - location="headers", - type=bool, - help="Specify if the results are sorted ascending or descending", - ) - - @api.expect(parser) - @api.response(200, "Success") - @api.response(401, "Unauthorized") @requires_auth() + @pydantic_validate(api) # @cache.cached(key_prefix=cache_key) - def get(self, user: User): - limit = int(request.headers.get(SimulationList.LIMIT_HEADER) or 100) - page = int(request.headers.get(SimulationList.PAGE_HEADER) or 1) - sort_by = request.headers.get(SimulationList.SORT_BY_HEADER, "") - sort_asc = ( - request.headers.get(SimulationList.SORT_ASC_HEADER, "false").lower() - == "true" - ) + def get( + self, + user: User, + pagination: Annotated[PaginationData, Header()], + ) -> PaginatedResponse[SimulationListItem]: names = [] constraints = [] if request.args: @@ -262,347 +189,290 @@ def get(self, user: User): count, data = current_app.db.query_meta_data( constraints, names, - limit=limit, - page=page, - sort_by=sort_by, - sort_asc=sort_asc, + limit=pagination.limit, + page=pagination.page, + sort_by=pagination.sort_by, + sort_asc=pagination.sort_asc, ) else: count, data = current_app.db.list_simulation_data( meta_keys=names, - limit=limit, - page=page, - sort_by=sort_by, - sort_asc=sort_asc, + limit=pagination.limit, + page=pagination.page, + sort_by=pagination.sort_by, + sort_asc=pagination.sort_asc, ) - return jsonify({"count": count, "page": page, "limit": limit, "results": data}) + return PaginatedResponse[SimulationListItem].model_validate( + { + "count": count, + "page": pagination.page, + "limit": pagination.limit, + "results": data, + } + ) @requires_auth() - def post(self, user: User): - try: - # _get_json_aware is a custom function to handle JSON parsing - # similar to Flask's request.get_json, but with gzip support. - # It returns None if the content type is not application/json. - # If silent=True, it returns None instead of raising an error. - # If force=True, it ignores the content type check. - data = _get_json_aware() - if not data: - return error("Invalid or missing JSON data") - - if "simulation" not in data: - return error("Simulation data not provided") - - add_watcher = data.get("add_watcher", True) - - simulation = models_sim.Simulation.from_data(data["simulation"]) - - # Simulation Upload (Push) Date - simulation.datetime = datetime.datetime.now() - - if data["uploaded_by"] is not None: - simulation.set_meta("uploaded_by", data["uploaded_by"]) - elif user.email is not None: - simulation.set_meta("uploaded_by", user.email) - elif user.name is not None: - simulation.set_meta("uploaded_by", user.name) - else: - simulation.set_meta("uploaded_by", "anonymous") - if add_watcher: - simulation.watchers.append( - models_watcher.Watcher( - user.name, user.email, models_watcher.Notification.ALL - ) + @pydantic_validate(api) + def post( + self, + user: User, + body: Annotated[SimulationPostData, Body()], + ) -> SimulationPostResponse: + simulation = models_sim.Simulation.from_data_model(body.simulation) + + # Simulation Upload (Push) Date + simulation.datetime = datetime.datetime.now() + + uploaded_by = body.uploaded_by or user.email or user.name or "anonymous" + + simulation.set_meta("uploaded_by", uploaded_by) + + if body.add_watcher: + simulation.watchers.append( + models_watcher.Watcher( + user.name, user.email, models_watcher.Notification.ALL ) + ) - if "alias" in data["simulation"]: - alias = data["simulation"]["alias"] - if alias is not None: - (updated_alias, next_id) = _set_alias(alias) - if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) - simulation.alias = updated_alias - else: - simulation.alias = alias - else: - simulation.alias = simulation.uuid.hex + alias = body.simulation.alias + if alias is not None: + (updated_alias, next_id) = _set_alias(alias) + if updated_alias: + simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.alias = updated_alias else: - simulation.alias = simulation.uuid.hex + simulation.alias = alias + else: + simulation.alias = simulation.uuid.hex - files = list(itertools.chain(simulation.inputs, simulation.outputs)) - sim_file_paths = simulation.file_paths() - common_root = find_common_root(sim_file_paths) + files = list(itertools.chain(simulation.inputs, simulation.outputs)) + sim_file_paths = simulation.file_paths() + common_root = find_common_root(sim_file_paths) - config = current_app.simdb_config + config = current_app.simdb_config + copy_files = config.get_option("server.copy_files", default=True) + imas_remote_host = config.get_option("server.imas_remote_host", default=None) - if config.get_option("server.copy_files", default=True): - staging_dir = ( - Path(config.get_string_option("server.upload_folder")) - / simulation.uuid.hex - ) + if copy_files or imas_remote_host: + staging_dir = ( + Path(config.get_string_option("server.upload_folder")) + / simulation.uuid.hex + ) - for sim_file in files: - if sim_file.uri.scheme == "file": - if sim_file.uri.path is None: - raise ValueError("Simulation path not set") - path = secure_path(sim_file.uri.path, common_root, staging_dir) - if not path.exists(): - raise ValueError( - f"simulation file {sim_file.uuid} not uploaded" - ) - sim_file.uri = URI(scheme="file", path=path) - elif sim_file.uri.scheme == "imas": + for sim_file in files: + if ( + copy_files + and sim_file.uri.scheme == "file" + and sim_file.uri.path is not None + ): + path = secure_path(sim_file.uri.path, common_root, staging_dir) + if not path.exists(): + raise ResponseException( + f"simulation file {sim_file.uuid} not uploaded" + ) + sim_file.uri = URI(scheme="file", path=path) + elif sim_file.uri.scheme == "imas": + if copy_files: path = secure_path( Path(sim_file.uri.query["path"]), common_root, staging_dir, is_file=common_root is not None, ) - sim_file.uri = convert_uri(sim_file.uri, path, config) - elif config.get_option("server.imas_remote_host", default=None): - staging_dir = ( - Path(config.get_string_option("server.upload_folder")) - / simulation.uuid.hex - ) + else: + path = Path(sim_file.uri.query["path"]) + sim_file.uri = convert_uri(sim_file.uri, path, config) - for sim_file in files: - if sim_file.uri.scheme == "imas": - if config.get_option("server.copy_files", default=True): - path = secure_path( - Path(sim_file.uri.query["path"]), - common_root, - staging_dir, - is_file=common_root is not None, - ) - sim_file.uri = convert_uri(sim_file.uri, path, config) - else: - path = Path(sim_file.uri.query["path"]) - sim_file.uri = convert_uri(sim_file.uri, path, config) - - result = { - "ingested": simulation.uuid.hex, - } + result = SimulationPostResponse( + ingested=simulation.uuid, error=None, validation=None + ) - if current_app.simdb_config.get_option( - "validation.auto_validate", default=False - ): - result["validation"] = _validate(simulation, user) - - if current_app.simdb_config.get_option( - "validation.error_on_fail", default=False - ): - if simulation.status == models_sim.Simulation.Status.NOT_VALIDATED: - raise Exception( - "Validation config option error_on_fail=True without " - "auto_validate=True." - ) - elif simulation.status == models_sim.Simulation.Status.FAILED: - result["error"] = f"""Simulation validation failed and server has - error_on_fail=True.\n{result["validation"]["error"]}""" - response = jsonify(result) - response.status_code = 400 - return response - - replaces = simulation.find_meta("replaces") - if ( - not current_app.simdb_config.get_option( - "development.disable_replaces", default=False + error_on_fail = current_app.simdb_config.get_option( + "validation.error_on_fail", default=False + ) + + if current_app.simdb_config.get_option( + "validation.auto_validate", default=False + ): + result.validation = _validate(simulation, user) + + if not result.validation.passed and error_on_fail: + raise ResponseException( + f"Simulation validation failed and server has " + f"error_on_fail=True.\n{result.validation.error}" ) - and replaces - and replaces[0].value - ): - sim_id = replaces[0].value - try: - replaces_sim = current_app.db.get_simulation(sim_id) - except DatabaseError: - replaces_sim = None - if replaces_sim is None: - pass - else: - _update_simulation_status( - replaces_sim, models_sim.Simulation.Status.DEPRECATED, user - ) - replaces_sim.set_meta("replaced_by", simulation.uuid) - current_app.db.insert_simulation(replaces_sim) + elif error_on_fail: + raise ResponseException( + "Validation config option error_on_fail=True without " + "auto_validate=True." + ) + + disable_replaces = config.get_option( + "development.disable_replaces", default=False + ) + replaces = simulation.find_meta("replaces") + + if not disable_replaces and replaces and replaces[0].value: + sim_id = replaces[0].value + try: + replaces_sim = current_app.db.get_simulation(sim_id) + except DatabaseError: + replaces_sim = None + + if replaces_sim is not None: + _update_simulation_status( + replaces_sim, models_sim.Simulation.Status.DEPRECATED, user + ) + replaces_sim.set_meta("replaced_by", simulation.uuid) + current_app.db.insert_simulation(replaces_sim) - current_app.db.insert_simulation(simulation) - clear_cache() + current_app.db.insert_simulation(simulation) + clear_cache() - with contextlib.suppress(OSError): - create_alias_dir(simulation) + with contextlib.suppress(OSError): + create_alias_dir(simulation) - return jsonify(result) - except (DatabaseError, ValueError) as err: - return error(str(err)) + return result @api.route("/simulation/") class Simulation(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, sim_id: str, user: User): + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> SimulationDataResponse: try: simulation = current_app.db.get_simulation(sim_id) - if simulation: - sim_data = simulation.data(recurse=True) - sim_data["children"] = current_app.db.get_simulation_children( - simulation - ) - sim_data["parents"] = current_app.db.get_simulation_parents(simulation) - return jsonify(sim_data) - return error("Simulation not found") - except DatabaseError as err: - return error(str(err)) + except DatabaseError: + raise ResponseException( + f"Simulation with id {sim_id} could not be found" + ) from None - parser = api.parser() - parser.add_argument( - "status", type=str, location="json", help="status", required=True - ) + sim_data = simulation.to_model_with_refs(recurse=True) + + sim_data.children = current_app.db.get_simulation_children_ref(simulation) + sim_data.parents = current_app.db.get_simulation_parents_ref(simulation) + return sim_data - @api.expect(parser) @requires_auth("admin") - def patch(self, sim_id: str, user: Optional[User] = None): - try: - data = request.get_json() or {} - if "status" not in data: - return error("Status not provided") - simulation = current_app.db.get_simulation(sim_id) - if simulation is None: - raise ValueError(f"Simulation {sim_id} not found.") - status = models_sim.Simulation.Status(data["status"]) - _update_simulation_status(simulation, status, user) - current_app.db.insert_simulation(simulation) - clear_cache() - return {} - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def patch( + self, + sim_id: str, + user: Optional[User], + body: Annotated[StatusPatchData, Body()], + ) -> SimulationPatchResponse: + simulation = current_app.db.get_simulation(sim_id) + if simulation is None: + raise ResponseException(f"Simulation {sim_id} not found.") + status = models_sim.Simulation.Status(body.status) + _update_simulation_status(simulation, status, user) + current_app.db.insert_simulation(simulation) + clear_cache() + return SimulationPatchResponse() @requires_auth("admin") - def delete(self, sim_id: str, user: User): - try: - simulation = current_app.db.delete_simulation(sim_id) - clear_cache() - files = [] - for file in itertools.chain(simulation.inputs, simulation.outputs): - if file.uri.scheme == "file": - if file.uri.path is None: - raise ValueError("File path not set") - files.append(f"{file.uuid} ({file.uri.path.name})") - file.uri.path.unlink() - if simulation.inputs or simulation.outputs: - first_file = ( - simulation.inputs[0] if simulation.inputs else simulation.outputs[0] - ) - if first_file.uri.path is not None: - directory = first_file.uri.path.parent - if directory != Path() and directory != Path("/"): - directory.rmdir() - return jsonify({"deleted": {"simulation": simulation.uuid, "files": files}}) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def delete(self, sim_id: str, user: User) -> SimulationDeleteResponse: + simulation = current_app.db.delete_simulation(sim_id) + clear_cache() + files = [] + for file in itertools.chain(simulation.inputs, simulation.outputs): + if file.uri.scheme == "file": + if file.uri.path is None: + raise ValueError("File path not set") + files.append(f"{file.uuid} ({file.uri.path.name})") + file.uri.path.unlink() + if simulation.inputs or simulation.outputs: + first_file = ( + simulation.inputs[0] if simulation.inputs else simulation.outputs[0] + ) + if first_file.uri.path is not None: + directory = first_file.uri.path.parent + if directory != Path() and directory != Path("/"): + directory.rmdir() + return SimulationDeleteResponse( + deleted=DeletedSimulation(simulation=simulation.uuid, files=files) + ) @api.route("/simulation/metadata/") class SimulationMeta(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, sim_id: str, user: User): - try: - simulation = current_app.db.get_simulation(sim_id) - if simulation: - return jsonify([meta.data() for meta in simulation.meta]) - return error("Simulation not found") - except DatabaseError as err: - return error(str(err)) - - parser = api.parser() - parser.add_argument("key", type=str, location="json", help="status", required=True) - parser.add_argument( - "value", type=str, location="json", help="status", required=True - ) + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> MetadataDataList: + simulation = current_app.db.get_simulation(sim_id) + if simulation: + return MetadataDataList.model_validate( + [meta.data() for meta in simulation.meta] + ) + raise ResponseException("Simulation not found") - @api.expect(parser) @requires_auth("admin") - def patch(self, sim_id: str, user: Optional[User] = None): - try: - data = request.get_json() or {} - - if "key" not in data: - return error("Metadata key not provided") - - if "value" not in data: - return error("New metadata value not provided") - - key = data["key"] - value = data["value"].lower() - simulation = current_app.db.get_simulation(sim_id) - if simulation is None: - raise ValueError(f"Simulation {sim_id} not found.") - old_values = [meta.data() for meta in simulation.find_meta(key)] - if key.lower() != "status": - simulation.set_meta(key, value) - else: - status = models_sim.Simulation.Status(value) - _update_simulation_status(simulation, status, user) - - current_app.db.insert_simulation(simulation) - clear_cache() - return old_values - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def patch( + self, + sim_id: str, + user: Optional[User], + body: Annotated[MetadataPatchData, Body()], + ) -> MetadataDataList: + key = body.key + value = body.value.lower() + simulation = current_app.db.get_simulation(sim_id) + if simulation is None: + raise ResponseException(f"Simulation {sim_id} not found.") + old_values = MetadataDataList.model_validate( + [meta.data() for meta in simulation.find_meta(key)] + ) + if key.lower() != "status": + simulation.set_meta(key, value) + else: + status = models_sim.Simulation.Status(value) + _update_simulation_status(simulation, status, user) - parser_delete = api.parser() - parser_delete.add_argument( - "key", type=str, location="json", help="metadata key", required=True - ) + current_app.db.insert_simulation(simulation) + clear_cache() + return old_values - @api.expect(parser_delete) @requires_auth("admin") - def delete(self, sim_id: str, user: Optional[User] = None): - try: - data = request.get_json() or {} - - if "key" not in data: - return error("Metadata key not provided") - - key = data["key"] - - simulation = current_app.db.get_simulation(sim_id) - if simulation is None: - raise ValueError(f"Simulation {sim_id} not found.") + @pydantic_validate(api) + def delete( + self, + sim_id: str, + user: Optional[User], + body: Annotated[MetadataDeleteData, Body()], + ) -> MetadataDeleteResponse: + simulation = current_app.db.get_simulation(sim_id) + if simulation is None: + raise ResponseException(f"Simulation {sim_id} not found.") - simulation.remove_meta(key) - current_app.db.insert_simulation(simulation) - clear_cache() - return {} - except DatabaseError as err: - return error(str(err)) + simulation.remove_meta(body.key) + current_app.db.insert_simulation(simulation) + clear_cache() + return MetadataDeleteResponse() @api.route("/validate/") class ValidateSimulation(Resource): @requires_auth() - def post(self, sim_id, user: User): - try: - simulation = current_app.db.get_simulation(sim_id) - result = _validate(simulation, user) - current_app.db.insert_simulation(simulation) - clear_cache() - return jsonify(result) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def post(self, sim_id, user: User) -> ValidationResult: + simulation = current_app.db.get_simulation(sim_id) + result = _validate(simulation, user) + current_app.db.insert_simulation(simulation) + clear_cache() + return result @api.route("/trace/") class SimulationTrace(Resource): @requires_auth() @cache.cached(key_prefix=cache_key) # type: ignore[invalid-argument-type] - def get(self, sim_id: str, user: User): - try: - data = _build_trace(sim_id) - return jsonify(data) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> SimulationTraceData: + return _build_trace(sim_id) @api.route("/simulation/package/") diff --git a/src/simdb/remote/apis/watchers.py b/src/simdb/remote/apis/watchers.py index 53e0634..51940aa 100644 --- a/src/simdb/remote/apis/watchers.py +++ b/src/simdb/remote/apis/watchers.py @@ -1,12 +1,20 @@ -from flask import jsonify, request +from typing import Annotated + from flask_restx import Namespace, Resource -from simdb.database import DatabaseError, models +from simdb.database import models from simdb.notifications import Notification from simdb.remote.core.auth import User, requires_auth from simdb.remote.core.cache import clear_cache -from simdb.remote.core.errors import error +from simdb.remote.core.pydantic_utils import Body, pydantic_validate from simdb.remote.core.typing import current_app +from simdb.remote.models import ( + WatcherDeleteRequest, + WatcherDeleteResponse, + WatcherGetResponse, + WatcherPostRequest, + WatcherPostResponse, +) api = Namespace("watchers", path="/") @@ -14,52 +22,43 @@ @api.route("/watchers/") class Watcher(Resource): @requires_auth() - def post(self, sim_id: str, user: User): - try: - data = request.get_json() - if data is None: - return error("No data provided") - - username = data.get("user", user.name) - email = data.get("email", user.email) - - if "notification" not in data: - return error("Watcher notification not provided") + @pydantic_validate(api) + def post( + self, sim_id: str, user: User, data: Annotated[WatcherPostRequest, Body()] + ) -> WatcherPostResponse: + username = data.user or user.name + email = data.email or user.email - notification = getattr(Notification, data["notification"]) + notification = getattr(Notification, data.notification) - watcher = models.Watcher(username, email, notification) - current_app.db.add_watcher(sim_id, watcher) - clear_cache() + watcher = models.Watcher(username, email, notification) + current_app.db.add_watcher(sim_id, watcher) + clear_cache() - if username != user.name: - # TODO: send email to notify user that they have been added as a watcher - pass + if username != user.name: + # TODO: send email to notify user that they have been added as a watcher + pass - return jsonify({"added": {"simulation": sim_id, "watcher": data["user"]}}) - except DatabaseError as err: - return error(str(err)) + return WatcherPostResponse.model_validate( + {"added": {"simulation": sim_id, "watcher": username}} + ) @requires_auth() - def delete(self, sim_id: str, user: User): - try: - data = request.get_json() or {} - username = data.get("user", user.name) + @pydantic_validate(api) + def delete( + self, sim_id: str, user: User, data: Annotated[WatcherDeleteRequest, Body()] + ) -> WatcherDeleteResponse: + username = data.user or user.name - current_app.db.remove_watcher(sim_id, username) - clear_cache() - return jsonify({"removed": {"simulation": sim_id, "watcher": username}}) - except DatabaseError as err: - return error(str(err)) + current_app.db.remove_watcher(sim_id, username) + clear_cache() + return WatcherDeleteResponse.model_validate( + {"removed": {"simulation": sim_id, "watcher": username}} + ) @requires_auth() - def get(self, sim_id: str, user: User): - try: - return jsonify( - [ - watcher.data(recurse=True) - for watcher in current_app.db.list_watchers(sim_id) - ] - ) - except DatabaseError as err: - return error(str(err)) + @pydantic_validate(api) + def get(self, sim_id: str, user: User) -> WatcherGetResponse: + return WatcherGetResponse( + [watcher.to_model() for watcher in current_app.db.list_watchers(sim_id)] + ) diff --git a/src/simdb/remote/app.py b/src/simdb/remote/app.py index 288ffae..5a7f17e 100644 --- a/src/simdb/remote/app.py +++ b/src/simdb/remote/app.py @@ -97,6 +97,7 @@ def create_app( CORS(app, resources={r"/*": {"origins": "*"}}) app.config["TESTING"] = testing app.config["DEBUG"] = debug + app.config["RESTX_INCLUDE_ALL_MODELS"] = True app.config["PROFILE"] = profile app.json_encoder = cast(Type[JSONEncoder], CustomEncoder) app.json_decoder = cast(Type[JSONDecoder], CustomDecoder) diff --git a/src/simdb/remote/core/pydantic_utils.py b/src/simdb/remote/core/pydantic_utils.py new file mode 100644 index 0000000..3cd40f9 --- /dev/null +++ b/src/simdb/remote/core/pydantic_utils.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +import contextlib +import copy +import functools +import gzip +import inspect +import logging +from typing import ( + Annotated, + Any, + get_args, + get_origin, +) + +from flask import Response, request +from flask_restx import Api, Namespace +from pydantic import BaseModel, ValidationError + +from simdb.remote.core.errors import error as _error +from simdb.remote.models import ErrorResponse + +logger = logging.getLogger(__name__) + + +class ResponseException(Exception): + """Raised when a client error has occurred in the request (HTTP 4xx).""" + + def __init__(self, message, return_code=400): + super().__init__(message) + self.message = message + self.return_code = return_code + + +class ServerException(Exception): + """Raised when an unexpected server-side error has occurred (HTTP 500).""" + + def __init__(self, message, return_code=500): + super().__init__(message) + self.message = message + self.return_code = return_code + + +# --------------------------------------------------------------------------- +# Marker classes for Annotated-style parameter declarations +# --------------------------------------------------------------------------- + + +class _ParamSource: + """Base class for parameter source markers.""" + + +class Header(_ParamSource): + """Marker: populate this parameter from ``request.headers``.""" + + +class Body(_ParamSource): + """Marker: populate this parameter from the JSON request body.""" + + +class Query(_ParamSource): + """Marker: populate this parameter from ``request.args`` (query string).""" + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _register_defs(ns: Namespace | Api, defs): + for name, schema in defs.items(): + # Clean the schema: rewrite refs and remove internal $defs + clean_schema = copy.deepcopy(schema) + children = clean_schema.pop("$defs", {}) + ns.schema_model(name, clean_schema) + _register_defs(ns, children) + + +def _collect_and_register(ns: Namespace | Api, model: type[BaseModel]): + """ + Registers a Pydantic model and all nested dependencies to a Flask-RESTX Namespace. + """ + full_schema = model.model_json_schema(ref_template="#/definitions/{model}") + all_defs = full_schema.get("$defs", {}) + + # 1. Register all sub-models found in $defs first + _register_defs(ns, all_defs) + + # 2. Register the root model + root_name = model.__name__ + root_schema = copy.deepcopy(full_schema) + root_schema.pop("$defs", None) + + return ns.schema_model(root_name, root_schema) + + +def _get_annotated_params( + f: Any, +) -> list[tuple[str, type[BaseModel], _ParamSource]]: + """Inspect *f*'s signature and return a list of ``(param_name, model, source)`` + tuples for every parameter annotated as ``Annotated[SomePydanticModel, Source()]``. + """ + results = [] + sig = inspect.signature(f) + + for param_name, param in sig.parameters.items(): + annotation = param.annotation + if annotation is inspect.Parameter.empty: + continue + if get_origin(annotation) is Annotated: + args = get_args(annotation) + if len(args) >= 2: + model_type = args[0] + source = args[1] + if ( + isinstance(source, _ParamSource) + and isinstance(model_type, type) + and issubclass(model_type, BaseModel) + ): + results.append((param_name, model_type, source)) + return results + + +def _validate_param( + model: type[BaseModel], + source: _ParamSource, +) -> BaseModel: + """Validate and return a Pydantic model instance from the appropriate + part of the current Flask request. + + Raises + ------ + ValidationError + If the data does not conform to the model. + ValueError + If the request body is missing or not valid JSON (for Body sources). + """ + if isinstance(source, Header): + # Convert Werkzeug Headers to a plain dict (lowercase keys) + raw = {k.lower(): v for k, v in request.headers.items()} + return model.model_validate(raw) + + elif isinstance(source, Query): + # Convert ImmutableMultiDict to a plain dict (lists for multi-values) + raw = request.args.to_dict(flat=False) + # Flatten single-value lists for convenience + flat = {k: v[0] if len(v) == 1 else v for k, v in raw.items()} + return model.model_validate(flat) + + else: + enc = (request.headers.get("Content-Encoding") or "").lower() + request_data = request.get_data(cache=False) + if request_data is None: + raise ValueError("Invalid or missing JSON body") + + if enc == "gzip": + with contextlib.suppress(OSError): + request_data = gzip.decompress(request_data) + + return model.model_validate_json(request_data) + + +# --------------------------------------------------------------------------- +# FastAPI-style route decorator +# --------------------------------------------------------------------------- + + +def pydantic_validate( + ns: Namespace | Api, + *, + response_model: type[BaseModel] | None = None, + error_model: type[BaseModel] = ErrorResponse, + client_error_codes: tuple[int, ...] = (400,), +) -> Any: + """Decorator factory that wires up Pydantic validation for a Flask-RESTX endpoint. + + Inspects the decorated function's signature for parameters annotated with + ``Annotated[SomePydanticModel, Header()]``, ``Annotated[SomePydanticModel, Body()]`` + or ``Annotated[SomePydanticModel, Query()]``, validates the corresponding parts of + the incoming request, and injects the validated model instances as keyword + arguments. + + All discovered input models are automatically registered with *ns* for + Swagger/OpenAPI documentation. Body models are registered as ``@ns.expect`` + models; header/query models are registered as parser arguments. + + If the function's return annotation is a ``BaseModel`` subclass (or if + *response_model* is provided explicitly) the return value is automatically + serialised with ``model_dump(mode="json")`` and wrapped in ``jsonify``. + + Error handling distinguishes between client errors and server errors: + + - :class:`ResponseException` (and request validation errors) → HTTP 4xx + (default 400). Use ``return_code`` to customise (e.g. 404, 422). + - :class:`ServerException` → HTTP 5xx (default 500). Use for explicit + server-side failures. + - Any other unhandled :class:`Exception` → HTTP 500, logged as an error. + + Parameters + ---------- + ns: + The Flask-RESTX ``Namespace`` (or ``Api``) to register models on. + response_model: + Optional explicit response model. If ``None`` the decorator tries to + infer it from the function's return annotation. + error_model: + Pydantic model used to serialise error responses. + client_error_codes: + HTTP status codes to document as client error responses in Swagger. + + Returns + ------- + A decorator suitable for use on Flask-RESTX ``Resource`` methods. + + Example + ------- + .. code-block:: python + + from typing import Annotated + from simdb.remote.core.pydantic_utils import pydantic_validate, Header, Body + + class SimulationList(Resource): + @pydantic_validate(api) + def get( + self, + user: User, + pagination: Annotated[PaginationData, Header()], + ) -> PaginatedResponse: + ... + + @pydantic_validate(api) + def post( + self, + user: User, + body: Annotated[SimulationPostData, Body()], + ) -> SimulationPostResponse: + ... + """ + + def decorator(f): + annotated_params = _get_annotated_params(f) + + # Determine response model from return annotation if not given explicitly + _response_model = response_model + _error_model = error_model + if _response_model is None: + ret = inspect.signature(f).return_annotation + if ( + ret is not inspect.Parameter.empty + and inspect.isclass(ret) + and issubclass(ret, BaseModel) + ): + _response_model = ret + + # Register all input models with the namespace + _registered = {} + _body_schema = None + for _param_name, model_type, source in annotated_params: + schema = _collect_and_register(ns, model_type) + if isinstance(source, Body): + _body_schema = schema + + # Register response model + _resp_schema = None + if _response_model is not None: + _resp_schema = _collect_and_register(ns, _response_model) + + _error_schema = None + if _error_model is not None: + _error_schema = _collect_and_register(ns, _error_model) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + for param_name, model_type, source in annotated_params: + try: + validated = _validate_param(model_type, source) + except ValueError as exc: + return _error(str(exc)) + except ValidationError as exc: + first_error = exc.errors()[0] + loc = " -> ".join(str(loc) for loc in first_error["loc"]) + msg = f"Validation error at '{loc}': {first_error['msg']}" + return _error(msg) + kwargs[param_name] = validated + + try: + result = f(*args, **kwargs) + except ResponseException as err: + return Response( + response=_error_model(error=err.message).model_dump_json( + exclude_none=False + ), + status=err.return_code, + mimetype="application/json", + ) + except ServerException as err: + logger.error("Server error in %s: %s", f.__qualname__, err.message) + return Response( + response=_error_model(error=err.message).model_dump_json( + exclude_none=False + ), + status=err.return_code, + mimetype="application/json", + ) + except Exception as err: + logger.exception("Unhandled exception in %s", f.__qualname__) + return Response( + response=_error_model(error=str(err)).model_dump_json( + exclude_none=False + ), + status=500, + mimetype="application/json", + ) + + if isinstance(result, ErrorResponse): + return Response( + response=result.model_dump_json(exclude_none=False), + status=400, + mimetype="application/json", + ) + if isinstance(result, BaseModel): + return Response( + result.model_dump_json(exclude_none=False), + mimetype="application/json", + ) + + return result + + if _body_schema is not None: + wrapper = ns.expect(_body_schema, validate=False)(wrapper) + if _resp_schema is not None: + wrapper = ns.response(200, "Success", _resp_schema)(wrapper) + if _error_schema is not None: + for error_code in client_error_codes: + wrapper = ns.response(error_code, "Client error", _error_schema)( + wrapper + ) + wrapper = ns.response(500, "Server error", _error_schema)(wrapper) + + for _param_name, model_type, source in annotated_params: + if isinstance(source, (Query, Header)): + location = "query" if isinstance(source, Query) else "header" + schema = model_type.model_json_schema() + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + for field_name, field_props in properties.items(): + wrapper = ns.param( + name=field_name, + description=field_props.get("description", ""), + _in=location, + required=(field_name in required_fields), + type=field_props.get("type", "string"), + default=field_props.get("default"), + )(wrapper) + + return wrapper + + return decorator diff --git a/src/simdb/remote/models.py b/src/simdb/remote/models.py index 2152d0e..62cee41 100644 --- a/src/simdb/remote/models.py +++ b/src/simdb/remote/models.py @@ -1,5 +1,6 @@ """Pydantic models for the SimDB remote API.""" +import base64 from datetime import datetime as dt from datetime import timezone from pathlib import Path @@ -17,14 +18,21 @@ from urllib.parse import urlencode from uuid import UUID, uuid1 +import numpy as np +from pydantic import ( + BaseModel as _BaseModel, +) from pydantic import ( - BaseModel, BeforeValidator, + ConfigDict, Field, + InstanceOf, PlainSerializer, - RootModel, model_validator, ) +from pydantic import ( + RootModel as _RootModel, +) from simdb.cli.manifest import DataObject @@ -32,6 +40,14 @@ """UUID serialized as a hex string.""" +class BaseModel(_BaseModel): + model_config = ConfigDict(use_attribute_docstrings=True) + + +class RootModel(_RootModel): + model_config = ConfigDict(use_attribute_docstrings=True) + + def _deserialize_custom_uuid(v: Any) -> UUID: """Deserialize CustomUUID format back to UUID.""" if isinstance(v, UUID): @@ -64,7 +80,7 @@ class StatusPatchData(BaseModel): class DeletedSimulation(BaseModel): """Reference to a deleted simulation.""" - uuid: UUID + simulation: UUID """UUID of the deleted simulation.""" files: List[str] """List of deleted file paths.""" @@ -112,12 +128,54 @@ def __getitem__(self, item) -> FileData: return self.root[item] +def _deserialize_numpy(v: Any) -> Any: + if isinstance(v, np.ndarray): + return v + if isinstance(v, dict) and v.get("_type") == "numpy.ndarray": + np_bytes = base64.b64decode(v["bytes"].encode()) + return np.frombuffer(np_bytes, dtype=v["dtype"]).reshape(v["shape"]) + raise ValueError(f"Cannot deserialize {v} to np.ndarray") + + +def _serialize_numpy(o: np.ndarray) -> dict: + """Serialize numpy arrays to dict format for the web dashboard.""" + encoded_bytes = base64.b64encode(o.data).decode() + return { + "_type": "numpy.ndarray", + "dtype": o.dtype.name, + "shape": o.shape, + "bytes": encoded_bytes, + } + + +NumpyArray = Annotated[ + InstanceOf[np.ndarray], + BeforeValidator(_deserialize_numpy), + PlainSerializer(_serialize_numpy, return_type=dict), +] + + +MetadataValue = Union[ + CustomUUID, + str, + int, + float, + bool, + list, + dict, + NumpyArray, + None, +] +"""Supported types for simulation metadata values. Numpy arrays and scalars are +automatically converted to their plain Python equivalents before validation.""" + + class MetadataData(BaseModel): """Key-value pair for simulation metadata.""" element: str """Metadata key/name.""" - value: Union[CustomUUID, Any] + value: MetadataValue """Metadata value.""" def as_dict(self) -> dict: @@ -171,6 +229,31 @@ def as_querystring(self) -> str: return urlencode(self.as_dict()) +class MetadataDeleteResponse(BaseModel): + pass + + +class MetadataKeyInfo(BaseModel): + """Information about a metadata key.""" + + name: str + """Metadata key name.""" + type: str + """Python type name of the metadata value.""" + + +class MetadataKeyInfoList(RootModel): + """List of metadata key info items.""" + + root: List[MetadataKeyInfo] = [] + + +class MetadataValueList(RootModel): + """List of metadata values for a given key.""" + + root: List[MetadataValue] = [] + + class SimulationReference(BaseModel): """Reference to a simulation.""" @@ -206,6 +289,10 @@ class SimulationDataResponse(SimulationData): """Child simulations.""" +class SimulationPatchResponse(BaseModel): + pass + + class SimulationPostData(BaseModel): """Data for creating a new simulation.""" @@ -268,31 +355,26 @@ class PaginatedResponse(BaseModel, Generic[T]): class PaginationData(BaseModel): - """Pagination parameters from request headers.""" + """Pagination parameters from request headers. - limit: int + Fields are populated from HTTP headers. The field aliases match the + lowercased header names as provided by Werkzeug / ``_validate_param``. + Use ``model_validate`` with ``by_alias=False`` (the default) or pass a + dict with the alias keys; Pydantic will resolve them via the + ``model_config`` ``populate_by_name=True`` setting. + """ + + model_config = ConfigDict(populate_by_name=True, use_attribute_docstrings=True) + + limit: int = Field(100, alias="simdb-result-limit") """Number of items per page.""" - page: int + page: int = Field(1, alias="simdb-page") """Current page number.""" - sort_by: str + sort_by: str = Field("", alias="simdb-sort-by") """Field to sort by.""" - sort_asc: bool + sort_asc: bool = Field(False, alias="simdb-sort-asc") """Whether to sort ascending.""" - @model_validator(mode="before") - @classmethod - def parse_headers(cls, data: Any): - """Parse pagination from HTTP headers.""" - if not isinstance(data, dict): - return data - new_data = { - "limit": data.get("simdb-result-limit", 100), - "page": data.get("simdb-page", 1), - "sort_by": data.get("simdb-sort-by", ""), - "sort_asc": data.get("simdb-sort-asc", False), - } - return new_data - class SimulationTraceData(SimulationData): """Simulation data with status history.""" @@ -469,3 +551,10 @@ class StagingDirectoryResponse(BaseModel): staging_dir: Path """Path to the staging dir.""" + + +class ErrorResponse(BaseModel): + """Response model for server errors.""" + + error: str + """Error description.""" diff --git a/tests/remote/api/test_metadata.py b/tests/remote/api/test_metadata.py index 0d66b73..afb31fa 100644 --- a/tests/remote/api/test_metadata.py +++ b/tests/remote/api/test_metadata.py @@ -1,9 +1,12 @@ +import numpy as np from conftest import ( HEADERS, generate_simulation_data, post_simulation, ) +from simdb.remote.models import MetadataKeyInfoList, MetadataValueList + def test_get_metadata_keys(client): """Test GET /v1.2/metadata endpoint - list all metadata keys.""" @@ -49,6 +52,31 @@ def test_get_metadata_values(client): assert "machine-a" in rv.json or "machine-b" in rv.json +def test_get_metadata_array_value(client): + """Test metadata ndarray storage""" + # Create a simulation with array metadata + array_data = np.array([1, 2, 3]) + simulation_data_1 = generate_simulation_data(metadata={"array_machine": array_data}) + rv_post_1 = post_simulation(client, simulation_data_1) + assert rv_post_1.status_code == 200 + + rv = client.get("/v1.2/metadata", headers=HEADERS) + assert rv.status_code == 200 + mkeys = MetadataKeyInfoList.model_validate_json(rv.data) + for k in mkeys.root: + if k.name == "array_machine": + mkey = k + assert mkey.type == "ndarray" + + rv = client.get("/v1.2/metadata/array_machine", headers=HEADERS) + + assert rv.status_code == 200 + mdata = MetadataValueList.model_validate_json(rv.data) + assert len(mdata.root) == 1 + a = mdata.root[0] + assert isinstance(a, np.ndarray) + + def test_get_metadata_values_nonexistent_key(client): """Test GET /v1.2/metadata/{name} endpoint - non-existent key.""" # Get values for a metadata key that doesn't exist diff --git a/tests/remote/api/test_simulations.py b/tests/remote/api/test_simulations.py index 2ff30f7..71e7cf5 100644 --- a/tests/remote/api/test_simulations.py +++ b/tests/remote/api/test_simulations.py @@ -736,6 +736,61 @@ def test_delete_simulation_metadata(client): assert data == simulation_data.simulation.metadata +def test_get_simulation_parents_and_children(client): + """GET /v1.2/simulation/{id} must return correct parents and + children. + """ + shared_checksum = "shared_checksum_for_parent_child_test" + + # Parent simulation: produces an output file with a known checksum + parent_output = generate_simulation_file() + parent_output.checksum = shared_checksum + parent_data = generate_simulation_data( + alias="parent-sim-pc-test", + outputs=[parent_output], + ) + rv_parent = post_simulation(client, parent_data) + assert rv_parent.status_code == 200 + + # Child simulation: consumes an input file with the same checksum + child_input = generate_simulation_file() + child_input.checksum = shared_checksum + child_data = generate_simulation_data( + alias="child-sim-pc-test", + inputs=[child_input], + ) + rv_child = post_simulation(client, child_data) + assert rv_child.status_code == 200 + + # Fetch the parent and verify its children list contains the child + # and its parents list is empty (the parent has no parents of its own). + rv = client.get( + f"/v1.2/simulation/{parent_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv.status_code == 200 + parent_response = SimulationDataResponse.model_validate(rv.json) + + parent_children_uuids = [ref.uuid for ref in parent_response.children] + parent_parents_uuids = [ref.uuid for ref in parent_response.parents] + + assert child_data.simulation.uuid in parent_children_uuids + assert child_data.simulation.uuid not in parent_parents_uuids + + # Fetch the child and verify its parents list contains the parent + # and its children list is empty. + rv = client.get( + f"/v1.2/simulation/{child_data.simulation.uuid.hex}", headers=HEADERS + ) + assert rv.status_code == 200 + child_response = SimulationDataResponse.model_validate(rv.json) + + child_parents_uuids = [ref.uuid for ref in child_response.parents] + child_children_uuids = [ref.uuid for ref in child_response.children] + + assert parent_data.simulation.uuid in child_parents_uuids + assert parent_data.simulation.uuid not in child_children_uuids + + def test_trace_endpoint(client): """Test trace endpoint returns valid SimulationTraceData and handles replacement chains."""