diff --git a/alembic/versions/28bee3aa2429_convert_metadata_to_json_column.py b/alembic/versions/28bee3aa2429_convert_metadata_to_json_column.py new file mode 100644 index 00000000..b262fa0b --- /dev/null +++ b/alembic/versions/28bee3aa2429_convert_metadata_to_json_column.py @@ -0,0 +1,160 @@ +"""convert_metadata_to_json_column + +Revision ID: 28bee3aa2429 +Revises: 9e9a4a7cd639 +Create Date: 2026-02-26 17:01:30.925750 + +""" + +import json +import pickle +from typing import Any, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import text +from sqlalchemy.dialects import postgresql + +from alembic import op + +revision: str = "28bee3aa2429" +down_revision: Union[str, Sequence[str], None] = "9e9a4a7cd639" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _make_json_serializable(value: Any) -> Any: + """Recursively convert a value to something JSON-serializable. + + Numpy arrays fall through to str(), which uses numpy's print threshold and + truncates large arrays — avoiding multi-hundred-MB JSON for array-valued metadata. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + return [_make_json_serializable(v) for v in value] + if isinstance(value, dict): + return {str(k): _make_json_serializable(v) for k, v in value.items()} + # Covers numpy arrays (truncated by numpy's print threshold), datetimes, etc. + return str(value) + + +def upgrade() -> None: + """Upgrade schema.""" + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Add metadata column only if it doesn't already exist (e.g. created via create_all) + existing_columns = [col["name"] for col in inspector.get_columns("simulations")] + if "metadata" not in existing_columns: + if conn.dialect.name == "postgresql": + op.add_column( + "simulations", + sa.Column( + "metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + ) + else: + op.add_column( + "simulations", sa.Column("metadata", sa.JSON(), nullable=True) + ) + + # Migrate existing data from metadata table if it still exists + if "metadata" in inspector.get_table_names(): + result = conn.execute(text("SELECT DISTINCT sim_id FROM metadata")) + sim_ids = [row[0] for row in result] + + for sim_id in sim_ids: + meta_rows = conn.execute( + text("SELECT element, value FROM metadata WHERE sim_id = :sim_id"), + {"sim_id": sim_id}, + ) + + meta_dict = {} + for element, value in meta_rows: + if value is not None: + try: + unpickled = ( + pickle.loads(value) if isinstance(value, bytes) else value + ) + except Exception: + unpickled = repr(value) + meta_dict[element] = _make_json_serializable(unpickled) + else: + meta_dict[element] = None + + if conn.dialect.name == "postgresql": + conn.execute( + text( + "UPDATE simulations SET metadata = :metadata::jsonb" + " WHERE id = :sim_id" + ), + {"metadata": json.dumps(meta_dict), "sim_id": sim_id}, + ) + else: + conn.execute( + text( + "UPDATE simulations SET metadata = :metadata WHERE id = :sim_id" + ), + {"metadata": json.dumps(meta_dict), "sim_id": sim_id}, + ) + + op.drop_index("metadata_index", table_name="metadata") + op.drop_index(op.f("ix_metadata_sim_id"), table_name="metadata") + op.drop_table("metadata") + + +def downgrade() -> None: + """Downgrade schema.""" + conn = op.get_bind() + + # Recreate metadata table + op.create_table( + "metadata", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("sim_id", sa.Integer(), nullable=True), + sa.Column("element", sa.String(length=250), nullable=False), + sa.Column("value", sa.PickleType(), nullable=True), + sa.ForeignKeyConstraint( + ["sim_id"], + ["simulations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_metadata_sim_id"), "metadata", ["sim_id"], unique=False) + op.create_index("metadata_index", "metadata", ["sim_id", "element"], unique=True) + + # Migrate data back from JSON column to metadata table + if conn.dialect.name == "postgresql": + migration_query = text(""" + INSERT INTO metadata (sim_id, element, value) + SELECT s.id, kv.key, kv.value::text + FROM simulations s, json_each_text(s.metadata::json) kv + WHERE s.metadata IS NOT NULL + """) + conn.execute(migration_query) + else: + result = conn.execute( + text("SELECT id, metadata FROM simulations WHERE metadata IS NOT NULL") + ) + for sim_id, metadata_json in result: + if metadata_json: + try: + meta_dict = json.loads(metadata_json) + for element, value in meta_dict.items(): + # Pickle the value for storage + pickled_value = pickle.dumps(value, 0) + conn.execute( + text( + "INSERT INTO metadata (sim_id, element, value) " + "VALUES (:sim_id, :element, :value)" + ), + { + "sim_id": sim_id, + "element": element, + "value": pickled_value, + }, + ) + except Exception: + pass + + op.drop_column("simulations", "metadata") diff --git a/src/simdb/database/database.py b/src/simdb/database/database.py index 37ba0071..5aa25488 100644 --- a/src/simdb/database/database.py +++ b/src/simdb/database/database.py @@ -8,18 +8,17 @@ import appdirs import sqlalchemy.orm -from sqlalchemy import String, Text, asc, create_engine, desc, func, or_ +from sqlalchemy import String, Text, create_engine, func from sqlalchemy import cast as sql_cast from sqlalchemy import or_ as sql_or from sqlalchemy.exc import DBAPIError, IntegrityError, SQLAlchemyError -from sqlalchemy.orm import Bundle, joinedload, scoped_session, sessionmaker +from sqlalchemy.orm import scoped_session, sessionmaker from simdb.config import Config from simdb.query import QueryType, query_compare from .models import Base from .models.file import File -from .models.metadata import MetaData from .models.simulation import Simulation @@ -176,46 +175,78 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def _get_simulation_data(self, limit, query, meta_keys, page) -> Tuple[int, List]: - if limit: - limit = limit * len(meta_keys) if meta_keys else limit - limit_query = query.limit(limit).offset((page - 1) * limit) - else: - limit_query = self.get_simulation_data(query) - data = {} - for row in limit_query: - data.setdefault( - row.simulation.uuid, - { - "alias": row.simulation.alias, - "uuid": row.simulation.uuid, - "datetime": row.simulation.datetime.isoformat(), - "metadata": [], - }, - ) + def _get_simulation_data( + self, query, meta_keys, limit, page, sort_by="", sort_asc=False + ) -> Tuple[int, List]: + """ + Build simulation data from query results with JSON metadata. + + :param query: SQLAlchemy query object + :param meta_keys: List of metadata keys to include + :param limit: Maximum number of results per page + :param page: Page number (1-indexed) + :param sort_by: Field name to sort by (can be alias/uuid/datetime/metadata key) + :param sort_asc: Sort in ascending order if True, descending if False + :return: Tuple of (total_count, list of simulation dicts) + """ + all_rows = query.all() + total_count = len(all_rows) + + results = [] + for row in all_rows: + sim_data = { + "alias": row.alias, + "uuid": row.uuid, + "datetime": row.datetime.isoformat(), + } + + meta_dict = row._metadata or {} + + sim_data["_meta_dict"] = meta_dict + if meta_keys: - data[row.simulation.uuid]["metadata"].append( - {"element": row.metadata.element, "value": row.metadata.value} - ) - if meta_keys: - return query.count() / len(meta_keys), list(data.values()) - else: - return query.count(), list(data.values()) + sim_data["metadata"] = [ + {"element": k, "value": v} + for k, v in meta_dict.items() + if k in meta_keys + ] + + results.append(sim_data) + + if sort_by: + + def get_sort_key(item): + if sort_by in ("alias", "uuid", "datetime"): + val = item.get(sort_by, "") + else: + val = item.get("_meta_dict", {}).get(sort_by, "") + # Handle None values - put them at the end + if val is None: + return "" if sort_asc else "~" + return str(val).lower() if isinstance(val, str) else str(val) + + results.sort(key=get_sort_key, reverse=not sort_asc) + + for sim_data in results: + sim_data.pop("_meta_dict", None) + + if limit: + start_idx = (page - 1) * limit + end_idx = start_idx + limit + results = results[start_idx:end_idx] + + return total_count, results def _find_simulation(self, sim_ref: str) -> "Simulation": try: sim_uuid = uuid.UUID(sim_ref) simulation = ( - self.session.query(Simulation) - .options(joinedload(Simulation.meta)) - .filter_by(uuid=sim_uuid) - .one_or_none() + self.session.query(Simulation).filter_by(uuid=sim_uuid).one_or_none() ) except ValueError: try: simulation = ( self.session.query(Simulation) - .options(joinedload(Simulation.meta)) .filter( sql_or( sql_cast(Simulation.uuid, Text).startswith(sim_ref), @@ -258,22 +289,10 @@ def list_simulations( :return: A list of Simulations. """ - - if meta_keys: - query = ( - self.session.query(Simulation) - .options(joinedload(Simulation.meta)) - .outerjoin(Simulation.meta) - .filter(MetaData.element.in_(meta_keys)) - ) - if limit: - query = query.limit(limit) - return query.all() - else: - query = self.session.query(Simulation) - if limit: - query = query.limit(limit) - return query.all() + query = self.session.query(Simulation) + if limit: + query = query.limit(limit) + return query.all() def list_simulation_data( self, @@ -286,62 +305,13 @@ def list_simulation_data( """ Return a list of all the simulations stored in the database. - :return: A list of Simulations. + :return: A tuple of (total_count, list of simulation data dicts). """ + query = self.session.query(Simulation) - sort_query = None - if sort_by: - sort_dir = asc if sort_asc else desc - sort_query = ( - self.session.query( - Simulation.id, - func.row_number() - .over(order_by=sort_dir(MetaData.value)) - .label("row_num"), - ) - .join(Simulation.meta) - .filter(MetaData.element == sort_by) - .subquery() - ) - - if meta_keys: - s_b = Bundle( - "simulation", Simulation.alias, Simulation.uuid, Simulation.datetime - ) - m_b = Bundle("metadata", MetaData.element, MetaData.value) - query = self.session.query(s_b, m_b).outerjoin(Simulation.meta) - - names_filters = [] - for name in meta_keys: - if name in ("alias", "uuid"): - continue - names_filters.append(m_b.c.element.ilike(name)) # type: ignore[union-attr] - if names_filters: - query = query.filter(or_(*names_filters)) - - if sort_query is not None: - query = query.join( - sort_query, Simulation.id == sort_query.c.id - ).order_by(sort_query.c.row_num) - - return self._get_simulation_data(limit, query, meta_keys, page) - else: - query = self.session.query( - Simulation.alias, Simulation.uuid, Simulation.datetime - ) - - if sort_query is not None: - query = query.join( - sort_query, Simulation.id == sort_query.c.id - ).order_by(sort_query.c.row_num) - - limit_query = ( - query.limit(limit).offset((page - 1) * limit) if limit else query - ) - return query.count(), [ - {"alias": alias, "uuid": uuid, "datetime": datetime.isoformat()} - for alias, uuid, datetime in limit_query - ] + return self._get_simulation_data( + query, meta_keys, limit, page, sort_by, sort_asc + ) def get_simulation_data(self, query): limit_query = query @@ -372,94 +342,80 @@ def delete_simulation(self, sim_ref: str) -> "Simulation": self.session.commit() return simulation - def _get_metadata( + def _get_sim_ids_from_json( self, constraints: List[Tuple[str, str, "QueryType"]] - ) -> Iterable: - m_b = Bundle("metadata", MetaData.element, MetaData.value) - s_b = Bundle("simulation", Simulation.id, Simulation.alias, Simulation.uuid) - query = self.session.query(m_b, s_b).join(Simulation) + ) -> Iterable[int]: + query = self.session.query( + Simulation.id, + Simulation._metadata, + Simulation.alias, + Simulation.uuid, + Simulation.datetime, + ) + + sim_id_sets = {} for name, value, query_type in constraints: - date_time = datetime.now() - if name == "creation_date": - date_time = datetime.strptime( - value.replace("_", ":"), "%Y-%m-%d %H:%M:%S" - ) - if query == QueryType.NONE: - pass - elif query_type == QueryType.EQ: - if name == "alias": + sim_id_sets[(name, value, query_type)] = set() + + for name, value, query_type in constraints: + if name == "alias": + if query_type == QueryType.EQ: query = query.filter(func.lower(Simulation.alias) == value.lower()) - elif name == "uuid": - query = query.filter(Simulation.uuid == uuid.UUID(value)) - elif name == "creation_date": - query = query.filter(Simulation.datetime == date_time) - elif query_type == QueryType.IN: - if name == "alias": + elif query_type == QueryType.IN: query = query.filter(Simulation.alias.ilike(f"%{value}%")) - elif name == "uuid": + elif query_type == QueryType.NI: + query = query.filter(Simulation.alias.notilike(f"%{value}%")) + elif query_type == QueryType.NE: + query = query.filter(func.lower(Simulation.alias) != value.lower()) + elif name == "uuid": + if query_type == QueryType.EQ: + query = query.filter(Simulation.uuid == uuid.UUID(value)) + elif query_type == QueryType.IN: query = query.filter( - func.REPLACE(cast(Simulation.uuid, String), "-", "").ilike( + func.REPLACE(sql_cast(Simulation.uuid, String), "-", "").ilike( "%{}%".format(value.replace("-", "")) ) ) - elif query_type == QueryType.NI: - if name == "alias": - query = query.filter(Simulation.alias.notilike(f"%{value}%")) - elif name == "uuid": + elif query_type == QueryType.NI: query = query.filter( - func.REPLACE(cast(Simulation.uuid, String), "-", "").notilike( - "%{}%".format(value.replace("-", "")) - ) + func.REPLACE( + sql_cast(Simulation.uuid, String), "-", "" + ).notilike("%{}%".format(value.replace("-", ""))) ) - elif query_type == QueryType.GT: - if name == "creation_date": + elif query_type == QueryType.NE: + query = query.filter(Simulation.uuid != uuid.UUID(value)) + elif name == "creation_date": + date_time = datetime.strptime( + value.replace("_", ":"), "%Y-%m-%d %H:%M:%S" + ) + if query_type == QueryType.EQ: + query = query.filter(Simulation.datetime == date_time) + elif query_type == QueryType.GT: query = query.filter(Simulation.datetime > date_time) - elif query_type == QueryType.GE: - if name == "creation_date": + elif query_type == QueryType.GE: query = query.filter(Simulation.datetime >= date_time) - elif query_type == QueryType.LT: - if name == "creation_date": + elif query_type == QueryType.LT: query = query.filter(Simulation.datetime < date_time) - elif query_type == QueryType.LE: - if name == "creation_date": + elif query_type == QueryType.LE: query = query.filter(Simulation.datetime <= date_time) - elif query_type == QueryType.NE: - if name == "creation_date": + elif query_type == QueryType.NE: query = query.filter(Simulation.datetime != date_time) - if name == "alias": - query = query.filter(func.lower(Simulation.alias) != value.lower()) - if name == "uuid": - query = query.filter(Simulation.uuid != uuid.UUID(value)) - elif name in ("uuid", "alias"): - raise ValueError(f"Invalid query type {query_type} for alias or uuid.") - names_filters = [] - for name, _, _ in constraints: - if name in ("alias", "uuid", "creation_date"): - continue - names_filters.append(MetaData.element.ilike(name)) - if names_filters: - query = query.filter(or_(*names_filters)) - - return query - - def _get_sim_ids( - self, constraints: List[Tuple[str, str, "QueryType"]] - ) -> Iterable[int]: - rows = self._get_metadata(constraints) - sim_id_sets = {} - for name, value, query_type in constraints: - sim_id_sets[(name, value, query_type)] = set() + # Execute query and filter on JSON metadata in Python + rows = query.all() for row in rows: + meta_dict = row._metadata or {} + for name, value, query_type in constraints: - if name in ("alias", "uuid", "creation_date"): - sim_id_sets[(name, value, query_type)].add(row.simulation.id) - if row.metadata.element == name and ( - query_type == QueryType.EXIST - or query_compare(query_type, name, row.metadata.value, value) + if name in ("alias", "uuid", "creation_date") or ( + name in meta_dict + and ( + query_type == QueryType.EXIST + or query_compare(query_type, name, meta_dict[name], value) + ) ): - sim_id_sets[(name, value, query_type)].add(row.simulation.id) + sim_id_sets[(name, value, query_type)].add(row.id) if sim_id_sets: return set.intersection(*sim_id_sets.values()) @@ -475,15 +431,11 @@ def query_meta( :return: """ - sim_ids = self._get_sim_ids(constraints) + sim_ids = self._get_sim_ids_from_json(constraints) if not sim_ids: return [] - query = ( - self.session.query(Simulation) - .options(joinedload(Simulation.meta)) - .filter(Simulation.id.in_(sim_ids)) - ) + query = self.session.query(Simulation).filter(Simulation.id.in_(sim_ids)) return query.all() def query_meta_data( @@ -501,49 +453,15 @@ def query_meta_data( :return: """ - sim_ids = self._get_sim_ids(constraints) + sim_ids = self._get_sim_ids_from_json(constraints) if not sim_ids: return 0, [] - sort_query = None - if sort_by: - sort_dir = asc if sort_asc else desc - sort_query = ( - self.session.query( - Simulation.id, - func.row_number() - .over(order_by=sort_dir(MetaData.value)) - .label("row_num"), - ) - .join(Simulation.meta) - .filter(MetaData.element == sort_by) - .subquery() - ) + query = self.session.query(Simulation).filter(Simulation.id.in_(sim_ids)) - s_b = Bundle( - "simulation", - Simulation.id, - Simulation.alias, - Simulation.uuid, - Simulation.datetime, + return self._get_simulation_data( + query, meta_keys, limit, page, sort_by, sort_asc ) - m_b = Bundle("metadata", MetaData.element, MetaData.value) - if meta_keys: - query = ( - self.session.query(s_b, m_b) - .outerjoin(Simulation.meta) - .filter(s_b.c.id.in_(sim_ids)) # type: ignore[union-attr] - ) - query = query.filter(m_b.c.element.in_(meta_keys)) # type: ignore[union-attr] - else: - query = self.session.query(s_b).filter(s_b.c.id.in_(sim_ids)) # type: ignore[union-attr] - - if sort_query is not None: - query = query.join(sort_query, Simulation.id == sort_query.c.id).order_by( - sort_query.c.row_num - ) - - return self._get_simulation_data(limit, query, meta_keys, page) def get_simulation(self, sim_ref: str) -> "Simulation": """ @@ -611,11 +529,11 @@ def get_metadata(self, sim_ref: str, name: str) -> List[str]: :param sim_ref: the simulation identifier :param name: the metadata key - :return: The matching MetaData. + :return: The matching metadata values. """ simulation = self._find_simulation(sim_ref) self.session.commit() - return [m.value for m in simulation.meta if m.element == name] + return simulation.find_meta(name) def add_watcher(self, sim_ref: str, watcher: "Watcher"): sim = self._find_simulation(sim_ref) @@ -635,28 +553,34 @@ def list_watchers(self, sim_ref: str) -> List["Watcher"]: return self._find_simulation(sim_ref).watchers def list_metadata_keys(self) -> List[dict]: - if self.engine.dialect.name == "postgresql": - query = self.session.query(MetaData.element, MetaData.value).distinct( - MetaData.element - ) - else: - query = self.session.query(MetaData.element, MetaData.value).group_by( - MetaData.element - ) - return [{"name": row[0], "type": type(row[1]).__name__} for row in query.all()] + simulations = self.session.query(Simulation._metadata).all() + + keys_dict = {} + for (meta_dict,) in simulations: + if meta_dict: + for key, value in meta_dict.items(): + if key not in keys_dict: + keys_dict[key] = value + + return [{"name": k, "type": type(v).__name__} for k, v in keys_dict.items()] def list_metadata_values(self, name: str) -> List[str]: if name == "alias": query = self.session.query(Simulation.alias).filter( - Simulation.alias is not None + Simulation.alias.isnot(None) ) + data = [row[0] for row in query.all()] else: - query = ( - self.session.query(MetaData.value) - .filter(MetaData.element == name) - .distinct() - ) - data = [row[0] for row in query.all()] + simulations = self.session.query(Simulation._metadata).all() + values_set = set() + + for (meta_dict,) in simulations: + if meta_dict and name in meta_dict: + val = meta_dict[name] + values_set.add(str(val) if val is not None else None) + + data = list(values_set) + try: return sorted(data) except TypeError: diff --git a/src/simdb/database/models/simulation.py b/src/simdb/database/models/simulation.py index 88b8588a..fff57b6e 100644 --- a/src/simdb/database/models/simulation.py +++ b/src/simdb/database/models/simulation.py @@ -1,8 +1,6 @@ import itertools import sys import uuid -from collections import defaultdict -from collections.abc import Iterable from datetime import datetime from enum import Enum from getpass import getuser @@ -12,8 +10,11 @@ if sys.version_info < (3, 11): from backports.datetime_fromisoformat import MonkeyPatch -from sqlalchemy import Column, ForeignKey, Table +from dateutil import parser as date_parser +from sqlalchemy import JSON, Column, ForeignKey, Table from sqlalchemy import types as sql_types +from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import relationship if "sphinx" in sys.modules: @@ -41,7 +42,6 @@ from .base import Base from .file import File -from .metadata import MetaData from .types import UUID from .utils import checked_get, flatten_dict, unflatten_dict from .watcher import Watcher @@ -80,6 +80,17 @@ def _update_legacy_uri(data_object: DataObject): return URI(f"imas:{backend}?path={path}") +class MetaDataWrapper: + """Temporary wrapper to provide backwards compatibility with MetaData interface.""" + + def __init__(self, element: str, value: Any): + self.element = element + self.value = value + + def data(self, recurse: bool = False) -> Dict[str, Any]: + return {"element": self.element, "value": self.value} + + @inherit_docstrings class Simulation(Base): """ @@ -99,19 +110,43 @@ class Status(Enum): uuid = Column(UUID, nullable=False, unique=True, index=True) alias = Column(sql_types.String(250), nullable=True, unique=True, index=True) datetime = Column(sql_types.DateTime, nullable=False) + _metadata = Column( + "metadata", + MutableDict.as_mutable( + postgresql.JSONB(astext_type=sql_types.Text()).with_variant( + JSON(), "sqlite" + ) + ), + nullable=True, + default=dict, + ) inputs: List["File"] = relationship( "File", secondary=simulation_input_files, backref="input_for" ) outputs: List["File"] = relationship( "File", secondary=simulation_output_files, backref="output_of" ) - meta: List["MetaData"] = relationship( - "MetaData", lazy="raise", cascade="all, delete-orphan" - ) watchers: List["Watcher"] = relationship( "Watcher", secondary=simulation_watchers, lazy="dynamic" ) + @property + def meta(self) -> List[MetaDataWrapper]: + """ + Property to provide backwards compatibility. + Returns a list of MetaDataWrapper objects from the JSON metadata. + """ + meta_dict = self._get_metadata_dict() + return [MetaDataWrapper(k, v) for k, v in meta_dict.items()] + + def _get_metadata_dict(self) -> Dict[str, Any]: + if self._metadata is None: + return {} + return self._metadata + + def _set_metadata_dict(self, meta_dict: Dict[str, Any]) -> None: + self._metadata = meta_dict + def __init__( self, manifest: Union[Manifest, None], config: Optional[Config] = None ) -> None: @@ -123,14 +158,16 @@ def __init__( """ if manifest is None: + self._metadata = {} return self.uuid = uuid.uuid1() self.datetime = datetime.now() + self._metadata = {} # For legacy simulation import responsible_name is from manifest else it will be # the user.email if manifest.responsible_name: - self.meta.append(MetaData("uploaded_by", manifest.responsible_name)) + self.set_meta("uploaded_by", manifest.responsible_name) self.user = getuser() @@ -160,9 +197,7 @@ def __init__( self.inputs.append(file) if all_input_idss: - self.meta.append( - MetaData("input_ids", "[{}]".format(", ".join(all_input_idss))) - ) + self.set_meta("input_ids", "[{}]".format(", ".join(all_input_idss))) all_output_idss = [] @@ -184,7 +219,7 @@ def __init__( flatten_dict(flattened_meta, meta) for key, value in flattened_meta.items(): - self.meta.append(MetaData(key, value)) + self.set_meta(key, value) file = File(output.type, output.uri, all_output_idss, config=config) if output.type == DataObject.Type.IMAS and "path" not in output.uri.query: @@ -193,7 +228,7 @@ def __init__( self.outputs.append(file) if all_output_idss: - self.meta.append(MetaData("ids", "[{}]".format(", ".join(all_output_idss)))) + self.set_meta("ids", "[{}]".format(", ".join(all_output_idss))) flattened_dict: Dict[str, str] = {} flatten_dict(flattened_dict, manifest.metadata) @@ -210,9 +245,7 @@ def __init__( def status(self) -> Optional["Simulation.Status"]: result = self.find_meta("status") if result: - value = ( - result[0].value if result[0].value != "invalidated" else "not validated" - ) + value = result[0] if result[0] != "invalidated" else "not validated" return Simulation.Status(value) return None @@ -229,25 +262,22 @@ def __str__(self): getattr(self, name), ) result += "metadata:\n" - for meta in self.meta: - if ( - isinstance(meta.value, Iterable) - and not isinstance(meta.value, np.ndarray) - and "\n" in meta.value - ): + meta_dict = self._get_metadata_dict() + for element, value in meta_dict.items(): + if isinstance(value, str) and "\n" in value: first_line = True - for line in meta.value.split("\n"): + for line in value.split("\n"): if first_line: - result += f" {meta.element}: {line}\n" + result += f" {element}: {line}\n" elif line: - indent = " " * (len(meta.element) + 2) + indent = " " * (len(element) + 2) result += f" {indent}{line}" first_line = False - elif isinstance(meta.value, np.ndarray): - string = np.array2string(meta.value, threshold=10) - result += f" {meta.element}: {string}\n" + elif isinstance(value, np.ndarray): + string = np.array2string(value, threshold=10) + result += f" {element}: {string}\n" else: - result += f" {meta.element}: {meta.value}\n" + result += f" {element}: {value}\n" result += "inputs:\n" for file in self.inputs: result += f"{file}\n" @@ -256,38 +286,32 @@ def __str__(self): result += f"{file}\n" return result - def find_meta(self, name: str) -> List["MetaData"]: - return [m for m in self.meta if m.element == name] + def find_meta(self, name: str) -> List[Any]: + meta_dict = self._get_metadata_dict() + if name in meta_dict: + return [meta_dict[name]] + return [] def remove_meta(self, name: str) -> None: - self.meta = [m for m in self.meta if m.element != name] + if self._metadata is None: + return + if name in self._metadata: + del self._metadata[name] - def set_meta(self, name: str, value: str) -> None: - for m in self.meta: - if m.element == name: - m.value = value - break - else: - self.meta.append(MetaData(name, value)) + def set_meta(self, name: str, value: Any) -> None: + if self._metadata is None: + self._metadata = {} + self._metadata[name] = value def validate_meta(self) -> None: """ - Check the metadata elements for duplicates, throwing and exception if found. + Check the metadata elements for duplicates, throwing an exception if found. - Duplicates should not be possible but if there is an issue causing them to arise - then at least it will be caught early rather than causing an SQL constraint - failure later. + With JSON storage, duplicates are not possible by design (dict keys are unique), + but we keep this method for backwards compatibility. """ - names = [m.element for m in self.meta] - counts = defaultdict(lambda: 0) - for name in names: - counts[name] += 1 - duplicates = [k for (k, v) in counts.items() if v > 1] - if len(duplicates) > 0: - raise ValueError( - f"Duplicate metadata elements {duplicates} found for simulation " - f"{self.uuid}" - ) + # With JSON/dict storage, duplicates are impossible + pass def file_paths(self) -> Set[Path]: def _get_path(file: File) -> Optional[Path]: @@ -321,7 +345,7 @@ def from_data(cls, data: Dict[str, Union[str, Dict, List]]) -> "Simulation": simulation.alias = checked_get(data, "alias", str) if "datetime" not in data: data["datetime"] = datetime.now().isoformat() - simulation.datetime = datetime.fromisoformat(checked_get(data, "datetime", str)) + simulation.datetime = date_parser.parse(checked_get(data, "datetime", str)) if "inputs" in data: inputs = checked_get(data, "inputs", list) simulation.inputs = [File.from_data(el) for el in inputs] @@ -330,10 +354,13 @@ def from_data(cls, data: Dict[str, Union[str, Dict, List]]) -> "Simulation": simulation.outputs = [File.from_data(el) for el in outputs] if "metadata" in data: metadata = checked_get(data, "metadata", list) + meta_dict = {} for el in metadata: if not isinstance(el, dict): raise Exception("corrupted metadata element - expected dictionary") - simulation.meta.append(MetaData.from_data(el)) + if "element" in el and "value" in el: + meta_dict[el["element"]] = el["value"] + simulation._set_metadata_dict(meta_dict) return simulation def data( @@ -347,13 +374,19 @@ def data( if recurse: data["inputs"] = [f.data(recurse=True) for f in self.inputs] data["outputs"] = [f.data(recurse=True) for f in self.outputs] - data["metadata"] = [m.data(recurse=True) for m in self.meta] + meta_dict = self._get_metadata_dict() + data["metadata"] = [ + {"element": k, "value": v} for k, v in meta_dict.items() + ] elif meta_keys: + meta_dict = self._get_metadata_dict() data["metadata"] = [ - m.data(recurse=True) for m in self.meta if m.element in meta_keys + {"element": k, "value": v} + for k, v in meta_dict.items() + if k in meta_keys ] return data def meta_dict(self) -> Dict[str, Union[Dict, Any]]: - meta = {m.element: m.value for m in self.meta} + meta = self._get_metadata_dict() return unflatten_dict(meta) diff --git a/src/simdb/remote/apis/v1/simulations.py b/src/simdb/remote/apis/v1/simulations.py index d171513d..9ecbb34f 100644 --- a/src/simdb/remote/apis/v1/simulations.py +++ b/src/simdb/remote/apis/v1/simulations.py @@ -10,7 +10,6 @@ from flask_restx import Namespace, Resource from simdb.database import DatabaseError -from simdb.database.models import metadata as models_meta from simdb.database.models import simulation as models_sim from simdb.email.server import EmailServer from simdb.query import QueryType, parse_query_arg @@ -183,7 +182,7 @@ def post(self, user: User): alias = data["simulation"]["alias"] (updated_alias, next_id) = _set_alias(alias) if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.set_meta("seqid", next_id) simulation.alias = updated_alias else: simulation.alias = alias diff --git a/src/simdb/remote/apis/v1_1/simulations.py b/src/simdb/remote/apis/v1_1/simulations.py index 5bc1eadd..28d946e3 100644 --- a/src/simdb/remote/apis/v1_1/simulations.py +++ b/src/simdb/remote/apis/v1_1/simulations.py @@ -10,7 +10,6 @@ from flask_restx import Namespace, Resource from simdb.database import DatabaseError -from simdb.database.models import metadata as models_meta from simdb.database.models import simulation as models_sim from simdb.email.server import EmailServer from simdb.query import QueryType, parse_query_arg @@ -207,13 +206,13 @@ def post(self, user: User): return error("Simulation data not provided") simulation = models_sim.Simulation.from_data(data["simulation"]) - simulation.meta.append(models_meta.MetaData("uploaded_by", user.name)) + simulation.set_meta("uploaded_by", user.name) if "alias" in data["simulation"]: alias = data["simulation"]["alias"] (updated_alias, next_id) = _set_alias(alias) if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.set_meta("seqid", next_id) simulation.alias = updated_alias else: simulation.alias = alias diff --git a/src/simdb/remote/apis/v1_2/simulations.py b/src/simdb/remote/apis/v1_2/simulations.py index 37c4d12e..bd6e0353 100644 --- a/src/simdb/remote/apis/v1_2/simulations.py +++ b/src/simdb/remote/apis/v1_2/simulations.py @@ -12,7 +12,6 @@ from flask_restx import Namespace, Resource from simdb.database import DatabaseError -from simdb.database.models import metadata as models_meta from simdb.database.models import simulation as models_sim from simdb.database.models import watcher as models_watcher from simdb.email.server import EmailServer @@ -130,7 +129,7 @@ def _build_trace(sim_id: str) -> Dict[str, Any]: status = simulation.find_meta("status") if status: - status_value = status[0].value + status_value = status[0] if isinstance(status_value, str): data["status"] = status_value else: @@ -138,19 +137,19 @@ def _build_trace(sim_id: str) -> Dict[str, Any]: status_on_name = str(data["status"]) + "_on" status_on = simulation.find_meta(status_on_name) if status_on: - data[status_on_name] = status_on[0].value + data[status_on_name] = status_on[0] replaces = simulation.find_meta("replaces") if replaces: - data["replaces"] = _build_trace(replaces[0].value) + data["replaces"] = _build_trace(replaces[0]) replaced_on = simulation.find_meta("replaced_on") if replaced_on: - data["deprecated_on"] = replaced_on[0].value + data["deprecated_on"] = replaced_on[0] replaces_reason = simulation.find_meta("replaces_reason") if replaces_reason: - data["replaces_reason"] = replaces_reason[0].value + data["replaces_reason"] = replaces_reason[0] return data @@ -320,7 +319,7 @@ def post(self, user: User): if alias is not None: (updated_alias, next_id) = _set_alias(alias) if updated_alias: - simulation.meta.append(models_meta.MetaData("seqid", next_id)) + simulation.set_meta("seqid", next_id) simulation.alias = updated_alias else: simulation.alias = alias @@ -409,9 +408,9 @@ def post(self, user: User): "development.disable_replaces", default=False ) and replaces - and replaces[0].value + and replaces[0] ): - sim_id = replaces[0].value + sim_id = replaces[0] try: replaces_sim = current_app.db.get_simulation(sim_id) except DatabaseError: @@ -422,7 +421,7 @@ def post(self, user: User): _update_simulation_status( replaces_sim, models_sim.Simulation.Status.DEPRECATED, user ) - replaces_sim.set_meta("replaced_by", simulation.uuid) + replaces_sim.set_meta("replaced_by", simulation.uuid.hex) current_app.db.insert_simulation(replaces_sim) current_app.db.insert_simulation(simulation) @@ -538,7 +537,9 @@ def patch(self, sim_id: str, user: Optional[User] = None): simulation = current_app.db.get_simulation(sim_id) if simulation is None: raise ValueError(f"Simulation {sim_id} not found.") - old_values = [meta.data() for meta in simulation.find_meta(key)] + old_values = [ + {"element": key, "value": v} for v in simulation.find_meta(key) + ] if key.lower() != "status": simulation.set_meta(key, value) else: diff --git a/tests/remote/api/test_simulations.py b/tests/remote/api/test_simulations.py index 2ff30f73..1aae16ab 100644 --- a/tests/remote/api/test_simulations.py +++ b/tests/remote/api/test_simulations.py @@ -170,7 +170,7 @@ def test_post_simulations_with_replaces(client): assert metadata["status"].lower() == "deprecated" # Check replaced_by metadata was added - assert metadata["replaced_by"] == new_simulation_data.simulation.uuid + assert metadata["replaced_by"] == new_simulation_data.simulation.uuid.hex # Verify the new simulation has replaces metadata rv_new_get = client.get(