Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ roar label show artifact ./outputs/model.pt
# Show label history (all versions)
roar label history dag current
roar label history artifact <artifact-hash>

# Push current local user-managed labels to GLaaS
roar label push job @2
roar label push artifact ./outputs/model.pt
```

**Entity targets:**
Expand All @@ -286,7 +290,7 @@ roar label history artifact <artifact-hash>
- `job`: step ref (`@N` or `@BN`) or job UID
- `artifact`: file path or artifact hash

Labels are stored locally and included in lineage registration/publish flows to GLaaS when supported by the configured server.
Labels are stored locally by default. You can explicitly push the current local user-managed labels for one target to GLaaS with `roar label push ...`, and labels are also included in lineage registration/publish flows when supported by the configured server.

### `roar register`

Expand Down
134 changes: 134 additions & 0 deletions roar/application/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..core.label_origins import LABEL_ORIGIN_USER, build_current_key_origins
from ..db.context import DatabaseContext
from .label_rendering import flatten_label_metadata
from .publish.lineage import LineageCollector
from .system_labels import is_reserved_system_label_path, strip_reserved_system_labels


Expand Down Expand Up @@ -44,6 +45,17 @@ class _LabelSyncDatabaseContext(Protocol):
def labels(self) -> Any: ...


class _RemoteLabelMutationDatabaseContext(Protocol):
@property
def sessions(self) -> Any: ...

@property
def jobs(self) -> Any: ...

@property
def artifacts(self) -> Any: ...


def parse_label_pairs(pairs: tuple[str, ...]) -> dict[str, Any]:
"""Parse ``key=value`` pairs into nested metadata."""
metadata: dict[str, Any] = {}
Expand Down Expand Up @@ -233,6 +245,92 @@ def _reject_reserved_keys(metadata: dict[str, Any]) -> None:
raise ValueError(f"Reserved label keys cannot be set manually: {joined}")


def build_remote_label_mutation_payload(
db_ctx: _RemoteLabelMutationDatabaseContext,
*,
roar_dir: Path,
target: LabelTargetRef,
metadata: dict[str, Any],
) -> dict[str, Any]:
"""Build a GLaaS label-mutation payload for one local target."""
if target.entity_type == "dag":
if target.session_id is None:
raise ValueError("DAG target is missing a local session id.")
return {
"entity_type": "dag",
"session_hash": _canonical_remote_session_hash(
db_ctx,
roar_dir=roar_dir,
session_id=int(target.session_id),
),
"metadata": metadata,
}

if target.entity_type == "job":
if target.job_id is None:
raise ValueError("Job target is missing a local job id.")
job = db_ctx.jobs.get(int(target.job_id))
if not isinstance(job, dict):
raise ValueError("Job not found.")
session_id = job.get("session_id")
job_uid = job.get("job_uid")
if not isinstance(session_id, int):
raise ValueError("Job target is missing a local session id.")
if not isinstance(job_uid, str) or not job_uid:
raise ValueError("Job target is missing a job UID.")
return {
"entity_type": "job",
"session_hash": _canonical_remote_session_hash(
db_ctx,
roar_dir=roar_dir,
session_id=session_id,
),
"job_uid": job_uid,
"metadata": metadata,
}

if target.entity_type == "artifact":
if target.artifact_id is None:
raise ValueError("Artifact target is missing a local artifact id.")
artifact = db_ctx.artifacts.get(str(target.artifact_id))
if not isinstance(artifact, dict):
raise ValueError("Artifact not found.")
artifact_hash = artifact.get("hash")
if not isinstance(artifact_hash, str) or not artifact_hash:
hashes = db_ctx.artifacts.get_hashes(str(target.artifact_id))
artifact_hash = next(
(
item.get("digest")
for item in hashes
if isinstance(item, dict)
and item.get("algorithm") == "blake3"
and isinstance(item.get("digest"), str)
and item.get("digest")
),
None,
)
if not isinstance(artifact_hash, str) or not artifact_hash:
artifact_hash = next(
(
item.get("digest")
for item in hashes
if isinstance(item, dict)
and isinstance(item.get("digest"), str)
and item.get("digest")
),
None,
)
if not isinstance(artifact_hash, str) or not artifact_hash:
raise ValueError("Artifact target is missing a content hash.")
return {
"entity_type": "artifact",
"artifact_hash": artifact_hash,
"metadata": metadata,
}

raise ValueError(f"Unsupported label entity type: {target.entity_type}")


def collect_label_sync_payloads(
db_ctx: _LabelSyncDatabaseContext,
*,
Expand Down Expand Up @@ -308,6 +406,42 @@ def collect_label_sync_payloads(
return payloads


def _canonical_remote_session_hash(
db_ctx: _RemoteLabelMutationDatabaseContext,
*,
roar_dir: Path,
session_id: int,
) -> str:
from ..core.logging import get_logger
from .publish.runtime import build_publish_runtime
from .publish.service import prepare_register_preview_execution

session = db_ctx.sessions.get(session_id)
if not isinstance(session, dict):
raise ValueError("Session not found.")

lineage = LineageCollector().collect_session(session_id, roar_dir)
if not getattr(lineage, "jobs", None):
raise ValueError(
"Selected session has no tracked steps. Run 'roar run' or 'roar build' first."
)

runtime = build_publish_runtime(
start_dir=str(roar_dir.parent),
allow_public_without_binding=True,
)
prepared = prepare_register_preview_execution(
runtime=runtime,
roar_dir=roar_dir,
cwd=roar_dir.parent,
session_id=session_id,
session_hash_override=None,
logger=get_logger(),
lineage=lineage,
)
return str(prepared.session_hash)


def _assign_nested(root: dict[str, Any], path: list[str], value: Any) -> None:
cursor = root
for key in path[:-1]:
Expand Down
3 changes: 3 additions & 0 deletions roar/application/query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"DagQueryRequest": ".requests",
"LabelCopyRequest": ".requests",
"LabelHistoryRequest": ".requests",
"LabelPushRequest": ".requests",
"LabelSetRequest": ".requests",
"LabelShowRequest": ".requests",
"LineageQueryRequest": ".requests",
Expand All @@ -23,6 +24,7 @@
"StatusSummary": ".results",
"build_copy_labels_summary": ".label",
"build_label_history_summary": ".label",
"build_push_labels_summary": ".label",
"build_set_labels_summary": ".label",
"build_show_labels_summary": ".label",
"copy_labels": ".label",
Expand All @@ -32,6 +34,7 @@
"render_log": ".log",
"render_show": ".show",
"render_status": ".status",
"push_labels": ".label",
"set_labels": ".label",
"show_labels": ".label",
}
Expand Down
51 changes: 50 additions & 1 deletion roar/application/query/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from __future__ import annotations

from ...db.context import create_database_context
from ...integrations.glaas import GlaasClient
from ..label_rendering import flatten_label_metadata
from ..labels import LabelService, parse_label_pairs
from ..labels import LabelService, build_remote_label_mutation_payload, parse_label_pairs
from ..system_labels import strip_reserved_system_labels
from .requests import (
LabelCopyRequest,
LabelHistoryRequest,
LabelPushRequest,
LabelSetRequest,
LabelShowRequest,
)
Expand Down Expand Up @@ -64,6 +67,52 @@ def build_copy_labels_summary(request: LabelCopyRequest) -> LabelCurrentSummary:
return _build_current_summary(result.metadata, heading=heading)


def push_labels(request: LabelPushRequest) -> str:
"""Push the current local user-managed label document for a target to GLaaS."""
return build_push_labels_summary(request).render()


def build_push_labels_summary(request: LabelPushRequest) -> LabelCurrentSummary:
"""Build the typed summary for a remote label push operation."""
with create_database_context(request.roar_dir) as db_ctx:
service = LabelService(db_ctx, request.cwd)
resolved = service.resolve_target(request.entity_type, request.target)
metadata = strip_reserved_system_labels(service.current_metadata(resolved))
if not metadata:
raise ValueError(f"No local user-managed labels to push for {request.target}.")
payload = build_remote_label_mutation_payload(
db_ctx,
roar_dir=request.roar_dir,
target=resolved,
metadata=metadata,
)

client = GlaasClient(start_dir=str(request.cwd), allow_public_without_binding=True)
result, error = client.patch_current_label(payload)
if error:
raise ValueError(f"Remote label push failed: {error}")

remote_metadata = metadata
version: int | None = None
if isinstance(result, dict):
returned_metadata = result.get("metadata")
if isinstance(returned_metadata, dict):
remote_metadata = strip_reserved_system_labels(returned_metadata)
raw_version = result.get("version")
if raw_version is not None:
try:
version = int(raw_version)
except (TypeError, ValueError):
version = None

heading = (
f"Pushed remote labels (version {version}):"
if version is not None
else "Pushed remote labels:"
)
return _build_current_summary(remote_metadata, heading=heading)


def show_labels(request: LabelShowRequest) -> str:
"""Show the current local label document for a target."""
return build_show_labels_summary(request).render()
Expand Down
8 changes: 8 additions & 0 deletions roar/application/query/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ class LabelHistoryRequest:
target: str


@dataclass(frozen=True)
class LabelPushRequest:
roar_dir: Path
cwd: Path
entity_type: str
target: str


DiffFormat = Literal["summary", "category", "dag"]


Expand Down
26 changes: 25 additions & 1 deletion roar/cli/commands/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
roar label cp <dag|job|artifact> <source> <dag|job|artifact> <dest>
roar label show <dag|job|artifact> <target>
roar label history <dag|job|artifact> <target>
roar label push <dag|job|artifact> <target>
"""

from __future__ import annotations
Expand All @@ -14,6 +15,7 @@

from ...application.query.label import (
copy_labels,
push_labels,
set_labels,
show_labels,
)
Expand All @@ -23,6 +25,7 @@
from ...application.query.requests import (
LabelCopyRequest,
LabelHistoryRequest,
LabelPushRequest,
LabelSetRequest,
LabelShowRequest,
)
Expand All @@ -35,7 +38,7 @@
@click.group("label", invoke_without_command=True)
@click.pass_context
def label(ctx: click.Context) -> None:
"""Manage local labels for DAGs, jobs, and artifacts."""
"""Manage local labels and push user-managed label updates to GLaaS."""
if ctx.invoked_subcommand is None:
click.echo(ctx.get_help())

Expand Down Expand Up @@ -115,6 +118,27 @@ def label_show(ctx: RoarContext, entity_type: str, target: str) -> None:
click.echo(rendered)


@label.command("push")
@click.argument("entity_type", type=_ENTITY_TYPE)
@click.argument("target")
@click.pass_obj
@require_init
def label_push(ctx: RoarContext, entity_type: str, target: str) -> None:
"""Push the current local user-managed labels for a target to GLaaS."""
try:
rendered = push_labels(
LabelPushRequest(
roar_dir=ctx.roar_dir,
cwd=ctx.cwd,
entity_type=entity_type,
target=target,
)
)
except (ValueError, RuntimeError) as exc:
raise click.ClickException(str(exc)) from exc
click.echo(rendered)


@label.command("history")
@click.argument("entity_type", type=_ENTITY_TYPE)
@click.argument("target")
Expand Down
7 changes: 7 additions & 0 deletions roar/integrations/glaas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ def sync_labels(
return {"created": 0, "updated": 0, "unchanged": 0}, None
return self._request("POST", "/api/v1/labels/sync", {"labels": labels})

def patch_current_label(
self,
label: dict[str, Any],
) -> tuple[dict | None, str | None]:
"""Patch the current remote label document for one lineage target."""
return self._request("PATCH", "/api/v1/labels/current", label)

def register_job_inputs(
self,
session_hash: str,
Expand Down
Loading
Loading