From b47d1e28dbb4b54ee4ec660e08210d469fe6d058 Mon Sep 17 00:00:00 2001 From: JACOB STANLEY Date: Sat, 27 Jun 2026 03:23:13 +0100 Subject: [PATCH 1/3] https://github.com/Menjay7/Traqora.git --- .github/dependabot.yml | 65 ++++++++++++++++++++++++++++++++++ .github/workflows/security.yml | 33 +++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..1ff4e03 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,65 @@ +version: 2 +updates: + # Python dependencies + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "06:00" + open-pull-requests-limit: 10 + labels: + - "dependencies" + - "python" + commit-message: + prefix: "deps(python)" + include: "scope" + groups: + # Group production dependencies together + production-dependencies: + patterns: + - "*" + exclude-patterns: + - "pytest*" + - "mypy*" + - "ruff*" + - "black*" + - "coverage*" + # Group development/test dependencies together + development-dependencies: + patterns: + - "pytest*" + - "mypy*" + - "ruff*" + - "black*" + - "coverage*" + + # Rust dependencies + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "06:00" + open-pull-requests-limit: 10 + labels: + - "dependencies" + - "rust" + commit-message: + prefix: "deps(rust)" + include: "scope" + + # GitHub Actions dependencies + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "06:00" + open-pull-requests-limit: 5 + labels: + - "dependencies" + - "github-actions" + commit-message: + prefix: "deps(actions)" + include: "scope" diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index cd95c26..bc32880 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -10,6 +10,39 @@ on: - cron: "0 6 * * 1" jobs: + # --------------------------------------------------------------------------- + # CodeQL static analysis + # --------------------------------------------------------------------------- + codeql: + name: CodeQL Analysis + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: ["python", "javascript"] + + steps: + - uses: actions/checkout@v4 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + queries: security-extended + + - name: Autobuild + uses: github/codeql-action/autobuild@v3 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:${{matrix.language}}" + # --------------------------------------------------------------------------- # Python security tests (pytest) # --------------------------------------------------------------------------- From 93db1085ef64cb1cdafa6749fb8b3e954da350a5 Mon Sep 17 00:00:00 2001 From: JACOB STANLEY Date: Sun, 28 Jun 2026 19:35:16 +0100 Subject: [PATCH 2/3] https://github.com/Menjay7/astroml.git --- astroml/db/schema.py | 120 ++++++ astroml/tracking/__init__.py | 3 +- astroml/tracking/ab_testing.py | 701 +++++++++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_schema.py | 143 +++++++ 5 files changed, 967 insertions(+), 1 deletion(-) create mode 100644 astroml/tracking/ab_testing.py diff --git a/astroml/db/schema.py b/astroml/db/schema.py index f81ced8..5e563a1 100644 --- a/astroml/db/schema.py +++ b/astroml/db/schema.py @@ -615,3 +615,123 @@ class ModelVersion(Base): name="ck_model_versions_status", ), ) + + +# --------------------------------------------------------------------------- +# A/B Testing Framework +# --------------------------------------------------------------------------- + +class Experiment(Base): + """A/B test experiment for comparing models or prompts.""" + + __tablename__ = "experiments" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(128), nullable=False, unique=True) + description: Mapped[Optional[str]] = mapped_column(Text) + experiment_type: Mapped[str] = mapped_column(String(32), nullable=False) # 'model', 'prompt' + status: Mapped[str] = mapped_column(String(32), nullable=False, server_default="draft") + traffic_allocation: Mapped[float] = mapped_column(Numeric, nullable=False, server_default="1.0") + start_at: Mapped[Optional[datetime]] = mapped_column() + end_at: Mapped[Optional[datetime]] = mapped_column() + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + nullable=False, server_default=func.now(), onupdate=func.now() + ) + + # Relationships + variants: Mapped[list[Variant]] = relationship( + back_populates="experiment", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("ix_experiments_type", "experiment_type"), + Index("ix_experiments_status", "status"), + Index("ix_experiments_start_at", "start_at"), + CheckConstraint( + "experiment_type IN ('model', 'prompt')", + name="ck_experiments_type", + ), + CheckConstraint( + "status IN ('draft', 'running', 'paused', 'completed', 'archived')", + name="ck_experiments_status", + ), + CheckConstraint( + "traffic_allocation >= 0 AND traffic_allocation <= 1", + name="ck_experiments_traffic_allocation", + ), + ) + + +class Variant(Base): + """A variant in an A/B test experiment.""" + + __tablename__ = "variants" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + experiment_id: Mapped[int] = mapped_column( + Integer, ForeignKey("experiments.id"), nullable=False + ) + name: Mapped[str] = mapped_column(String(128), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text) + traffic_weight: Mapped[float] = mapped_column(Numeric, nullable=False, server_default="0.5") + is_control: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="false") + model_version_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("model_versions.id") + ) + config: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) # For prompt variants or model config + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + nullable=False, server_default=func.now(), onupdate=func.now() + ) + + # Relationships + experiment: Mapped[Experiment] = relationship(back_populates="variants") + model_version: Mapped[Optional[ModelVersion]] = relationship() + results: Mapped[list[ExperimentResult]] = relationship( + back_populates="variant", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + UniqueConstraint("experiment_id", "name", name="uq_variants_experiment_name"), + Index("ix_variants_experiment_id", "experiment_id"), + Index("ix_variants_model_version_id", "model_version_id"), + CheckConstraint( + "traffic_weight >= 0 AND traffic_weight <= 1", + name="ck_variants_traffic_weight", + ), + ) + + +class ExperimentResult(Base): + """Individual result from an A/B test experiment.""" + + __tablename__ = "experiment_results" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) + variant_id: Mapped[int] = mapped_column( + Integer, ForeignKey("variants.id"), nullable=False + ) + user_id: Mapped[Optional[str]] = mapped_column(String(128)) # Optional user identifier + session_id: Mapped[Optional[str]] = mapped_column(String(128)) # For session-based analysis + metrics: Mapped[dict] = mapped_column( + JSON().with_variant(JSONB(), "postgresql"), nullable=False + ) # e.g., {"accuracy": 0.95, "latency_ms": 100} + metadata: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) # Additional context + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + + # Relationships + variant: Mapped[Variant] = relationship(back_populates="results") + + __table_args__ = ( + Index("ix_experiment_results_variant_id", "variant_id"), + Index("ix_experiment_results_user_id", "user_id"), + Index("ix_experiment_results_session_id", "session_id"), + Index("ix_experiment_results_created_at", "created_at"), + ) diff --git a/astroml/tracking/__init__.py b/astroml/tracking/__init__.py index cc8a2f8..8da1aaa 100644 --- a/astroml/tracking/__init__.py +++ b/astroml/tracking/__init__.py @@ -1,4 +1,5 @@ +from .ab_testing import ABTestingFramework from .mlflow_tracker import MLflowTracker from .model_registry import ModelRegistry -__all__ = ["MLflowTracker", "ModelRegistry"] +__all__ = ["MLflowTracker", "ModelRegistry", "ABTestingFramework"] diff --git a/astroml/tracking/ab_testing.py b/astroml/tracking/ab_testing.py new file mode 100644 index 0000000..33e446f --- /dev/null +++ b/astroml/tracking/ab_testing.py @@ -0,0 +1,701 @@ +"""A/B testing framework for comparing models and prompts.""" +from __future__ import annotations + +import hashlib +import logging +import random +from datetime import datetime +from typing import Any, Dict, List, Optional + +import numpy as np +from scipy import stats +from sqlalchemy import select +from sqlalchemy.orm import Session + +from astroml.db.schema import Experiment, ExperimentResult, ModelVersion, Variant +from astroml.db.session import get_session + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Experiment Status State Machine +# --------------------------------------------------------------------------- + +VALID_EXPERIMENT_TRANSITIONS = { + "draft": ["running", "archived"], + "running": ["paused", "completed", "archived"], + "paused": ["running", "archived"], + "completed": ["archived"], + "archived": [], # Terminal state +} + +VALID_EXPERIMENT_STATUSES = set(VALID_EXPERIMENT_TRANSITIONS.keys()) + + +class InvalidExperimentStatusError(ValueError): + """Raised when an invalid experiment status transition is attempted.""" + + pass + + +class ABTestingFramework: + """Core class for managing A/B tests for models and prompts. + + Provides experiment management, variant assignment, and result tracking + with statistical analysis capabilities. + """ + + def __init__(self, session: Optional[Session] = None): + """Initialize the A/B testing framework. + + Args: + session: Optional SQLAlchemy session. If not provided, creates a new session. + """ + self._session = session + self._owns_session = session is None + + @property + def session(self) -> Session: + """Get the SQLAlchemy session, creating one if needed.""" + if self._session is None: + self._session = get_session() + return self._session + + def close(self) -> None: + """Close the session if we own it.""" + if self._owns_session and self._session is not None: + self._session.close() + self._session = None + + def __enter__(self) -> "ABTestingFramework": + return self + + def __exit__(self, *_: Any) -> None: + self.close() + + # ------------------------------------------------------------------ + # Experiment CRUD operations + # ------------------------------------------------------------------ + + def create_experiment( + self, + name: str, + experiment_type: str, + description: Optional[str] = None, + traffic_allocation: float = 1.0, + ) -> Experiment: + """Create a new A/B test experiment. + + Args: + name: Unique experiment name + experiment_type: Type of experiment ('model' or 'prompt') + description: Optional experiment description + traffic_allocation: Fraction of traffic to allocate (0.0 to 1.0) + + Returns: + Created Experiment instance + + Raises: + ValueError: If experiment with same name exists or invalid parameters + """ + if experiment_type not in ("model", "prompt"): + raise ValueError(f"experiment_type must be 'model' or 'prompt', got '{experiment_type}'") + + if not 0.0 <= traffic_allocation <= 1.0: + raise ValueError(f"traffic_allocation must be between 0.0 and 1.0, got {traffic_allocation}") + + existing = self.get_experiment_by_name(name) + if existing: + raise ValueError(f"Experiment with name '{name}' already exists") + + experiment = Experiment( + name=name, + description=description, + experiment_type=experiment_type, + traffic_allocation=traffic_allocation, + ) + self.session.add(experiment) + self.session.commit() + self.session.refresh(experiment) + logger.info("Created experiment: %s (id=%d, type=%s)", name, experiment.id, experiment_type) + return experiment + + def get_experiment(self, experiment_id: int) -> Optional[Experiment]: + """Get an experiment by ID. + + Args: + experiment_id: Experiment ID + + Returns: + Experiment instance or None if not found + """ + return self.session.get(Experiment, experiment_id) + + def get_experiment_by_name(self, name: str) -> Optional[Experiment]: + """Get an experiment by name. + + Args: + name: Experiment name + + Returns: + Experiment instance or None if not found + """ + stmt = select(Experiment).where(Experiment.name == name) + return self.session.execute(stmt).scalar_one_or_none() + + def list_experiments( + self, + experiment_type: Optional[str] = None, + status: Optional[str] = None, + ) -> List[Experiment]: + """List experiments with optional filters. + + Args: + experiment_type: Filter by experiment type + status: Filter by status + + Returns: + List of Experiment instances + """ + stmt = select(Experiment) + if experiment_type: + stmt = stmt.where(Experiment.experiment_type == experiment_type) + if status: + stmt = stmt.where(Experiment.status == status) + stmt = stmt.order_by(Experiment.created_at.desc()) + return list(self.session.execute(stmt).scalars().all()) + + def update_experiment( + self, + experiment_id: int, + description: Optional[str] = None, + traffic_allocation: Optional[float] = None, + start_at: Optional[datetime] = None, + end_at: Optional[datetime] = None, + ) -> Optional[Experiment]: + """Update an experiment. + + Args: + experiment_id: Experiment ID + description: New description + traffic_allocation: New traffic allocation + start_at: Start timestamp + end_at: End timestamp + + Returns: + Updated Experiment instance or None if not found + """ + experiment = self.get_experiment(experiment_id) + if not experiment: + return None + + if description is not None: + experiment.description = description + if traffic_allocation is not None: + if not 0.0 <= traffic_allocation <= 1.0: + raise ValueError(f"traffic_allocation must be between 0.0 and 1.0, got {traffic_allocation}") + experiment.traffic_allocation = traffic_allocation + if start_at is not None: + experiment.start_at = start_at + if end_at is not None: + experiment.end_at = end_at + + self.session.commit() + self.session.refresh(experiment) + logger.info("Updated experiment: %s (id=%d)", experiment.name, experiment_id) + return experiment + + def delete_experiment(self, experiment_id: int) -> bool: + """Delete an experiment and all its variants and results. + + Args: + experiment_id: Experiment ID + + Returns: + True if deleted, False if not found + """ + experiment = self.get_experiment(experiment_id) + if not experiment: + return False + + self.session.delete(experiment) + self.session.commit() + logger.info("Deleted experiment: %s (id=%d)", experiment.name, experiment_id) + return True + + # ------------------------------------------------------------------ + # Variant CRUD operations + # ------------------------------------------------------------------ + + def create_variant( + self, + experiment_id: int, + name: str, + traffic_weight: float = 0.5, + is_control: bool = False, + model_version_id: Optional[int] = None, + config: Optional[Dict[str, Any]] = None, + description: Optional[str] = None, + ) -> Variant: + """Create a new variant for an experiment. + + Args: + experiment_id: Parent experiment ID + name: Variant name + traffic_weight: Traffic weight (0.0 to 1.0) + is_control: Whether this is the control variant + model_version_id: Optional model version ID for model experiments + config: Configuration dict (for prompts or model config) + description: Optional variant description + + Returns: + Created Variant instance + + Raises: + ValueError: If variant with same name exists or invalid parameters + """ + if not 0.0 <= traffic_weight <= 1.0: + raise ValueError(f"traffic_weight must be between 0.0 and 1.0, got {traffic_weight}") + + experiment = self.get_experiment(experiment_id) + if not experiment: + raise ValueError(f"Experiment with id {experiment_id} not found") + + existing = self.get_variant(experiment_id, name) + if existing: + raise ValueError(f"Variant '{name}' already exists for experiment {experiment_id}") + + variant = Variant( + experiment_id=experiment_id, + name=name, + description=description, + traffic_weight=traffic_weight, + is_control=is_control, + model_version_id=model_version_id, + config=config, + ) + self.session.add(variant) + self.session.commit() + self.session.refresh(variant) + logger.info( + "Created variant: %s (id=%d, experiment_id=%d)", + name, + variant.id, + experiment_id, + ) + return variant + + def get_variant(self, experiment_id: int, name: str) -> Optional[Variant]: + """Get a variant by experiment ID and name. + + Args: + experiment_id: Experiment ID + name: Variant name + + Returns: + Variant instance or None if not found + """ + stmt = select(Variant).where( + Variant.experiment_id == experiment_id, Variant.name == name + ) + return self.session.execute(stmt).scalar_one_or_none() + + def get_variant_by_id(self, variant_id: int) -> Optional[Variant]: + """Get a variant by ID. + + Args: + variant_id: Variant ID + + Returns: + Variant instance or None if not found + """ + return self.session.get(Variant, variant_id) + + def list_variants(self, experiment_id: int) -> List[Variant]: + """List all variants for an experiment. + + Args: + experiment_id: Experiment ID + + Returns: + List of Variant instances + """ + stmt = select(Variant).where(Variant.experiment_id == experiment_id) + return list(self.session.execute(stmt).scalars().all()) + + def delete_variant(self, variant_id: int) -> bool: + """Delete a variant and all its results. + + Args: + variant_id: Variant ID + + Returns: + True if deleted, False if not found + """ + variant = self.get_variant_by_id(variant_id) + if not variant: + return False + + self.session.delete(variant) + self.session.commit() + logger.info("Deleted variant: %s (id=%d)", variant.name, variant_id) + return True + + # ------------------------------------------------------------------ + # Experiment lifecycle management + # ------------------------------------------------------------------ + + def start_experiment(self, experiment_id: int) -> Optional[Experiment]: + """Start an experiment. + + Args: + experiment_id: Experiment ID + + Returns: + Updated Experiment or None if not found + + Raises: + InvalidExperimentStatusError: If experiment cannot be started + """ + experiment = self.get_experiment(experiment_id) + if not experiment: + return None + + self._validate_experiment_status_transition(experiment.status, "running") + experiment.status = "running" + experiment.start_at = datetime.now(datetime.UTC) + + self.session.commit() + self.session.refresh(experiment) + logger.info("Started experiment: %s (id=%d)", experiment.name, experiment_id) + return experiment + + def pause_experiment(self, experiment_id: int) -> Optional[Experiment]: + """Pause an experiment. + + Args: + experiment_id: Experiment ID + + Returns: + Updated Experiment or None if not found + + Raises: + InvalidExperimentStatusError: If experiment cannot be paused + """ + experiment = self.get_experiment(experiment_id) + if not experiment: + return None + + self._validate_experiment_status_transition(experiment.status, "paused") + experiment.status = "paused" + + self.session.commit() + self.session.refresh(experiment) + logger.info("Paused experiment: %s (id=%d)", experiment.name, experiment_id) + return experiment + + def complete_experiment(self, experiment_id: int) -> Optional[Experiment]: + """Complete an experiment. + + Args: + experiment_id: Experiment ID + + Returns: + Updated Experiment or None if not found + + Raises: + InvalidExperimentStatusError: If experiment cannot be completed + """ + experiment = self.get_experiment(experiment_id) + if not experiment: + return None + + self._validate_experiment_status_transition(experiment.status, "completed") + experiment.status = "completed" + experiment.end_at = datetime.now(datetime.UTC) + + self.session.commit() + self.session.refresh(experiment) + logger.info("Completed experiment: %s (id=%d)", experiment.name, experiment_id) + return experiment + + def archive_experiment(self, experiment_id: int) -> Optional[Experiment]: + """Archive an experiment. + + Args: + experiment_id: Experiment ID + + Returns: + Updated Experiment or None if not found + + Raises: + InvalidExperimentStatusError: If experiment cannot be archived + """ + experiment = self.get_experiment(experiment_id) + if not experiment: + return None + + self._validate_experiment_status_transition(experiment.status, "archived") + experiment.status = "archived" + + self.session.commit() + self.session.refresh(experiment) + logger.info("Archived experiment: %s (id=%d)", experiment.name, experiment_id) + return experiment + + @staticmethod + def _validate_experiment_status_transition(from_status: str, to_status: str) -> None: + """Validate that an experiment status transition is allowed. + + Args: + from_status: Current status + to_status: Target status + + Raises: + InvalidExperimentStatusError: If transition is not allowed + """ + if to_status not in VALID_EXPERIMENT_STATUSES: + raise InvalidExperimentStatusError(f"Invalid target status: '{to_status}'") + + if from_status == to_status: + return # No-op transition is allowed + + allowed_transitions = VALID_EXPERIMENT_TRANSITIONS.get(from_status, []) + if to_status not in allowed_transitions: + raise InvalidExperimentStatusError( + f"Cannot transition from '{from_status}' to '{to_status}'. " + f"Allowed transitions from '{from_status}': {allowed_transitions}" + ) + + # ------------------------------------------------------------------ + # Variant assignment + # ------------------------------------------------------------------ + + def assign_variant( + self, + experiment_id: int, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> Optional[Variant]: + """Assign a variant to a user/session based on traffic weights. + + Uses deterministic hashing for consistent assignment across requests. + + Args: + experiment_id: Experiment ID + user_id: Optional user identifier + session_id: Optional session identifier + + Returns: + Assigned Variant or None if experiment not found/not running + + Raises: + ValueError: If neither user_id nor session_id provided + """ + if not user_id and not session_id: + raise ValueError("Either user_id or session_id must be provided") + + experiment = self.get_experiment(experiment_id) + if not experiment or experiment.status != "running": + return None + + variants = self.list_variants(experiment_id) + if not variants: + return None + + # Normalize traffic weights to sum to 1 + total_weight = sum(v.traffic_weight for v in variants) + if total_weight == 0: + return None + + # Use deterministic hashing for consistent assignment + identifier = user_id or session_id + hash_value = int(hashlib.md5(f"{experiment_id}:{identifier}".encode()).hexdigest(), 16) + hash_float = (hash_value % 10000) / 10000.0 + + # Select variant based on cumulative weights + cumulative = 0.0 + for variant in variants: + cumulative += variant.traffic_weight / total_weight + if hash_float < cumulative: + return variant + + return variants[-1] # Fallback to last variant + + # ------------------------------------------------------------------ + # Result tracking + # ------------------------------------------------------------------ + + def record_result( + self, + variant_id: int, + metrics: Dict[str, float], + user_id: Optional[str] = None, + session_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> ExperimentResult: + """Record a result for a variant. + + Args: + variant_id: Variant ID + metrics: Dictionary of metric values (e.g., {"accuracy": 0.95}) + user_id: Optional user identifier + session_id: Optional session identifier + metadata: Optional additional context + + Returns: + Created ExperimentResult instance + + Raises: + ValueError: If variant not found + """ + variant = self.get_variant_by_id(variant_id) + if not variant: + raise ValueError(f"Variant with id {variant_id} not found") + + result = ExperimentResult( + variant_id=variant_id, + user_id=user_id, + session_id=session_id, + metrics=metrics, + metadata=metadata, + ) + self.session.add(result) + self.session.commit() + self.session.refresh(result) + logger.debug("Recorded result for variant: %s (id=%d)", variant.name, variant_id) + return result + + def get_variant_results( + self, + variant_id: int, + metric_name: Optional[str] = None, + ) -> List[ExperimentResult]: + """Get results for a variant, optionally filtered by metric. + + Args: + variant_id: Variant ID + metric_name: Optional metric name to filter + + Returns: + List of ExperimentResult instances + """ + stmt = select(ExperimentResult).where(ExperimentResult.variant_id == variant_id) + results = list(self.session.execute(stmt).scalars().all()) + + if metric_name: + results = [r for r in results if metric_name in r.metrics] + + return results + + # ------------------------------------------------------------------ + # Statistical analysis + # ------------------------------------------------------------------ + + def compare_variants( + self, + experiment_id: int, + metric_name: str, + control_variant_name: Optional[str] = None, + ) -> Dict[str, Any]: + """Compare variants using statistical tests. + + Args: + experiment_id: Experiment ID + metric_name: Metric to compare + control_variant_name: Optional control variant name (auto-detect if not provided) + + Returns: + Dictionary with comparison results including: + - variant_stats: Statistics for each variant + - pairwise_tests: Statistical test results between variants + - winner: Best performing variant + """ + variants = self.list_variants(experiment_id) + if len(variants) < 2: + raise ValueError("Experiment must have at least 2 variants to compare") + + # Identify control variant + control = None + if control_variant_name: + control = next((v for v in variants if v.name == control_variant_name), None) + else: + control = next((v for v in variants if v.is_control), None) + + if not control: + control = variants[0] # Use first variant as control + + # Collect metrics for each variant + variant_data = {} + for variant in variants: + results = self.get_variant_results(variant.id, metric_name) + values = [r.metrics[metric_name] for r in results] + variant_data[variant.name] = values + + # Calculate statistics for each variant + variant_stats = {} + for name, values in variant_data.items(): + if values: + variant_stats[name] = { + "count": len(values), + "mean": np.mean(values), + "std": np.std(values), + "min": np.min(values), + "max": np.max(values), + "median": np.median(values), + } + else: + variant_stats[name] = { + "count": 0, + "mean": None, + "std": None, + "min": None, + "max": None, + "median": None, + } + + # Perform pairwise tests + pairwise_tests = [] + control_values = variant_data.get(control.name, []) + + for variant in variants: + if variant.name == control.name: + continue + + variant_values = variant_data.get(variant.name, []) + if len(control_values) > 1 and len(variant_values) > 1: + # Perform t-test + t_stat, p_value = stats.ttest_ind(control_values, variant_values) + + # Calculate effect size (Cohen's d) + pooled_std = np.sqrt( + (np.std(control_values) ** 2 + np.std(variant_values) ** 2) / 2 + ) + effect_size = (np.mean(variant_values) - np.mean(control_values)) / pooled_std if pooled_std > 0 else 0 + + pairwise_tests.append( + { + "control": control.name, + "treatment": variant.name, + "t_statistic": t_stat, + "p_value": p_value, + "effect_size": effect_size, + "significant": p_value < 0.05, + } + ) + + # Determine winner (highest mean) + winner = None + best_mean = -float("inf") + for name, stats in variant_stats.items(): + if stats["mean"] is not None and stats["mean"] > best_mean: + best_mean = stats["mean"] + winner = name + + return { + "metric": metric_name, + "variant_stats": variant_stats, + "pairwise_tests": pairwise_tests, + "winner": winner, + } diff --git a/requirements.txt b/requirements.txt index aabe6af..84f524a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ torch-geometric>=2.3.0 numpy>=1.24 pandas>=2.0 polars>=1.0 +scipy>=1.10 sqlalchemy>=2.0 alembic>=1.12 psycopg2-binary>=2.9 diff --git a/tests/test_schema.py b/tests/test_schema.py index 9578351..3bfcbee 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -15,6 +15,8 @@ Account, Asset, Base, + Experiment, + ExperimentResult, GraphAccount, GraphClaimDetail, GraphEdge, @@ -25,6 +27,7 @@ ModelVersion, Operation, Transaction, + Variant, ) @@ -66,6 +69,9 @@ def test_models_importable(): GraphPaymentDetail, Model, ModelVersion, + Experiment, + Variant, + ExperimentResult, ): assert hasattr(cls, "__tablename__") @@ -78,6 +84,8 @@ def test_create_all_tables(engine): "accounts", "assets", "effects", + "experiment_results", + "experiments", "graph_accounts", "graph_claim_details", "graph_edges", @@ -89,6 +97,7 @@ def test_create_all_tables(engine): "normalized_transactions", "operations", "transactions", + "variants", } @@ -103,6 +112,9 @@ def test_table_names(): assert GraphEdge.__tablename__ == "graph_edges" assert Model.__tablename__ == "models" assert ModelVersion.__tablename__ == "model_versions" + assert Experiment.__tablename__ == "experiments" + assert Variant.__tablename__ == "variants" + assert ExperimentResult.__tablename__ == "experiment_results" # --------------------------------------------------------------------------- @@ -276,6 +288,79 @@ def test_model_version_columns(engine): ) +def test_experiment_columns(engine): + inspector = inspect(engine) + cols = {c["name"] for c in inspector.get_columns("experiments")} + expected = { + "id", + "name", + "description", + "experiment_type", + "status", + "traffic_allocation", + "start_at", + "end_at", + "created_at", + "updated_at", + } + assert expected <= cols + + +def test_variant_columns(engine): + inspector = inspect(engine) + cols = {c["name"] for c in inspector.get_columns("variants")} + expected = { + "id", + "experiment_id", + "name", + "description", + "traffic_weight", + "is_control", + "model_version_id", + "config", + "created_at", + "updated_at", + } + assert expected <= cols + + # FK to experiments + fks = inspector.get_foreign_keys("variants") + assert any( + fk["referred_table"] == "experiments" + and fk["referred_columns"] == ["id"] + for fk in fks + ) + # FK to model_versions + assert any( + fk["referred_table"] == "model_versions" + and fk["referred_columns"] == ["id"] + for fk in fks + ) + + +def test_experiment_result_columns(engine): + inspector = inspect(engine) + cols = {c["name"] for c in inspector.get_columns("experiment_results")} + expected = { + "id", + "variant_id", + "user_id", + "session_id", + "metrics", + "metadata", + "created_at", + } + assert expected <= cols + + # FK to variants + fks = inspector.get_foreign_keys("experiment_results") + assert any( + fk["referred_table"] == "variants" + and fk["referred_columns"] == ["id"] + for fk in fks + ) + + # --------------------------------------------------------------------------- # Relationships # --------------------------------------------------------------------------- @@ -406,6 +491,64 @@ def test_model_registry_relationships(session): assert version2.model is model +def test_ab_testing_relationships(session): + """Experiment.variants cascade deletes Variant and ExperimentResult rows.""" + now = datetime.now(timezone.utc) + + experiment = Experiment( + name="test-experiment", + experiment_type="model", + description="Test experiment", + ) + session.add(experiment) + session.flush() + + variant1 = Variant( + experiment_id=experiment.id, + name="control", + traffic_weight=0.5, + is_control=True, + ) + variant2 = Variant( + experiment_id=experiment.id, + name="treatment", + traffic_weight=0.5, + is_control=False, + ) + session.add_all([variant1, variant2]) + session.flush() + + result1 = ExperimentResult( + variant_id=variant1.id, + metrics={"accuracy": 0.9}, + ) + result2 = ExperimentResult( + variant_id=variant1.id, + metrics={"accuracy": 0.85}, + ) + result3 = ExperimentResult( + variant_id=variant2.id, + metrics={"accuracy": 0.92}, + ) + session.add_all([result1, result2, result3]) + session.flush() + + session.refresh(experiment) + session.refresh(variant1) + session.refresh(variant2) + + assert len(experiment.variants) == 2 + assert variant1 in experiment.variants + assert variant2 in experiment.variants + assert variant1.experiment is experiment + assert variant2.experiment is experiment + assert len(variant1.results) == 2 + assert len(variant2.results) == 1 + assert result1.variant is variant1 + assert result2.variant is variant1 + assert result3.variant is variant2 + + # --------------------------------------------------------------------------- # Round-trip insert & query # --------------------------------------------------------------------------- From 25809ccb10434a29303431ca813d57a553009372 Mon Sep 17 00:00:00 2001 From: JACOB STANLEY Date: Mon, 29 Jun 2026 12:05:38 +0100 Subject: [PATCH 3/3] https://github.com/Menjay7/astroml.git --- astroml/db/schema.py | 94 ++++ astroml/tracking/__init__.py | 3 +- astroml/tracking/golden_dataset.py | 762 +++++++++++++++++++++++++++++ tests/test_schema.py | 90 ++++ 4 files changed, 948 insertions(+), 1 deletion(-) create mode 100644 astroml/tracking/golden_dataset.py diff --git a/astroml/db/schema.py b/astroml/db/schema.py index 5e563a1..98b2704 100644 --- a/astroml/db/schema.py +++ b/astroml/db/schema.py @@ -735,3 +735,97 @@ class ExperimentResult(Base): Index("ix_experiment_results_session_id", "session_id"), Index("ix_experiment_results_created_at", "created_at"), ) + + +# --------------------------------------------------------------------------- +# Golden Dataset Framework +# --------------------------------------------------------------------------- + +class GoldenDataset(Base): + """Golden dataset for model evaluation and benchmarking.""" + + __tablename__ = "golden_datasets" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(128), nullable=False, unique=True) + description: Mapped[Optional[str]] = mapped_column(Text) + dataset_type: Mapped[str] = mapped_column(String(32), nullable=False) # 'classification', 'regression', 'anomaly_detection', etc. + task_type: Mapped[str] = mapped_column(String(32), nullable=False) + version: Mapped[str] = mapped_column(String(32), nullable=False) + source: Mapped[Optional[str]] = mapped_column(String(256)) # Data source identifier + size: Mapped[int] = mapped_column(Integer, nullable=False, server_default="0") # Number of entries + status: Mapped[str] = mapped_column(String(32), nullable=False, server_default="draft") + quality_score: Mapped[Optional[float]] = mapped_column(Numeric) # Overall quality metric (0-1) + metadata: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) # Additional dataset metadata + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + nullable=False, server_default=func.now(), onupdate=func.now() + ) + + # Relationships + entries: Mapped[list[GoldenDatasetEntry]] = relationship( + back_populates="dataset", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("ix_golden_datasets_type", "dataset_type"), + Index("ix_golden_datasets_task_type", "task_type"), + Index("ix_golden_datasets_version", "version"), + Index("ix_golden_datasets_status", "status"), + UniqueConstraint("name", "version", name="uq_golden_datasets_name_version"), + CheckConstraint( + "dataset_type IN ('classification', 'regression', 'anomaly_detection', 'clustering', 'custom')", + name="ck_golden_datasets_type", + ), + CheckConstraint( + "status IN ('draft', 'review', 'approved', 'archived')", + name="ck_golden_datasets_status", + ), + CheckConstraint( + "quality_score IS NULL OR (quality_score >= 0 AND quality_score <= 1)", + name="ck_golden_datasets_quality_score", + ), + ) + + +class GoldenDatasetEntry(Base): + """Individual entry in a golden dataset with ground truth labels.""" + + __tablename__ = "golden_dataset_entries" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) + dataset_id: Mapped[int] = mapped_column( + Integer, ForeignKey("golden_datasets.id"), nullable=False + ) + input_data: Mapped[dict] = mapped_column( + JSON().with_variant(JSONB(), "postgresql"), nullable=False + ) # Model input features + output_data: Mapped[dict] = mapped_column( + JSON().with_variant(JSONB(), "postgresql"), nullable=False + ) # Ground truth labels + metadata: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) # Entry-specific metadata + difficulty: Mapped[Optional[float]] = mapped_column(Numeric) # Difficulty score (0-1) + confidence: Mapped[Optional[float]] = mapped_column(Numeric) # Label confidence (0-1) + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + + # Relationships + dataset: Mapped[GoldenDataset] = relationship(back_populates="entries") + + __table_args__ = ( + Index("ix_golden_dataset_entries_dataset_id", "dataset_id"), + Index("ix_golden_dataset_entries_difficulty", "difficulty"), + Index("ix_golden_dataset_entries_confidence", "confidence"), + CheckConstraint( + "difficulty IS NULL OR (difficulty >= 0 AND difficulty <= 1)", + name="ck_golden_dataset_entries_difficulty", + ), + CheckConstraint( + "confidence IS NULL OR (confidence >= 0 AND confidence <= 1)", + name="ck_golden_dataset_entries_confidence", + ), + ) diff --git a/astroml/tracking/__init__.py b/astroml/tracking/__init__.py index 8da1aaa..30a7a79 100644 --- a/astroml/tracking/__init__.py +++ b/astroml/tracking/__init__.py @@ -1,5 +1,6 @@ from .ab_testing import ABTestingFramework +from .golden_dataset import GoldenDatasetGenerator from .mlflow_tracker import MLflowTracker from .model_registry import ModelRegistry -__all__ = ["MLflowTracker", "ModelRegistry", "ABTestingFramework"] +__all__ = ["MLflowTracker", "ModelRegistry", "ABTestingFramework", "GoldenDatasetGenerator"] diff --git a/astroml/tracking/golden_dataset.py b/astroml/tracking/golden_dataset.py new file mode 100644 index 0000000..917358e --- /dev/null +++ b/astroml/tracking/golden_dataset.py @@ -0,0 +1,762 @@ +"""Golden dataset generation and management for model evaluation.""" +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from astroml.db.schema import GoldenDataset, GoldenDatasetEntry +from astroml.db.session import get_session + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Dataset Status State Machine +# --------------------------------------------------------------------------- + +VALID_DATASET_TRANSITIONS = { + "draft": ["review", "archived"], + "review": ["approved", "draft", "archived"], + "approved": ["archived"], + "archived": [], # Terminal state +} + +VALID_DATASET_STATUSES = set(VALID_DATASET_TRANSITIONS.keys()) + + +class InvalidDatasetStatusError(ValueError): + """Raised when an invalid dataset status transition is attempted.""" + + pass + + +class GoldenDatasetGenerator: + """Core class for generating and managing golden datasets. + + Provides dataset creation, entry management, validation, + and quality assessment capabilities. + """ + + def __init__(self, session: Optional[Session] = None): + """Initialize the golden dataset generator. + + Args: + session: Optional SQLAlchemy session. If not provided, creates a new session. + """ + self._session = session + self._owns_session = session is None + + @property + def session(self) -> Session: + """Get the SQLAlchemy session, creating one if needed.""" + if self._session is None: + self._session = get_session() + return self._session + + def close(self) -> None: + """Close the session if we own it.""" + if self._owns_session and self._session is not None: + self._session.close() + self._session = None + + def __enter__(self) -> "GoldenDatasetGenerator": + return self + + def __exit__(self, *_: Any) -> None: + self.close() + + # ------------------------------------------------------------------ + # Dataset CRUD operations + # ------------------------------------------------------------------ + + def create_dataset( + self, + name: str, + dataset_type: str, + task_type: str, + version: str = "1.0.0", + description: Optional[str] = None, + source: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> GoldenDataset: + """Create a new golden dataset. + + Args: + name: Dataset name + dataset_type: Type of dataset (classification, regression, etc.) + task_type: ML task type + version: Dataset version + description: Optional description + source: Optional data source identifier + metadata: Optional additional metadata + + Returns: + Created GoldenDataset instance + + Raises: + ValueError: If dataset with same name/version exists or invalid parameters + """ + if dataset_type not in ( + "classification", + "regression", + "anomaly_detection", + "clustering", + "custom", + ): + raise ValueError(f"Invalid dataset_type: '{dataset_type}'") + + existing = self.get_dataset_by_name_version(name, version) + if existing: + raise ValueError(f"Dataset '{name}' version '{version}' already exists") + + dataset = GoldenDataset( + name=name, + description=description, + dataset_type=dataset_type, + task_type=task_type, + version=version, + source=source, + metadata=metadata, + ) + self.session.add(dataset) + self.session.commit() + self.session.refresh(dataset) + logger.info( + "Created golden dataset: %s (id=%d, type=%s, version=%s)", + name, + dataset.id, + dataset_type, + version, + ) + return dataset + + def get_dataset(self, dataset_id: int) -> Optional[GoldenDataset]: + """Get a dataset by ID. + + Args: + dataset_id: Dataset ID + + Returns: + GoldenDataset instance or None if not found + """ + return self.session.get(GoldenDataset, dataset_id) + + def get_dataset_by_name_version( + self, name: str, version: str + ) -> Optional[GoldenDataset]: + """Get a dataset by name and version. + + Args: + name: Dataset name + version: Dataset version + + Returns: + GoldenDataset instance or None if not found + """ + stmt = select(GoldenDataset).where( + GoldenDataset.name == name, GoldenDataset.version == version + ) + return self.session.execute(stmt).scalar_one_or_none() + + def list_datasets( + self, + dataset_type: Optional[str] = None, + task_type: Optional[str] = None, + status: Optional[str] = None, + ) -> List[GoldenDataset]: + """List datasets with optional filters. + + Args: + dataset_type: Filter by dataset type + task_type: Filter by task type + status: Filter by status + + Returns: + List of GoldenDataset instances + """ + stmt = select(GoldenDataset) + if dataset_type: + stmt = stmt.where(GoldenDataset.dataset_type == dataset_type) + if task_type: + stmt = stmt.where(GoldenDataset.task_type == task_type) + if status: + stmt = stmt.where(GoldenDataset.status == status) + stmt = stmt.order_by(GoldenDataset.created_at.desc()) + return list(self.session.execute(stmt).scalars().all()) + + def update_dataset( + self, + dataset_id: int, + description: Optional[str] = None, + source: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + quality_score: Optional[float] = None, + ) -> Optional[GoldenDataset]: + """Update a dataset. + + Args: + dataset_id: Dataset ID + description: New description + source: New source identifier + metadata: New metadata + quality_score: New quality score (0-1) + + Returns: + Updated GoldenDataset instance or None if not found + + Raises: + ValueError: If quality_score is out of range + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + return None + + if description is not None: + dataset.description = description + if source is not None: + dataset.source = source + if metadata is not None: + dataset.metadata = metadata + if quality_score is not None: + if not 0.0 <= quality_score <= 1.0: + raise ValueError(f"quality_score must be between 0.0 and 1.0, got {quality_score}") + dataset.quality_score = quality_score + + self.session.commit() + self.session.refresh(dataset) + logger.info("Updated golden dataset: %s (id=%d)", dataset.name, dataset_id) + return dataset + + def delete_dataset(self, dataset_id: int) -> bool: + """Delete a dataset and all its entries. + + Args: + dataset_id: Dataset ID + + Returns: + True if deleted, False if not found + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + return False + + self.session.delete(dataset) + self.session.commit() + logger.info("Deleted golden dataset: %s (id=%d)", dataset.name, dataset_id) + return True + + # ------------------------------------------------------------------ + # Dataset lifecycle management + # ------------------------------------------------------------------ + + def submit_for_review(self, dataset_id: int) -> Optional[GoldenDataset]: + """Submit a dataset for review. + + Args: + dataset_id: Dataset ID + + Returns: + Updated GoldenDataset or None if not found + + Raises: + InvalidDatasetStatusError: If dataset cannot be submitted + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + return None + + self._validate_dataset_status_transition(dataset.status, "review") + dataset.status = "review" + + self.session.commit() + self.session.refresh(dataset) + logger.info("Submitted dataset for review: %s (id=%d)", dataset.name, dataset_id) + return dataset + + def approve_dataset(self, dataset_id: int) -> Optional[GoldenDataset]: + """Approve a dataset. + + Args: + dataset_id: Dataset ID + + Returns: + Updated GoldenDataset or None if not found + + Raises: + InvalidDatasetStatusError: If dataset cannot be approved + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + return None + + self._validate_dataset_status_transition(dataset.status, "approved") + dataset.status = "approved" + + self.session.commit() + self.session.refresh(dataset) + logger.info("Approved dataset: %s (id=%d)", dataset.name, dataset_id) + return dataset + + def reject_dataset(self, dataset_id: int) -> Optional[GoldenDataset]: + """Reject a dataset (return to draft). + + Args: + dataset_id: Dataset ID + + Returns: + Updated GoldenDataset or None if not found + + Raises: + InvalidDatasetStatusError: If dataset cannot be rejected + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + return None + + self._validate_dataset_status_transition(dataset.status, "draft") + dataset.status = "draft" + + self.session.commit() + self.session.refresh(dataset) + logger.info("Rejected dataset (returned to draft): %s (id=%d)", dataset.name, dataset_id) + return dataset + + def archive_dataset(self, dataset_id: int) -> Optional[GoldenDataset]: + """Archive a dataset. + + Args: + dataset_id: Dataset ID + + Returns: + Updated GoldenDataset or None if not found + + Raises: + InvalidDatasetStatusError: If dataset cannot be archived + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + return None + + self._validate_dataset_status_transition(dataset.status, "archived") + dataset.status = "archived" + + self.session.commit() + self.session.refresh(dataset) + logger.info("Archived dataset: %s (id=%d)", dataset.name, dataset_id) + return dataset + + @staticmethod + def _validate_dataset_status_transition(from_status: str, to_status: str) -> None: + """Validate that a dataset status transition is allowed. + + Args: + from_status: Current status + to_status: Target status + + Raises: + InvalidDatasetStatusError: If transition is not allowed + """ + if to_status not in VALID_DATASET_STATUSES: + raise InvalidDatasetStatusError(f"Invalid target status: '{to_status}'") + + if from_status == to_status: + return # No-op transition is allowed + + allowed_transitions = VALID_DATASET_TRANSITIONS.get(from_status, []) + if to_status not in allowed_transitions: + raise InvalidDatasetStatusError( + f"Cannot transition from '{from_status}' to '{to_status}'. " + f"Allowed transitions from '{from_status}': {allowed_transitions}" + ) + + # ------------------------------------------------------------------ + # Entry management + # ------------------------------------------------------------------ + + def add_entry( + self, + dataset_id: int, + input_data: Dict[str, Any], + output_data: Dict[str, Any], + metadata: Optional[Dict[str, Any]] = None, + difficulty: Optional[float] = None, + confidence: Optional[float] = None, + ) -> GoldenDatasetEntry: + """Add an entry to a dataset. + + Args: + dataset_id: Dataset ID + input_data: Model input features + output_data: Ground truth labels + metadata: Optional entry-specific metadata + difficulty: Optional difficulty score (0-1) + confidence: Optional label confidence (0-1) + + Returns: + Created GoldenDatasetEntry instance + + Raises: + ValueError: If dataset not found or invalid parameters + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + raise ValueError(f"Dataset with id {dataset_id} not found") + + if difficulty is not None and not 0.0 <= difficulty <= 1.0: + raise ValueError(f"difficulty must be between 0.0 and 1.0, got {difficulty}") + + if confidence is not None and not 0.0 <= confidence <= 1.0: + raise ValueError(f"confidence must be between 0.0 and 1.0, got {confidence}") + + entry = GoldenDatasetEntry( + dataset_id=dataset_id, + input_data=input_data, + output_data=output_data, + metadata=metadata, + difficulty=difficulty, + confidence=confidence, + ) + self.session.add(entry) + self.session.flush() + + # Update dataset size + dataset.size += 1 + self.session.commit() + self.session.refresh(entry) + logger.debug( + "Added entry to dataset: %s (entry_id=%d, dataset_id=%d)", + dataset.name, + entry.id, + dataset_id, + ) + return entry + + def add_entries_batch( + self, + dataset_id: int, + entries: List[Dict[str, Any]], + ) -> List[GoldenDatasetEntry]: + """Add multiple entries to a dataset in a single transaction. + + Args: + dataset_id: Dataset ID + entries: List of entry dicts with keys: input_data, output_data, metadata, difficulty, confidence + + Returns: + List of created GoldenDatasetEntry instances + + Raises: + ValueError: If dataset not found or invalid parameters + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + raise ValueError(f"Dataset with id {dataset_id} not found") + + created_entries = [] + for entry_data in entries: + entry = GoldenDatasetEntry( + dataset_id=dataset_id, + input_data=entry_data["input_data"], + output_data=entry_data["output_data"], + metadata=entry_data.get("metadata"), + difficulty=entry_data.get("difficulty"), + confidence=entry_data.get("confidence"), + ) + self.session.add(entry) + created_entries.append(entry) + + # Update dataset size + dataset.size += len(entries) + self.session.commit() + + for entry in created_entries: + self.session.refresh(entry) + + logger.info( + "Added %d entries to dataset: %s (dataset_id=%d)", + len(entries), + dataset.name, + dataset_id, + ) + return created_entries + + def get_entry(self, entry_id: int) -> Optional[GoldenDatasetEntry]: + """Get an entry by ID. + + Args: + entry_id: Entry ID + + Returns: + GoldenDatasetEntry instance or None if not found + """ + return self.session.get(GoldenDatasetEntry, entry_id) + + def list_entries( + self, + dataset_id: int, + min_difficulty: Optional[float] = None, + max_difficulty: Optional[float] = None, + min_confidence: Optional[float] = None, + ) -> List[GoldenDatasetEntry]: + """List entries for a dataset with optional filters. + + Args: + dataset_id: Dataset ID + min_difficulty: Filter by minimum difficulty + max_difficulty: Filter by maximum difficulty + min_confidence: Filter by minimum confidence + + Returns: + List of GoldenDatasetEntry instances + """ + stmt = select(GoldenDatasetEntry).where( + GoldenDatasetEntry.dataset_id == dataset_id + ) + if min_difficulty is not None: + stmt = stmt.where(GoldenDatasetEntry.difficulty >= min_difficulty) + if max_difficulty is not None: + stmt = stmt.where(GoldenDatasetEntry.difficulty <= max_difficulty) + if min_confidence is not None: + stmt = stmt.where(GoldenDatasetEntry.confidence >= min_confidence) + stmt = stmt.order_by(GoldenDatasetEntry.created_at) + return list(self.session.execute(stmt).scalars().all()) + + def delete_entry(self, entry_id: int) -> bool: + """Delete an entry. + + Args: + entry_id: Entry ID + + Returns: + True if deleted, False if not found + """ + entry = self.get_entry(entry_id) + if not entry: + return False + + dataset_id = entry.dataset_id + self.session.delete(entry) + self.session.flush() + + # Update dataset size + dataset = self.get_dataset(dataset_id) + if dataset and dataset.size > 0: + dataset.size -= 1 + + self.session.commit() + logger.debug("Deleted entry (entry_id=%d)", entry_id) + return True + + # ------------------------------------------------------------------ + # Dataset validation and quality metrics + # ------------------------------------------------------------------ + + def validate_dataset(self, dataset_id: int) -> Dict[str, Any]: + """Validate a dataset and return quality metrics. + + Args: + dataset_id: Dataset ID + + Returns: + Dictionary with validation results and quality metrics + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + raise ValueError(f"Dataset with id {dataset_id} not found") + + entries = self.list_entries(dataset_id) + + validation_results = { + "dataset_id": dataset_id, + "dataset_name": dataset.name, + "total_entries": len(entries), + "is_valid": True, + "issues": [], + "quality_metrics": {}, + } + + # Check if dataset has entries + if len(entries) == 0: + validation_results["is_valid"] = False + validation_results["issues"].append("Dataset has no entries") + return validation_results + + # Check for missing data + entries_with_difficulty = [e for e in entries if e.difficulty is not None] + entries_with_confidence = [e for e in entries if e.confidence is not None] + + if len(entries_with_difficulty) < len(entries): + validation_results["issues"].append( + f"Only {len(entries_with_difficulty)}/{len(entries)} entries have difficulty scores" + ) + + if len(entries_with_confidence) < len(entries): + validation_results["issues"].append( + f"Only {len(entries_with_confidence)}/{len(entries)} entries have confidence scores" + ) + + # Calculate quality metrics + if entries_with_difficulty: + difficulties = [e.difficulty for e in entries_with_difficulty] + validation_results["quality_metrics"]["difficulty"] = { + "mean": sum(difficulties) / len(difficulties), + "min": min(difficulties), + "max": max(difficulties), + "count": len(difficulties), + } + + if entries_with_confidence: + confidences = [e.confidence for e in entries_with_confidence] + validation_results["quality_metrics"]["confidence"] = { + "mean": sum(confidences) / len(confidences), + "min": min(confidences), + "max": max(confidences), + "count": len(confidences), + } + + # Calculate overall quality score + quality_score = self._calculate_quality_score(validation_results) + validation_results["quality_score"] = quality_score + + # Update dataset with quality score + dataset.quality_score = quality_score + self.session.commit() + + logger.info( + "Validated dataset: %s (id=%d, quality_score=%.2f)", + dataset.name, + dataset_id, + quality_score, + ) + + return validation_results + + def _calculate_quality_score(self, validation_results: Dict[str, Any]) -> float: + """Calculate overall quality score from validation results. + + Args: + validation_results: Validation results dictionary + + Returns: + Quality score between 0 and 1 + """ + score = 1.0 + + # Penalize for missing entries + if validation_results["total_entries"] == 0: + return 0.0 + + # Penalize for missing difficulty/confidence scores + quality_metrics = validation_results.get("quality_metrics", {}) + total_entries = validation_results["total_entries"] + + if "difficulty" in quality_metrics: + diff_coverage = quality_metrics["difficulty"]["count"] / total_entries + score *= (0.5 + 0.5 * diff_coverage) # Min 0.5 if full coverage + + if "confidence" in quality_metrics: + conf_coverage = quality_metrics["confidence"]["count"] / total_entries + score *= (0.5 + 0.5 * conf_coverage) # Min 0.5 if full coverage + + # Penalize for issues + num_issues = len(validation_results["issues"]) + score *= max(0.0, 1.0 - (num_issues * 0.1)) + + return min(1.0, max(0.0, score)) + + # ------------------------------------------------------------------ + # Dataset export/import + # ------------------------------------------------------------------ + + def export_dataset(self, dataset_id: int) -> Dict[str, Any]: + """Export a dataset with all entries. + + Args: + dataset_id: Dataset ID + + Returns: + Dictionary with dataset metadata and entries + """ + dataset = self.get_dataset(dataset_id) + if not dataset: + raise ValueError(f"Dataset with id {dataset_id} not found") + + entries = self.list_entries(dataset_id) + + export_data = { + "dataset": { + "name": dataset.name, + "description": dataset.description, + "dataset_type": dataset.dataset_type, + "task_type": dataset.task_type, + "version": dataset.version, + "source": dataset.source, + "quality_score": dataset.quality_score, + "metadata": dataset.metadata, + }, + "entries": [ + { + "input_data": entry.input_data, + "output_data": entry.output_data, + "metadata": entry.metadata, + "difficulty": entry.difficulty, + "confidence": entry.confidence, + } + for entry in entries + ], + } + + logger.info("Exported dataset: %s (id=%d, entries=%d)", dataset.name, dataset_id, len(entries)) + return export_data + + def import_dataset( + self, + export_data: Dict[str, Any], + new_name: Optional[str] = None, + new_version: Optional[str] = None, + ) -> GoldenDataset: + """Import a dataset from exported data. + + Args: + export_data: Exported dataset dictionary + new_name: Optional new name (uses original if not provided) + new_version: Optional new version (uses original if not provided) + + Returns: + Created GoldenDataset instance + """ + dataset_data = export_data["dataset"] + entries_data = export_data["entries"] + + dataset = self.create_dataset( + name=new_name or dataset_data["name"], + dataset_type=dataset_data["dataset_type"], + task_type=dataset_data["task_type"], + version=new_version or dataset_data["version"], + description=dataset_data.get("description"), + source=dataset_data.get("source"), + metadata=dataset_data.get("metadata"), + ) + + # Import entries in batch + self.add_entries_batch( + dataset.id, + entries_data, + ) + + # Restore quality score if available + if dataset_data.get("quality_score"): + dataset.quality_score = dataset_data["quality_score"] + self.session.commit() + + logger.info( + "Imported dataset: %s (id=%d, entries=%d)", + dataset.name, + dataset.id, + len(entries_data), + ) + return dataset diff --git a/tests/test_schema.py b/tests/test_schema.py index 3bfcbee..cb52a5f 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -17,6 +17,8 @@ Base, Experiment, ExperimentResult, + GoldenDataset, + GoldenDatasetEntry, GraphAccount, GraphClaimDetail, GraphEdge, @@ -72,6 +74,8 @@ def test_models_importable(): Experiment, Variant, ExperimentResult, + GoldenDataset, + GoldenDatasetEntry, ): assert hasattr(cls, "__tablename__") @@ -86,6 +90,8 @@ def test_create_all_tables(engine): "effects", "experiment_results", "experiments", + "golden_dataset_entries", + "golden_datasets", "graph_accounts", "graph_claim_details", "graph_edges", @@ -115,6 +121,8 @@ def test_table_names(): assert Experiment.__tablename__ == "experiments" assert Variant.__tablename__ == "variants" assert ExperimentResult.__tablename__ == "experiment_results" + assert GoldenDataset.__tablename__ == "golden_datasets" + assert GoldenDatasetEntry.__tablename__ == "golden_dataset_entries" # --------------------------------------------------------------------------- @@ -361,6 +369,51 @@ def test_experiment_result_columns(engine): ) +def test_golden_dataset_columns(engine): + inspector = inspect(engine) + cols = {c["name"] for c in inspector.get_columns("golden_datasets")} + expected = { + "id", + "name", + "description", + "dataset_type", + "task_type", + "version", + "source", + "size", + "status", + "quality_score", + "metadata", + "created_at", + "updated_at", + } + assert expected <= cols + + +def test_golden_dataset_entry_columns(engine): + inspector = inspect(engine) + cols = {c["name"] for c in inspector.get_columns("golden_dataset_entries")} + expected = { + "id", + "dataset_id", + "input_data", + "output_data", + "metadata", + "difficulty", + "confidence", + "created_at", + } + assert expected <= cols + + # FK to golden_datasets + fks = inspector.get_foreign_keys("golden_dataset_entries") + assert any( + fk["referred_table"] == "golden_datasets" + and fk["referred_columns"] == ["id"] + for fk in fks + ) + + # --------------------------------------------------------------------------- # Relationships # --------------------------------------------------------------------------- @@ -549,6 +602,43 @@ def test_ab_testing_relationships(session): assert result3.variant is variant2 +def test_golden_dataset_relationships(session): + """GoldenDataset.entries cascade deletes GoldenDatasetEntry rows.""" + dataset = GoldenDataset( + name="test-dataset", + dataset_type="classification", + task_type="classification", + version="1.0.0", + ) + session.add(dataset) + session.flush() + + entry1 = GoldenDatasetEntry( + dataset_id=dataset.id, + input_data={"feature1": 1.0, "feature2": 2.0}, + output_data={"label": 0}, + difficulty=0.5, + confidence=0.9, + ) + entry2 = GoldenDatasetEntry( + dataset_id=dataset.id, + input_data={"feature1": 3.0, "feature2": 4.0}, + output_data={"label": 1}, + difficulty=0.7, + confidence=0.95, + ) + session.add_all([entry1, entry2]) + session.flush() + + session.refresh(dataset) + + assert len(dataset.entries) == 2 + assert entry1 in dataset.entries + assert entry2 in dataset.entries + assert entry1.dataset is dataset + assert entry2.dataset is dataset + + # --------------------------------------------------------------------------- # Round-trip insert & query # ---------------------------------------------------------------------------