Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 3 additions & 29 deletions roar/application/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@

from ..core.label_origins import LABEL_ORIGIN_USER, build_current_key_origins
from ..db.context import DatabaseContext
from ..execution.recording.dataset_metadata import AUTO_DATASET_LABEL_KEYS
from .label_rendering import flatten_label_metadata

RESERVED_LABEL_KEYS = set(AUTO_DATASET_LABEL_KEYS)
from .system_labels import is_reserved_system_label_path, strip_reserved_system_labels


@dataclass(frozen=True)
Expand Down Expand Up @@ -195,7 +193,7 @@ def copy_metadata(
destination: LabelTargetRef,
) -> LabelWriteResult:
"""Copy current source metadata into the destination as a patch."""
source_metadata = _remove_reserved_paths(self.current_metadata(source), RESERVED_LABEL_KEYS)
source_metadata = strip_reserved_system_labels(self.current_metadata(source))
destination_metadata = self.current_metadata(destination)
merged = _deep_merge(destination_metadata, source_metadata)
if merged == destination_metadata:
Expand Down Expand Up @@ -229,7 +227,7 @@ def copy_metadata(
@staticmethod
def _reject_reserved_keys(metadata: dict[str, Any]) -> None:
keys = {key for key, _value in flatten_label_metadata(metadata)}
reserved = sorted(keys.intersection(RESERVED_LABEL_KEYS))
reserved = sorted(key for key in keys if is_reserved_system_label_path(key))
if reserved:
joined = ", ".join(reserved)
raise ValueError(f"Reserved label keys cannot be set manually: {joined}")
Expand Down Expand Up @@ -351,27 +349,3 @@ def _deep_merge(current: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any
else:
merged[key] = value
return merged


def _remove_reserved_paths(metadata: dict[str, Any], reserved_paths: set[str]) -> dict[str, Any]:
cleaned = json.loads(json.dumps(metadata))
for path in reserved_paths:
_remove_nested(cleaned, path.split("."))
return cleaned


def _remove_nested(root: dict[str, Any], path: list[str]) -> None:
if not path:
return
key = path[0]
if key not in root:
return
if len(path) == 1:
root.pop(key, None)
return
child = root.get(key)
if not isinstance(child, dict):
return
_remove_nested(child, path[1:])
if not child:
root.pop(key, None)
38 changes: 36 additions & 2 deletions roar/application/publish/put_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
normalize_registration_source_type,
prepare_batch_registration_artifacts,
register_publish_lineage,
sync_publish_labels,
)
from ...application.system_labels import refresh_job_system_labels
from ...core.interfaces.registration import GitContext
from ...core.logging import get_logger
from ...db.context import DatabaseContext
Expand Down Expand Up @@ -398,6 +400,7 @@ def put_prepared(
job_type="put",
exit_code=0,
)
refresh_job_system_labels(self._db, job_id=job_id)
self._logger.debug(
"Job created: id=%s, uid=%s, step=%d",
job_id,
Expand All @@ -414,7 +417,7 @@ def put_prepared(
)
with Spinner("Finalizing lineage links...") as spin:
spin.update("Registering put job...")
self._register_put_job_with_glaas(
put_job_registered = self._register_put_job_with_glaas(
coordinator=coordinator,
command=command,
session_hash=session_hash_value,
Expand All @@ -434,6 +437,15 @@ def put_prepared(
composite_registrations=composite_registrations,
registration_errors=registration_result.errors,
)
if put_job_registered:
spin.update("Syncing put job labels...")
self._sync_put_job_labels_with_glaas(
glaas_client=client,
session_hash=session_hash_value,
job_id=job_id,
job_uid=job_uid,
registration_errors=registration_result.errors,
)

registration_error = (
"; ".join(registration_result.errors) if registration_result.errors else None
Expand Down Expand Up @@ -552,7 +564,7 @@ def _register_put_job_with_glaas(
step_number: int,
metadata_json: str,
registration_errors: list[str],
) -> None:
) -> bool:
"""Create the put sink node in GLaaS."""
self._logger.debug("Registering put job with GLaaS: job_uid=%s, job_type=put", job_uid)
put_job_result = coordinator.job_service.create_job(
Expand All @@ -572,6 +584,28 @@ def _register_put_job_with_glaas(
self._logger.debug("Put job GLaaS registration failed: %s", put_job_result.error)
if put_job_result.error:
registration_errors.append(f"Put job: {put_job_result.error}")
return False
return True

def _sync_put_job_labels_with_glaas(
self,
*,
glaas_client: Any,
session_hash: str,
job_id: int,
job_uid: str,
registration_errors: list[str],
) -> None:
"""Sync the local current label document for the publish-time put job."""
sync_publish_labels(
glaas_client=glaas_client,
db_ctx=self._db,
session_id=None,
session_hash=session_hash,
jobs=[{"id": job_id, "job_uid": job_uid}],
artifacts=[],
errors=registration_errors,
)

def _link_put_job_artifacts_with_glaas(
self,
Expand Down
4 changes: 2 additions & 2 deletions roar/application/publish/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,13 @@ def sync_publish_labels(
*,
glaas_client: GlaasClient,
db_ctx: Any,
session_id: int,
session_id: int | None,
session_hash: str,
jobs: list[dict[str, Any]],
artifacts: list[dict[str, Any]],
errors: list[str] | None = None,
) -> None:
"""Sync publish labels to GLaaS and record any error on the supplied list."""
"""Sync current local labels for published entities to GLaaS."""
payloads = collect_label_sync_payloads(
db_ctx,
session_id=session_id,
Expand Down
5 changes: 4 additions & 1 deletion roar/application/query/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
remote_artifact_fallback_enabled,
run_local_then_remote_lookup,
)
from ..system_labels import omit_display_system_labels
from .requests import ShowQueryRequest
from .results import (
ShowArtifactComponentSummary,
Expand Down Expand Up @@ -450,4 +451,6 @@ def _current_label_metadata(
return None

metadata = current.get("metadata")
return metadata if isinstance(metadata, dict) else None
if not isinstance(metadata, dict):
return None
return cast(dict[str, Any] | None, omit_display_system_labels(metadata))
Loading
Loading