diff --git a/README.md b/README.md index 9041e73..f31cd15 100644 --- a/README.md +++ b/README.md @@ -278,6 +278,10 @@ roar label show artifact ./outputs/model.pt # Show label history (all versions) roar label history dag current roar label history artifact + +# Push current local user-managed labels to GLaaS +roar label push job @2 +roar label push artifact ./outputs/model.pt ``` **Entity targets:** @@ -286,7 +290,7 @@ roar label history artifact - `job`: step ref (`@N` or `@BN`) or job UID - `artifact`: file path or artifact hash -Labels are stored locally and included in lineage registration/publish flows to GLaaS when supported by the configured server. +Labels are stored locally by default. You can explicitly push the current local user-managed labels for one target to GLaaS with `roar label push ...`, and labels are also included in lineage registration/publish flows when supported by the configured server. ### `roar register` diff --git a/roar/application/labels.py b/roar/application/labels.py index 398d192..be42553 100644 --- a/roar/application/labels.py +++ b/roar/application/labels.py @@ -16,6 +16,7 @@ from ..core.label_origins import LABEL_ORIGIN_USER, build_current_key_origins from ..db.context import DatabaseContext from .label_rendering import flatten_label_metadata +from .publish.lineage import LineageCollector from .system_labels import is_reserved_system_label_path, strip_reserved_system_labels @@ -44,6 +45,17 @@ class _LabelSyncDatabaseContext(Protocol): def labels(self) -> Any: ... +class _RemoteLabelMutationDatabaseContext(Protocol): + @property + def sessions(self) -> Any: ... + + @property + def jobs(self) -> Any: ... + + @property + def artifacts(self) -> Any: ... + + def parse_label_pairs(pairs: tuple[str, ...]) -> dict[str, Any]: """Parse ``key=value`` pairs into nested metadata.""" metadata: dict[str, Any] = {} @@ -233,6 +245,92 @@ def _reject_reserved_keys(metadata: dict[str, Any]) -> None: raise ValueError(f"Reserved label keys cannot be set manually: {joined}") +def build_remote_label_mutation_payload( + db_ctx: _RemoteLabelMutationDatabaseContext, + *, + roar_dir: Path, + target: LabelTargetRef, + metadata: dict[str, Any], +) -> dict[str, Any]: + """Build a GLaaS label-mutation payload for one local target.""" + if target.entity_type == "dag": + if target.session_id is None: + raise ValueError("DAG target is missing a local session id.") + return { + "entity_type": "dag", + "session_hash": _canonical_remote_session_hash( + db_ctx, + roar_dir=roar_dir, + session_id=int(target.session_id), + ), + "metadata": metadata, + } + + if target.entity_type == "job": + if target.job_id is None: + raise ValueError("Job target is missing a local job id.") + job = db_ctx.jobs.get(int(target.job_id)) + if not isinstance(job, dict): + raise ValueError("Job not found.") + session_id = job.get("session_id") + job_uid = job.get("job_uid") + if not isinstance(session_id, int): + raise ValueError("Job target is missing a local session id.") + if not isinstance(job_uid, str) or not job_uid: + raise ValueError("Job target is missing a job UID.") + return { + "entity_type": "job", + "session_hash": _canonical_remote_session_hash( + db_ctx, + roar_dir=roar_dir, + session_id=session_id, + ), + "job_uid": job_uid, + "metadata": metadata, + } + + if target.entity_type == "artifact": + if target.artifact_id is None: + raise ValueError("Artifact target is missing a local artifact id.") + artifact = db_ctx.artifacts.get(str(target.artifact_id)) + if not isinstance(artifact, dict): + raise ValueError("Artifact not found.") + artifact_hash = artifact.get("hash") + if not isinstance(artifact_hash, str) or not artifact_hash: + hashes = db_ctx.artifacts.get_hashes(str(target.artifact_id)) + artifact_hash = next( + ( + item.get("digest") + for item in hashes + if isinstance(item, dict) + and item.get("algorithm") == "blake3" + and isinstance(item.get("digest"), str) + and item.get("digest") + ), + None, + ) + if not isinstance(artifact_hash, str) or not artifact_hash: + artifact_hash = next( + ( + item.get("digest") + for item in hashes + if isinstance(item, dict) + and isinstance(item.get("digest"), str) + and item.get("digest") + ), + None, + ) + if not isinstance(artifact_hash, str) or not artifact_hash: + raise ValueError("Artifact target is missing a content hash.") + return { + "entity_type": "artifact", + "artifact_hash": artifact_hash, + "metadata": metadata, + } + + raise ValueError(f"Unsupported label entity type: {target.entity_type}") + + def collect_label_sync_payloads( db_ctx: _LabelSyncDatabaseContext, *, @@ -308,6 +406,42 @@ def collect_label_sync_payloads( return payloads +def _canonical_remote_session_hash( + db_ctx: _RemoteLabelMutationDatabaseContext, + *, + roar_dir: Path, + session_id: int, +) -> str: + from ..core.logging import get_logger + from .publish.runtime import build_publish_runtime + from .publish.service import prepare_register_preview_execution + + session = db_ctx.sessions.get(session_id) + if not isinstance(session, dict): + raise ValueError("Session not found.") + + lineage = LineageCollector().collect_session(session_id, roar_dir) + if not getattr(lineage, "jobs", None): + raise ValueError( + "Selected session has no tracked steps. Run 'roar run' or 'roar build' first." + ) + + runtime = build_publish_runtime( + start_dir=str(roar_dir.parent), + allow_public_without_binding=True, + ) + prepared = prepare_register_preview_execution( + runtime=runtime, + roar_dir=roar_dir, + cwd=roar_dir.parent, + session_id=session_id, + session_hash_override=None, + logger=get_logger(), + lineage=lineage, + ) + return str(prepared.session_hash) + + def _assign_nested(root: dict[str, Any], path: list[str], value: Any) -> None: cursor = root for key in path[:-1]: diff --git a/roar/application/query/__init__.py b/roar/application/query/__init__.py index 50f847a..64cb8ee 100644 --- a/roar/application/query/__init__.py +++ b/roar/application/query/__init__.py @@ -9,6 +9,7 @@ "DagQueryRequest": ".requests", "LabelCopyRequest": ".requests", "LabelHistoryRequest": ".requests", + "LabelPushRequest": ".requests", "LabelSetRequest": ".requests", "LabelShowRequest": ".requests", "LineageQueryRequest": ".requests", @@ -23,6 +24,7 @@ "StatusSummary": ".results", "build_copy_labels_summary": ".label", "build_label_history_summary": ".label", + "build_push_labels_summary": ".label", "build_set_labels_summary": ".label", "build_show_labels_summary": ".label", "copy_labels": ".label", @@ -32,6 +34,7 @@ "render_log": ".log", "render_show": ".show", "render_status": ".status", + "push_labels": ".label", "set_labels": ".label", "show_labels": ".label", } diff --git a/roar/application/query/label.py b/roar/application/query/label.py index 78fa13d..533a51d 100644 --- a/roar/application/query/label.py +++ b/roar/application/query/label.py @@ -3,11 +3,14 @@ from __future__ import annotations from ...db.context import create_database_context +from ...integrations.glaas import GlaasClient from ..label_rendering import flatten_label_metadata -from ..labels import LabelService, parse_label_pairs +from ..labels import LabelService, build_remote_label_mutation_payload, parse_label_pairs +from ..system_labels import strip_reserved_system_labels from .requests import ( LabelCopyRequest, LabelHistoryRequest, + LabelPushRequest, LabelSetRequest, LabelShowRequest, ) @@ -64,6 +67,52 @@ def build_copy_labels_summary(request: LabelCopyRequest) -> LabelCurrentSummary: return _build_current_summary(result.metadata, heading=heading) +def push_labels(request: LabelPushRequest) -> str: + """Push the current local user-managed label document for a target to GLaaS.""" + return build_push_labels_summary(request).render() + + +def build_push_labels_summary(request: LabelPushRequest) -> LabelCurrentSummary: + """Build the typed summary for a remote label push operation.""" + with create_database_context(request.roar_dir) as db_ctx: + service = LabelService(db_ctx, request.cwd) + resolved = service.resolve_target(request.entity_type, request.target) + metadata = strip_reserved_system_labels(service.current_metadata(resolved)) + if not metadata: + raise ValueError(f"No local user-managed labels to push for {request.target}.") + payload = build_remote_label_mutation_payload( + db_ctx, + roar_dir=request.roar_dir, + target=resolved, + metadata=metadata, + ) + + client = GlaasClient(start_dir=str(request.cwd), allow_public_without_binding=True) + result, error = client.patch_current_label(payload) + if error: + raise ValueError(f"Remote label push failed: {error}") + + remote_metadata = metadata + version: int | None = None + if isinstance(result, dict): + returned_metadata = result.get("metadata") + if isinstance(returned_metadata, dict): + remote_metadata = strip_reserved_system_labels(returned_metadata) + raw_version = result.get("version") + if raw_version is not None: + try: + version = int(raw_version) + except (TypeError, ValueError): + version = None + + heading = ( + f"Pushed remote labels (version {version}):" + if version is not None + else "Pushed remote labels:" + ) + return _build_current_summary(remote_metadata, heading=heading) + + def show_labels(request: LabelShowRequest) -> str: """Show the current local label document for a target.""" return build_show_labels_summary(request).render() diff --git a/roar/application/query/requests.py b/roar/application/query/requests.py index 9b20be7..ee387bf 100644 --- a/roar/application/query/requests.py +++ b/roar/application/query/requests.py @@ -84,6 +84,14 @@ class LabelHistoryRequest: target: str +@dataclass(frozen=True) +class LabelPushRequest: + roar_dir: Path + cwd: Path + entity_type: str + target: str + + DiffFormat = Literal["summary", "category", "dag"] diff --git a/roar/cli/commands/label.py b/roar/cli/commands/label.py index a886505..0708153 100644 --- a/roar/cli/commands/label.py +++ b/roar/cli/commands/label.py @@ -6,6 +6,7 @@ roar label cp roar label show roar label history + roar label push """ from __future__ import annotations @@ -14,6 +15,7 @@ from ...application.query.label import ( copy_labels, + push_labels, set_labels, show_labels, ) @@ -23,6 +25,7 @@ from ...application.query.requests import ( LabelCopyRequest, LabelHistoryRequest, + LabelPushRequest, LabelSetRequest, LabelShowRequest, ) @@ -35,7 +38,7 @@ @click.group("label", invoke_without_command=True) @click.pass_context def label(ctx: click.Context) -> None: - """Manage local labels for DAGs, jobs, and artifacts.""" + """Manage local labels and push user-managed label updates to GLaaS.""" if ctx.invoked_subcommand is None: click.echo(ctx.get_help()) @@ -115,6 +118,27 @@ def label_show(ctx: RoarContext, entity_type: str, target: str) -> None: click.echo(rendered) +@label.command("push") +@click.argument("entity_type", type=_ENTITY_TYPE) +@click.argument("target") +@click.pass_obj +@require_init +def label_push(ctx: RoarContext, entity_type: str, target: str) -> None: + """Push the current local user-managed labels for a target to GLaaS.""" + try: + rendered = push_labels( + LabelPushRequest( + roar_dir=ctx.roar_dir, + cwd=ctx.cwd, + entity_type=entity_type, + target=target, + ) + ) + except (ValueError, RuntimeError) as exc: + raise click.ClickException(str(exc)) from exc + click.echo(rendered) + + @label.command("history") @click.argument("entity_type", type=_ENTITY_TYPE) @click.argument("target") diff --git a/roar/integrations/glaas/client.py b/roar/integrations/glaas/client.py index 0b692d8..bfc4095 100644 --- a/roar/integrations/glaas/client.py +++ b/roar/integrations/glaas/client.py @@ -458,6 +458,13 @@ def sync_labels( return {"created": 0, "updated": 0, "unchanged": 0}, None return self._request("POST", "/api/v1/labels/sync", {"labels": labels}) + def patch_current_label( + self, + label: dict[str, Any], + ) -> tuple[dict | None, str | None]: + """Patch the current remote label document for one lineage target.""" + return self._request("PATCH", "/api/v1/labels/current", label) + def register_job_inputs( self, session_hash: str, diff --git a/tests/application/query/test_label.py b/tests/application/query/test_label.py index f60685d..411c9d1 100644 --- a/tests/application/query/test_label.py +++ b/tests/application/query/test_label.py @@ -6,12 +6,14 @@ from roar.application.query import ( LabelCopyRequest, LabelHistoryRequest, + LabelPushRequest, LabelSetRequest, LabelShowRequest, ) from roar.application.query.label import ( build_copy_labels_summary, build_label_history_summary, + build_push_labels_summary, build_set_labels_summary, build_show_labels_summary, ) @@ -60,6 +62,16 @@ def _history_request(tmp_path: Path, **overrides) -> LabelHistoryRequest: ) +def _push_request(tmp_path: Path, **overrides) -> LabelPushRequest: + return LabelPushRequest( + roar_dir=overrides.pop("roar_dir", tmp_path / ".roar"), + cwd=overrides.pop("cwd", tmp_path), + entity_type=overrides.pop("entity_type", "artifact"), + target=overrides.pop("target", "processed.csv"), + **overrides, + ) + + def test_build_set_labels_summary_returns_typed_summary(tmp_path: Path) -> None: db_ctx = MagicMock() db_ctx.__enter__.return_value = db_ctx @@ -127,6 +139,75 @@ def test_build_show_labels_summary_renders_no_labels_when_empty(tmp_path: Path) assert summary.render() == "No labels." +def test_build_push_labels_summary_returns_remote_versioned_summary(tmp_path: Path) -> None: + db_ctx = MagicMock() + db_ctx.__enter__.return_value = db_ctx + db_ctx.__exit__.return_value = None + service = MagicMock() + resolved_target = object() + service.resolve_target.return_value = resolved_target + service.current_metadata.return_value = {"owner": "ml", "stage": "gold"} + client = MagicMock() + client.patch_current_label.return_value = ( + { + "version": 2, + "metadata": {"owner": "ml", "stage": "gold"}, + }, + None, + ) + + with ( + patch("roar.application.query.label.create_database_context", return_value=db_ctx), + patch("roar.application.query.label.LabelService", return_value=service), + patch( + "roar.application.query.label.build_remote_label_mutation_payload", + return_value={ + "entity_type": "artifact", + "artifact_hash": "a" * 64, + "metadata": {"owner": "ml", "stage": "gold"}, + }, + ) as build_payload, + patch("roar.application.query.label.GlaasClient", return_value=client), + ): + summary = build_push_labels_summary(_push_request(tmp_path)) + + build_payload.assert_called_once_with( + db_ctx, + roar_dir=tmp_path / ".roar", + target=resolved_target, + metadata={"owner": "ml", "stage": "gold"}, + ) + client.patch_current_label.assert_called_once_with( + { + "entity_type": "artifact", + "artifact_hash": "a" * 64, + "metadata": {"owner": "ml", "stage": "gold"}, + } + ) + assert summary.heading == "Pushed remote labels (version 2):" + assert summary.render() == "Pushed remote labels (version 2):\n owner=ml\n stage=gold" + + +def test_build_push_labels_summary_rejects_missing_user_managed_labels(tmp_path: Path) -> None: + db_ctx = MagicMock() + db_ctx.__enter__.return_value = db_ctx + db_ctx.__exit__.return_value = None + service = MagicMock() + service.resolve_target.return_value = object() + service.current_metadata.return_value = {"roar": {"operation": {"kind": "run"}}} + + with ( + patch("roar.application.query.label.create_database_context", return_value=db_ctx), + patch("roar.application.query.label.LabelService", return_value=service), + ): + try: + build_push_labels_summary(_push_request(tmp_path, target="@1", entity_type="job")) + except ValueError as exc: + assert str(exc) == "No local user-managed labels to push for @1." + else: # pragma: no cover - defensive assertion style + raise AssertionError("Expected ValueError for missing user-managed labels") + + def test_build_label_history_summary_returns_versioned_entries(tmp_path: Path) -> None: db_ctx = MagicMock() db_ctx.__enter__.return_value = db_ctx diff --git a/tests/integration/fake_glaas.py b/tests/integration/fake_glaas.py index b38dd5e..3792db9 100644 --- a/tests/integration/fake_glaas.py +++ b/tests/integration/fake_glaas.py @@ -21,6 +21,9 @@ def __init__(self) -> None: self.input_links: list[dict[str, Any]] = [] self.output_links: list[dict[str, Any]] = [] self.label_syncs: list[list[dict[str, Any]]] = [] + self.label_mutation_attempts: list[dict[str, Any]] = [] + self.label_mutations: list[dict[str, Any]] = [] + self.current_labels_by_target: dict[str, dict[str, Any]] = {} self.composite_registrations: list[dict[str, Any]] = [] self.artifacts_by_digest: dict[str, dict[str, Any]] = {} self.artifact_dags_by_digest: dict[str, dict[str, Any]] = {} @@ -125,6 +128,38 @@ def do_GET(self) -> None: self._write_json(404, {"error": f"Unhandled GET path: {self.path}"}) + def do_PATCH(self) -> None: + payload = self._read_json() + authorization = self.headers.get("Authorization") + self.server.auth_headers.append({"path": self.path, "authorization": authorization}) + if not authorization or not authorization.startswith("Bearer "): + self._write_json(401, {"error": "Missing or invalid bearer auth"}) + return + + if self.path == "/api/v1/labels/current": + metadata = payload.get("metadata") + target_key = _label_target_key(payload) + self.server.label_mutation_attempts.append(payload) + current = self.server.current_labels_by_target.get(target_key) + if not current: + self._write_json(404, {"error": {"message": "Label not found"}}) + return + self.server.label_mutations.append(payload) + merged_metadata = _deep_merge( + current["metadata"], metadata if isinstance(metadata, dict) else {} + ) + version = int(current.get("version", 0)) + 1 + updated = { + **current, + "version": version, + "metadata": merged_metadata, + } + self.server.current_labels_by_target[target_key] = updated + self._write_json(200, updated) + return + + self._write_json(404, {"error": f"Unhandled PATCH path: {self.path}"}) + def do_POST(self) -> None: payload = self._read_json() authorization = self.headers.get("Authorization") @@ -167,6 +202,29 @@ def do_POST(self) -> None: labels = payload.get("labels", []) if isinstance(labels, list): self.server.label_syncs.append(labels) + for label in labels: + if not isinstance(label, dict): + continue + target_key = _label_target_key(label) + current = self.server.current_labels_by_target.get(target_key) + version = int(current.get("version", 0)) + 1 if isinstance(current, dict) else 1 + current_label = { + "id": f"label-{len(self.server.current_labels_by_target) + 1}", + "entityType": label.get("entity_type"), + "version": version, + "metadata": label.get("metadata") + if isinstance(label.get("metadata"), dict) + else {}, + "createdAt": "2026-01-01T00:00:00Z", + } + if label.get("entity_type") == "dag": + current_label["sessionHash"] = label.get("session_hash") + elif label.get("entity_type") == "job": + current_label["sessionHash"] = label.get("session_hash") + current_label["jobUid"] = label.get("job_uid") + elif label.get("entity_type") == "artifact": + current_label["artifactHash"] = label.get("artifact_hash") + self.server.current_labels_by_target[target_key] = current_label self._write_json( 200, {"created": 0, "updated": 0, "unchanged": len(labels)}, @@ -235,6 +293,26 @@ def log_message(self, format: str, *args: object) -> None: """Suppress default stderr logging for integration tests.""" +def _label_target_key(payload: dict[str, Any]) -> str: + entity_type = str(payload.get("entity_type") or "") + if entity_type == "dag": + return f"dag:{payload.get('session_hash', '')}" + if entity_type == "job": + return f"job:{payload.get('session_hash', '')}:{payload.get('job_uid', '')}" + return f"artifact:{payload.get('artifact_hash', '')}" + + +def _deep_merge(current: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: + merged = json.loads(json.dumps(current)) + for key, value in patch.items(): + existing = merged.get(key) + if isinstance(existing, dict) and isinstance(value, dict): + merged[key] = _deep_merge(existing, value) + else: + merged[key] = value + return merged + + class FakeGlaasServer: """Context-managed fake GLaaS server.""" @@ -282,6 +360,14 @@ def output_links(self) -> list[dict[str, Any]]: def label_syncs(self) -> list[list[dict[str, Any]]]: return self._server.label_syncs + @property + def label_mutation_attempts(self) -> list[dict[str, Any]]: + return self._server.label_mutation_attempts + + @property + def label_mutations(self) -> list[dict[str, Any]]: + return self._server.label_mutations + def set_artifact_dag(self, digest: str, dag: dict[str, Any]) -> None: self._server.artifact_dags_by_digest[digest] = dag diff --git a/tests/integration/test_label_push_cli_integration.py b/tests/integration/test_label_push_cli_integration.py new file mode 100644 index 0000000..3c197fc --- /dev/null +++ b/tests/integration/test_label_push_cli_integration.py @@ -0,0 +1,258 @@ +"""Product-path coverage for explicit remote label push flows.""" + +from __future__ import annotations + +import json +import re +import sqlite3 +import subprocess +from pathlib import Path + +import pytest + +from .fake_glaas import FakeGlaasServer + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def fake_glaas_publish_server() -> FakeGlaasServer: + with FakeGlaasServer() as server: + yield server + + +def _configure_label_push_repo(repo: Path, roar_cli, fake_glaas_url: str) -> dict[str, str]: + subprocess.run( + ["git", "remote", "add", "origin", "https://github.com/test/repo.git"], + cwd=repo, + capture_output=True, + check=True, + ) + xdg_config_home = repo / ".xdg" + token_file = repo / "token-file.json" + token_file.write_text( + json.dumps( + { + "version": 1, + "provider": "treqs-cognito", + "access_token": "test-access-token", + "user": { + "sub": "treqs-user-123", + "db_user_id": "user-123", + "email": "trevor@example.com", + "username": "trevor", + }, + } + ), + encoding="utf-8", + ) + env = { + "XDG_CONFIG_HOME": str(xdg_config_home), + "GLAAS_API_URL": fake_glaas_url, + "ROAR_ENABLE_EXPERIMENTAL_ACCOUNT_COMMANDS": "1", + } + roar_cli("login", "--token-file", str(token_file), env_overrides=env) + roar_cli("projects", "link", "proj-test", env_overrides=env) + roar_cli("config", "set", "glaas.url", fake_glaas_url, env_overrides=env) + roar_cli("config", "set", "glaas.web_url", fake_glaas_url, env_overrides=env) + return env + + +def _create_tracked_output( + repo: Path, + *, + roar_cli, + git_commit, + python_exe: str, + env_overrides: dict[str, str] | None = None, +) -> None: + (repo / "input.csv").write_text("id,value\n1,foo\n2,bar\n", encoding="utf-8") + (repo / "preprocess.py").write_text( + "import sys\n" + "from pathlib import Path\n\n" + "src = Path(sys.argv[1]).read_text(encoding='utf-8').upper()\n" + "Path(sys.argv[2]).write_text(src, encoding='utf-8')\n", + encoding="utf-8", + ) + git_commit("Add preprocess script") + + run_result = roar_cli( + "run", + python_exe, + "preprocess.py", + "input.csv", + "processed.csv", + env_overrides=env_overrides, + ) + assert run_result.returncode == 0 + git_commit("Track processed output") + + +def _artifact_hash_for(repo: Path, roar_cli, target: str) -> str: + payload = json.loads(roar_cli("lineage", target).stdout) + return str(payload["artifact"]["hash"]) + + +def _job_uid_for(repo: Path, roar_cli, target: str) -> str: + payload = json.loads(roar_cli("lineage", target).stdout) + return str(payload["jobs"][0]["job_uid"]) + + +def _status_dag_hash(roar_cli, *, env_overrides: dict[str, str]) -> str: + result = roar_cli("status", env_overrides=env_overrides) + assert result.returncode == 0 + match = re.search(r"DAG hash:\s+([0-9a-f]{64})", result.stdout) + assert match is not None, result.stdout + return match.group(1) + + +def _active_local_session_hash(repo: Path) -> str: + db_path = repo / ".roar" / "roar.db" + with sqlite3.connect(db_path) as conn: + row = conn.execute( + "SELECT hash FROM sessions WHERE is_active = 1 ORDER BY id DESC LIMIT 1" + ).fetchone() + assert row is not None and isinstance(row[0], str) and row[0] + return str(row[0]) + + +def test_label_push_artifact_patches_existing_remote_labels( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + fake_glaas_publish_server: FakeGlaasServer, +) -> None: + env = _configure_label_push_repo(temp_git_repo, roar_cli, fake_glaas_publish_server.base_url) + _create_tracked_output( + temp_git_repo, + roar_cli=roar_cli, + git_commit=git_commit, + python_exe=python_exe, + env_overrides=env, + ) + + roar_cli( + "label", + "set", + "artifact", + "processed.csv", + "owner=ml", + "stage=raw", + env_overrides=env, + ) + roar_cli("register", "processed.csv", "--yes", env_overrides=env) + roar_cli( + "label", + "set", + "artifact", + "processed.csv", + "stage=gold", + env_overrides=env, + ) + + artifact_hash = _artifact_hash_for(temp_git_repo, roar_cli, "processed.csv") + result = roar_cli("label", "push", "artifact", "processed.csv", env_overrides=env) + + assert result.returncode == 0 + assert "Pushed remote labels (version 2):" in result.stdout + assert "owner=ml" in result.stdout + assert "stage=gold" in result.stdout + assert fake_glaas_publish_server.label_mutations == [ + { + "entity_type": "artifact", + "artifact_hash": artifact_hash, + "metadata": {"owner": "ml", "stage": "gold"}, + } + ] + + +def test_label_push_job_uses_canonical_remote_session_hash_and_omits_system_labels( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + fake_glaas_publish_server: FakeGlaasServer, +) -> None: + env = _configure_label_push_repo(temp_git_repo, roar_cli, fake_glaas_publish_server.base_url) + _create_tracked_output( + temp_git_repo, + roar_cli=roar_cli, + git_commit=git_commit, + python_exe=python_exe, + env_overrides=env, + ) + + roar_cli("label", "set", "job", "@1", "phase=preprocess", env_overrides=env) + session_ref = _active_local_session_hash(temp_git_repo) + roar_cli("register", session_ref, "--yes", env_overrides=env) + expected_session_hash = fake_glaas_publish_server.session_registrations[0]["hash"] + roar_cli("label", "set", "job", "@1", "phase=train", env_overrides=env) + + job_uid = _job_uid_for(temp_git_repo, roar_cli, "processed.csv") + synced_job_labels = [ + label + for batch in fake_glaas_publish_server.label_syncs + for label in batch + if label.get("entity_type") == "job" and label.get("job_uid") == job_uid + ] + assert len(synced_job_labels) == 1 + assert synced_job_labels[0]["session_hash"] == expected_session_hash + result = roar_cli("label", "push", "job", "@1", check=False, env_overrides=env) + + assert fake_glaas_publish_server.label_mutation_attempts == [ + { + "entity_type": "job", + "session_hash": expected_session_hash, + "job_uid": job_uid, + "metadata": {"phase": "train"}, + } + ] + assert result.returncode == 0 + assert "Pushed remote labels (version 2):" in result.stdout + assert "phase=train" in result.stdout + assert fake_glaas_publish_server.label_mutations == [ + { + "entity_type": "job", + "session_hash": expected_session_hash, + "job_uid": job_uid, + "metadata": {"phase": "train"}, + } + ] + assert any( + entry.get("path") == "/api/v1/labels/current" + and entry.get("authorization") == "Bearer test-access-token" + for entry in fake_glaas_publish_server.auth_headers + ) + + +def test_label_push_requires_existing_remote_current_labels( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + fake_glaas_publish_server: FakeGlaasServer, +) -> None: + env = _configure_label_push_repo(temp_git_repo, roar_cli, fake_glaas_publish_server.base_url) + _create_tracked_output( + temp_git_repo, + roar_cli=roar_cli, + git_commit=git_commit, + python_exe=python_exe, + env_overrides=env, + ) + + roar_cli( + "label", + "set", + "artifact", + "processed.csv", + "owner=ml", + "stage=gold", + env_overrides=env, + ) + result = roar_cli("label", "push", "artifact", "processed.csv", check=False, env_overrides=env) + + assert result.returncode != 0 + assert "Remote label push failed: HTTP 404: Label not found" in result.stderr + assert fake_glaas_publish_server.label_mutations == []