From 605caf1d0304f748f74a0cb1e923728b19c13d21 Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Fri, 17 Apr 2026 20:25:28 +0000 Subject: [PATCH 1/7] feat(publish): stage remote writes via registration sessions --- roar/application/publish/put_execution.py | 349 ++++++++++++++++++ roar/application/publish/put_preparation.py | 2 + .../application/publish/register_execution.py | 180 ++++++--- .../publish/register_preparation.py | 2 + roar/application/publish/session.py | 30 ++ roar/core/interfaces/registration.py | 21 +- roar/integrations/glaas/client.py | 140 +++++++ .../glaas/registration/coordinator.py | 132 +++++++ roar/integrations/glaas/registration/job.py | 285 ++++++++++++++ .../glaas/registration/session.py | 79 ++++ .../publish/test_put_preparation.py | 12 +- .../publish/test_register_preparation.py | 18 +- tests/application/publish/test_session.py | 57 ++- tests/integration/fake_glaas.py | 244 +++++++++++- .../test_public_publish_intent_cli.py | 310 ++++++++++++---- tests/integration/test_put_cli_integration.py | 28 +- .../integration/test_register_dry_run_cli.py | 17 +- 17 files changed, 1741 insertions(+), 165 deletions(-) diff --git a/roar/application/publish/put_execution.py b/roar/application/publish/put_execution.py index 60c3ace7..1f74b6e9 100644 --- a/roar/application/publish/put_execution.py +++ b/roar/application/publish/put_execution.py @@ -20,6 +20,7 @@ from ...application.publish.metadata import build_put_operation_metadata_json from ...application.publish.put_preparation import PreparedPutExecution from ...application.publish.registration import ( + normalize_registration_hashes, normalize_registration_source_type, prepare_batch_registration_artifacts, register_publish_lineage, @@ -164,6 +165,7 @@ def put_prepared( client = prepared.glaas_client session_id = prepared.session_id session_hash = prepared.session_hash + registration_session_id = prepared.registration_session_id git_context = prepared.git_context resolved = prepared.resolved_sources destination_type = prepared.destination_type @@ -263,6 +265,30 @@ def put_prepared( # Register lineage with GLaaS (session already registered above) coordinator = self._registration_coordinator or RegistrationCoordinator() + if registration_session_id: + return self._put_prepared_with_registration_session( + client=client, + coordinator=coordinator, + registration_session_id=registration_session_id, + session_id=session_id, + fallback_session_hash=session_hash or "", + git_context=git_context, + sources=sources, + message=message, + uploads=uploads, + uploaded_files=uploaded_files, + artifacts_info=artifacts_info, + lineage=lineage, + resolved=resolved, + hashes_by_path=hashes_by_path, + destination_type=destination_type, + composite_source_type=composite_source_type, + dataset_identifiers=dataset_identifiers, + additional_composite_roots=additional_composite_roots, + git_commit=git_commit, + git_tag=git_tag, + ) + with Spinner("Publishing lineage to GLaaS...") as spin: pre_registration_errors: list[str] = [] spin.update("Registering lineage composites...") @@ -491,6 +517,329 @@ def put_prepared( composites_registered=composite_result_items, ) + def _put_prepared_with_registration_session( + self, + *, + client: Any, + coordinator: RegistrationCoordinator, + registration_session_id: str, + session_id: int, + fallback_session_hash: str, + git_context: GitContext, + sources: list[str], + message: str, + uploads: list[_UploadedArtifact], + uploaded_files: list[PutUploadedFile], + artifacts_info: list[tuple[str, str]], + lineage: Any, + resolved: list[Any], + hashes_by_path: dict[str, str], + destination_type: str, + composite_source_type: str | None, + dataset_identifiers: list[dict[str, Any]], + additional_composite_roots: dict[Path, list[Any]], + git_commit: str | None, + git_tag: str | None, + ) -> PutResult: + """Execute the authenticated Phase 5 publish flow via registration sessions.""" + session_hash = fallback_session_hash + session_url: str | None = None + registration_errors: list[str] = [] + lineage_composite_registrations: list[dict[str, Any]] = [] + composite_registrations: list[dict[str, Any]] = [] + source_str = " ".join(sources) if sources else "(session outputs)" + command = f'roar put {source_str} -m "{message}"' + source_type = normalize_registration_source_type(destination_type) + + provisional_metadata_json = build_put_operation_metadata_json( + message=message, + destination=self._destination, + destination_type=destination_type, + artifact_urls={u.artifact_id: u.remote_url for u in uploads}, + composite_registrations=[], + lineage_composite_registrations=[], + dataset_identifiers=dataset_identifiers, + git_commit=git_commit, + git_tag=git_tag, + timestamp=time.time(), + ) + + step_number = self._db.sessions.get_next_step_number(session_id) + job_id, job_uid = self._db.jobs.create( + command=command, + timestamp=time.time(), + session_id=session_id, + step_number=step_number, + metadata=provisional_metadata_json, + execution_backend="local", + execution_role="host", + job_type="put", + exit_code=0, + ) + self._logger.debug( + "Put job created before registration-session finalize: id=%s, uid=%s, step=%d", + job_id, + job_uid, + step_number, + ) + + composite_results_for_linking = build_publish_composite_results( + resolved_sources=resolved, + hashes_by_path=hashes_by_path, + session_hash=fallback_session_hash, + source_type=composite_source_type, + additional_composite_roots=additional_composite_roots, + composite_builder=self._composite_builder, + ) + + put_job_registered = False + put_job_links_succeeded = False + with Spinner("Publishing lineage to GLaaS...") as spin: + spin.update("Staging lineage jobs and artifacts...") + registration_result = coordinator.register_lineage_under_registration_session( + registration_session_id=registration_session_id, + git_context=git_context, + jobs=lineage.jobs, + ) + registration_errors.extend(registration_result.errors) + + if registration_result.jobs_failed == 0 and registration_result.links_failed == 0: + spin.update("Staging put job...") + put_job_result = coordinator.job_service.create_job_under_registration_session( + command=command, + timestamp=time.time(), + registration_session_id=registration_session_id, + job_uid=job_uid, + git_commit=git_context.commit or "", + git_branch=git_context.branch or "", + duration_seconds=0.0, + exit_code=0, + job_type="put", + step_number=step_number, + metadata=provisional_metadata_json, + ) + if not put_job_result.success: + if put_job_result.error: + registration_errors.append(f"Put job: {put_job_result.error}") + else: + put_job_registered = True + + if put_job_registered: + spin.update("Linking put job artifacts...") + put_inputs = self._build_registration_session_put_job_inputs( + uploads=uploads, + lineage_artifacts=lineage.artifacts, + source_type=source_type, + ) + put_outputs = self._build_registration_session_put_job_outputs( + composite_results_for_linking + ) + link_result = coordinator.job_service.link_job_artifacts_under_registration_session( + registration_session_id=registration_session_id, + job_uid=job_uid, + inputs=put_inputs, + outputs=put_outputs, + ) + put_job_links_succeeded = link_result.success + if not link_result.success and link_result.error: + registration_errors.append(f"Put job links: {link_result.error}") + + if put_job_registered and put_job_links_succeeded: + spin.update("Finalizing lineage...") + finalize_result = coordinator.session_service.finalize_registration_session( + registration_session_id=registration_session_id, + git_context=git_context, + ) + if not finalize_result.success: + registration_errors.append( + f"Registration session finalize failed: {finalize_result.error}" + ) + else: + session_hash = finalize_result.session_hash + session_url = finalize_result.session_url + + spin.update("Registering lineage composites...") + lineage_composite_registrations = ( + preregister_put_lineage_composites_with_glaas( + db_ctx=self._db, + glaas_client=client, + lineage_artifacts=lineage.artifacts, + session_hash=session_hash, + registration_errors=registration_errors, + dataset_identifiers=dataset_identifiers, + composite_builder=self._composite_builder, + logger=self._logger, + ) + ) + + composite_results = build_publish_composite_results( + resolved_sources=resolved, + hashes_by_path=hashes_by_path, + session_hash=session_hash, + source_type=composite_source_type, + additional_composite_roots=additional_composite_roots, + composite_builder=self._composite_builder, + ) + spin.update("Registering output composites...") + composite_registrations = register_put_composites_with_glaas( + db_ctx=self._db, + glaas_client=client, + composite_results=composite_results, + registration_errors=registration_errors, + dataset_identifiers=dataset_identifiers, + logger=self._logger, + ) + + metadata_json = build_put_operation_metadata_json( + message=message, + destination=self._destination, + destination_type=destination_type, + artifact_urls={u.artifact_id: u.remote_url for u in uploads}, + composite_registrations=composite_registrations, + lineage_composite_registrations=lineage_composite_registrations, + dataset_identifiers=dataset_identifiers, + git_commit=git_commit, + git_tag=git_tag, + timestamp=time.time(), + ) + self._db.jobs.update_metadata(job_id, metadata_json) + refresh_job_system_labels(self._db, job_id=job_id) + self._link_local_put_job_artifacts( + job_id=job_id, + artifacts_info=artifacts_info, + lineage_artifacts=lineage.artifacts, + composite_registrations=composite_registrations, + ) + + if session_hash and put_job_registered: + self._sync_put_job_labels_with_glaas( + glaas_client=client, + session_hash=session_hash, + job_id=job_id, + job_uid=job_uid, + registration_errors=registration_errors, + ) + + composite_result_items = [ + PutCompositeRegistration( + root_path=(str(item["root_path"]) if item.get("root_path") is not None else None), + hash=str(item["hash"]) if item.get("hash") is not None else None, + component_count_stored=( + int(item["component_count_stored"]) + if item.get("component_count_stored") is not None + else None + ), + component_count_total=( + int(item["component_count_total"]) + if item.get("component_count_total") is not None + else None + ), + artifact_id=( + str(item["artifact_id"]) if item.get("artifact_id") is not None else None + ), + registered=bool(item["registered"]) if item.get("registered") is not None else None, + local_persisted=( + bool(item["local_persisted"]) + if item.get("local_persisted") is not None + else None + ), + local_error=( + str(item["local_error"]) if item.get("local_error") is not None else None + ), + ) + for item in composite_registrations + ] + + registration_error = "; ".join(registration_errors) if registration_errors else None + if registration_error: + self._logger.debug("Registration errors: %s", registration_error) + return PutResult( + success=False, + job_id=job_id, + job_uid=job_uid, + session_hash=session_hash, + session_url=session_url, + uploaded_files=uploaded_files, + composites_registered=composite_result_items, + error=registration_error, + ) + + return PutResult( + success=True, + job_id=job_id, + job_uid=job_uid, + session_hash=session_hash, + session_url=session_url, + uploaded_files=uploaded_files, + composites_registered=composite_result_items, + ) + + @staticmethod + def _build_registration_session_put_job_inputs( + *, + uploads: list[_UploadedArtifact], + lineage_artifacts: list[dict[str, Any]], + source_type: str | None, + ) -> list[dict[str, Any]]: + """Build direct hash-addressed inputs for a staged put job.""" + inputs: list[dict[str, Any]] = [] + seen_rows: set[tuple[str, str]] = set() + + for uploaded in uploads: + try: + size = max(0, int(Path(uploaded.local_path).stat().st_size)) + except (OSError, TypeError, ValueError): + size = 0 + row = { + "artifact_hash": uploaded.hash, + "path": uploaded.local_path, + "size": size, + "source_type": source_type, + } + row_key = (uploaded.hash, uploaded.local_path) + if row_key not in seen_rows: + seen_rows.add(row_key) + inputs.append(row) + + for artifact in lineage_artifacts: + if artifact.get("kind") != "composite": + continue + artifact_path = artifact.get("first_seen_path") or artifact.get("path") + if not isinstance(artifact_path, str) or not artifact_path: + continue + hashes = normalize_registration_hashes(artifact) + digest = hashes[0]["digest"] if hashes else None + if not digest: + continue + row_key = (digest, artifact_path) + if row_key in seen_rows: + continue + seen_rows.add(row_key) + inputs.append({"artifact_hash": digest, "path": artifact_path}) + + return inputs + + @staticmethod + def _build_registration_session_put_job_outputs( + composite_results: list[Any], + ) -> list[dict[str, Any]]: + """Build direct hash-addressed outputs for a staged put job.""" + outputs: list[dict[str, Any]] = [] + seen_rows: set[tuple[str, str]] = set() + for result in composite_results: + digest = getattr(result, "digest", None) + root_path = getattr(result, "root_path", None) + if not isinstance(digest, str) or not digest: + continue + if not isinstance(root_path, str) or not root_path: + continue + row_key = (digest, root_path) + if row_key in seen_rows: + continue + seen_rows.add(row_key) + outputs.append({"artifact_hash": digest, "path": root_path}) + return outputs + @staticmethod def _build_uploaded_artifacts_for_registration( uploads: list[_UploadedArtifact], diff --git a/roar/application/publish/put_preparation.py b/roar/application/publish/put_preparation.py index 26f55d6c..e8a2d590 100644 --- a/roar/application/publish/put_preparation.py +++ b/roar/application/publish/put_preparation.py @@ -35,6 +35,7 @@ class PreparedPutExecution: resolved_sources: list[ResolvedSource] destination_type: str composite_source_type: str | None + registration_session_id: str | None = None dataset_identifiers: list[dict[str, Any]] = field(default_factory=list) additional_composite_roots: dict[Path, list[ResolvedSource]] = field(default_factory=dict) @@ -104,6 +105,7 @@ def prepare_put_execution( session_hash=publish_session.session_hash, session_url=publish_session.session_url, git_context=git_context, + registration_session_id=publish_session.registration_session_id, resolved_sources=resolved_sources, destination_type=destination_type, composite_source_type=composite_source_type, diff --git a/roar/application/publish/register_execution.py b/roar/application/publish/register_execution.py index 8144adb5..92f19aa7 100644 --- a/roar/application/publish/register_execution.py +++ b/roar/application/publish/register_execution.py @@ -89,6 +89,13 @@ def register_publish_lineage(*args: Any, **kwargs: Any) -> Any: return _register_publish_lineage(*args, **kwargs) +def sync_publish_labels(*args: Any, **kwargs: Any) -> Any: + """Load label sync helpers only for real registration paths.""" + from ...application.publish.registration import sync_publish_labels as _sync_publish_labels + + return _sync_publish_labels(*args, **kwargs) + + def normalize_registration_hashes(*args: Any, **kwargs: Any) -> Any: """Load hash normalization only when extracting registration payloads.""" from ...application.publish.registration import ( @@ -213,6 +220,7 @@ def register_prepared_lineage( git_context = prepared.git_context session_hash = prepared.session_hash session_id = prepared.session_id + registration_session_id = prepared.registration_session_id omit_filter = self.omit_filter detected_secrets: list[str] = [] @@ -273,36 +281,125 @@ def register_prepared_lineage( ) composite_registrations: list[dict[str, Any]] = [] - pre_registration_errors: list[str] = [] + registration_errors: list[str] = [] + finalized_session_hash = session_hash + finalize_failed = False with Spinner("Publishing lineage to GLaaS...") as spin: - if has_lineage_composites(lineage.artifacts): - spin.update("Registering composite artifacts...") - try: - with create_database_context(roar_dir) as db_ctx: - composite_registrations = preregister_lineage_composites_with_glaas( - glaas_client=self.glaas_client, - db_ctx=db_ctx, - lineage_artifacts=lineage.artifacts, + refresh_job_artifact_references(lineage.jobs, lineage.artifacts) + + if registration_session_id: + spin.update("Staging jobs and artifacts...") + batch_result = self.coordinator.register_lineage_under_registration_session( + registration_session_id=registration_session_id, + git_context=git_context, + jobs=registration_jobs, + ) + registration_errors.extend(batch_result.errors) + + if batch_result.jobs_failed == 0 and batch_result.links_failed == 0: + spin.update("Finalizing lineage...") + finalize_result = self.coordinator.session_service.finalize_registration_session( + registration_session_id=registration_session_id, + git_context=git_context, + ) + if not finalize_result.success: + finalize_failed = True + registration_errors.append( + f"Registration session finalize failed: {finalize_result.error}" + ) + else: + finalized_session_hash = finalize_result.session_hash + + if has_lineage_composites(lineage.artifacts): + spin.update("Registering composite artifacts...") + try: + with create_database_context(roar_dir) as db_ctx: + composite_registrations = ( + preregister_lineage_composites_with_glaas( + glaas_client=self.glaas_client, + db_ctx=db_ctx, + lineage_artifacts=lineage.artifacts, + session_hash=finalized_session_hash, + registration_errors=registration_errors, + composite_builder=self.composite_builder, + logger=self._logger, + ) + ) + if session_id is not None: + sync_publish_labels( + glaas_client=self.glaas_client, + db_ctx=db_ctx, + session_id=session_id, + session_hash=finalized_session_hash, + jobs=registration_jobs, + artifacts=lineage.artifacts, + errors=registration_errors, + ) + except Exception as e: + return RegisterResult( + success=False, + session_hash=finalized_session_hash, + artifact_hash=artifact_hash, + error=f"Composite artifact registration failed: {e}", + secrets_detected=detected_secrets, + secrets_redacted=bool(detected_secrets), + ) + elif session_id is not None: + with create_database_context(roar_dir) as db_ctx: + sync_publish_labels( + glaas_client=self.glaas_client, + db_ctx=db_ctx, + session_id=session_id, + session_hash=finalized_session_hash, + jobs=registration_jobs, + artifacts=lineage.artifacts, + errors=registration_errors, + ) + else: + if has_lineage_composites(lineage.artifacts): + spin.update("Registering composite artifacts...") + try: + with create_database_context(roar_dir) as db_ctx: + composite_registrations = preregister_lineage_composites_with_glaas( + glaas_client=self.glaas_client, + db_ctx=db_ctx, + lineage_artifacts=lineage.artifacts, + session_hash=session_hash, + registration_errors=registration_errors, + composite_builder=self.composite_builder, + logger=self._logger, + ) + except Exception as e: + return RegisterResult( + success=False, session_hash=session_hash, - registration_errors=pre_registration_errors, - composite_builder=self.composite_builder, - logger=self._logger, + artifact_hash=artifact_hash, + error=f"Composite artifact registration failed: {e}", + secrets_detected=detected_secrets, + secrets_redacted=bool(detected_secrets), ) - except Exception as e: - return RegisterResult( - success=False, - session_hash=session_hash, - artifact_hash=artifact_hash, - error=f"Composite artifact registration failed: {e}", - secrets_detected=detected_secrets, - secrets_redacted=bool(detected_secrets), - ) - refresh_job_artifact_references(lineage.jobs, lineage.artifacts) - spin.update("Registering jobs, artifacts, and links...") + spin.update("Registering jobs, artifacts, and links...") - if session_id is not None: - with create_database_context(roar_dir) as db_ctx: + if session_id is not None: + with create_database_context(roar_dir) as db_ctx: + batch_result = register_publish_lineage( + coordinator=self.coordinator, + glaas_client=self.glaas_client, + session_hash=session_hash, + git_context=git_context, + jobs=registration_jobs, + artifacts=prepare_batch_registration_artifacts( + lineage.artifacts, + session_hash, + fallback_to_hash=True, + prefer_blake3_first=True, + ), + db_ctx=db_ctx, + session_id=session_id, + label_artifacts=lineage.artifacts, + ) + else: batch_result = register_publish_lineage( coordinator=self.coordinator, glaas_client=self.glaas_client, @@ -315,45 +412,28 @@ def register_prepared_lineage( fallback_to_hash=True, prefer_blake3_first=True, ), - db_ctx=db_ctx, - session_id=session_id, + db_ctx=None, + session_id=None, label_artifacts=lineage.artifacts, ) - else: - batch_result = register_publish_lineage( - coordinator=self.coordinator, - glaas_client=self.glaas_client, - session_hash=session_hash, - git_context=git_context, - jobs=registration_jobs, - artifacts=prepare_batch_registration_artifacts( - lineage.artifacts, - session_hash, - fallback_to_hash=True, - prefer_blake3_first=True, - ), - db_ctx=None, - session_id=None, - label_artifacts=lineage.artifacts, - ) + registration_errors.extend(batch_result.errors) composite_registered = sum(1 for item in composite_registrations if item.get("registered")) composite_failed = sum(1 for item in composite_registrations if not item.get("registered")) total_artifacts_registered = batch_result.artifacts_registered + composite_registered total_artifacts_failed = batch_result.artifacts_failed + composite_failed - all_errors = pre_registration_errors + batch_result.errors - if all_errors: - self._logger.warning("Registration completed with errors: %s", all_errors) + if registration_errors: + self._logger.warning("Registration completed with errors: %s", registration_errors) return RegisterResult( - success=batch_result.jobs_failed == 0 and total_artifacts_failed == 0, - session_hash=session_hash, + success=batch_result.jobs_failed == 0 and total_artifacts_failed == 0 and not finalize_failed, + session_hash=finalized_session_hash, artifact_hash=artifact_hash, jobs_registered=batch_result.jobs_created, artifacts_registered=total_artifacts_registered, links_created=batch_result.links_created, - error="; ".join(all_errors) if all_errors else None, + error="; ".join(registration_errors) if registration_errors else None, secrets_detected=detected_secrets, secrets_redacted=bool(detected_secrets), ) diff --git a/roar/application/publish/register_preparation.py b/roar/application/publish/register_preparation.py index 72af1a7d..5ec09782 100644 --- a/roar/application/publish/register_preparation.py +++ b/roar/application/publish/register_preparation.py @@ -25,6 +25,7 @@ class PreparedRegisterExecution: session_url: str | None git_tag_name: str | None git_tag_repo_root: Path | None + registration_session_id: str | None = None def prepare_register_execution( @@ -88,4 +89,5 @@ def prepare_register_execution( session_url=publish_session.session_url, git_tag_name=git_tag_name, git_tag_repo_root=git_tag_repo_root, + registration_session_id=publish_session.registration_session_id, ) diff --git a/roar/application/publish/session.py b/roar/application/publish/session.py index 50785cea..77d50c1e 100644 --- a/roar/application/publish/session.py +++ b/roar/application/publish/session.py @@ -26,6 +26,12 @@ def register( ) -> SessionRegistrationResult: """Register a session with GLaaS.""" + def create_registration_session( + self, + client_session_id: str | None = None, + ) -> SessionRegistrationResult: + """Create or resume a remote registration session.""" + @dataclass(frozen=True) class PreparedPublishSession: @@ -33,6 +39,7 @@ class PreparedPublishSession: session_hash: str session_url: str | None = None + registration_session_id: str | None = None def build_canonical_session_payload( @@ -209,6 +216,29 @@ def prepare_publish_session( logger.debug("GLaaS health check failed: %s", exc) raise ValueError(f"GLaaS health check failed: {exc}") from exc + publish_auth = getattr(glaas_client, "publish_auth", None) + access_token = getattr(publish_auth, "access_token", None) + should_use_registration_sessions = isinstance(access_token, str) and bool(access_token.strip()) + + if should_use_registration_sessions: + logger.debug("Creating remote registration session with GLaaS") + session_result = session_service.create_registration_session(client_session_id=None) + if not session_result.success: + logger.debug("Registration session creation failed: %s", session_result.error) + raise ValueError( + f"Registration session creation failed: {session_result.error}" + ) + + logger.debug( + "Registration session ready: %s", + session_result.registration_session_id, + ) + return PreparedPublishSession( + session_hash=session_hash, + session_url=None, + registration_session_id=session_result.registration_session_id, + ) + logger.debug("Registering session with GLaaS") session_result = resolved_session_service.register(session_hash, git_context) if not session_result.success: diff --git a/roar/core/interfaces/registration.py b/roar/core/interfaces/registration.py index 9f062051..7da0a31a 100644 --- a/roar/core/interfaces/registration.py +++ b/roar/core/interfaces/registration.py @@ -24,11 +24,14 @@ class GitContext: @dataclass class SessionRegistrationResult: - """Result of session registration.""" + """Result of session registration or registration-session lifecycle operations.""" success: bool session_hash: str session_url: str | None = None + registration_session_id: str | None = None + created: bool | None = None + status: str | None = None error: str | None = None @@ -59,6 +62,7 @@ class JobLinkResult: job_uid: str inputs_linked: int = 0 outputs_linked: int = 0 + artifacts_registered: int = 0 error: str | None = None @@ -97,6 +101,21 @@ def register( """Register session with GLaaS.""" ... + def create_registration_session( + self, + client_session_id: str | None = None, + ) -> SessionRegistrationResult: + """Create or resume a remote registration session.""" + ... + + def finalize_registration_session( + self, + registration_session_id: str, + git_context: GitContext, + ) -> SessionRegistrationResult: + """Finalize a remote registration session into an immutable lineage hash.""" + ... + @runtime_checkable class IArtifactRegistrar(Protocol): diff --git a/roar/integrations/glaas/client.py b/roar/integrations/glaas/client.py index b1b26d75..08c5f35c 100644 --- a/roar/integrations/glaas/client.py +++ b/roar/integrations/glaas/client.py @@ -454,6 +454,37 @@ def register_session( result, error = self._request("POST", "/api/v1/sessions", body) return result, error + def create_registration_session( + self, + client_session_id: str | None = None, + ) -> tuple[dict | None, str | None]: + """Create or resume a durable registration session.""" + body: dict[str, Any] = {} + if client_session_id: + body["client_session_id"] = client_session_id + return self._request("POST", "/api/v1/registration-sessions", body) + + def finalize_registration_session( + self, + registration_session_id: str, + git_repo: str, + git_commit: str, + git_branch: str, + ) -> tuple[dict | None, str | None]: + """Finalize a registration session into a lineage hash.""" + body: dict[str, Any] = { + "git_repo": git_repo, + "git_commit": git_commit, + "git_branch": git_branch, + } + if self._publish_auth.scope_request: + body["scope_request"] = dict(self._publish_auth.scope_request) + return self._request( + "POST", + f"/api/v1/registration-sessions/{registration_session_id}/finalize", + body, + ) + def sync_labels( self, labels: list[dict[str, Any]], @@ -470,6 +501,79 @@ def patch_current_label( """Patch the current remote label document for one lineage target.""" return self._request("PATCH", "/api/v1/labels/current", label) + def register_job_under_registration_session( + self, + registration_session_id: str, + command: str, + timestamp: float, + job_uid: str, + git_commit: str, + git_branch: str, + duration_seconds: float, + exit_code: int, + job_type: str | None, + step_number: int, + metadata: str | None = None, + parent_job_uid: str | None = None, + ) -> tuple[int | None, str | None]: + """Register a staged job under a registration session.""" + body = { + "command": command, + "timestamp": timestamp, + "job_uid": job_uid, + "git_commit": git_commit, + "git_branch": git_branch, + "duration_seconds": duration_seconds, + "exit_code": exit_code, + "job_type": job_type, + "step_number": step_number, + } + if metadata: + body["metadata"] = metadata + if parent_job_uid is not None: + body["parent_job_uid"] = parent_job_uid + + result, error = self._request( + "POST", + f"/api/v1/registration-sessions/{registration_session_id}/jobs", + body, + ) + if error: + return None, error + if result is None: + return None, None + return result.get("id"), None + + def register_jobs_batch_under_registration_session( + self, + registration_session_id: str, + jobs: list, + ) -> tuple[list, list, str | None]: + """Register multiple staged jobs under a registration session.""" + if not jobs: + return [], [], None + + body_jobs: list[dict[str, Any]] = [] + for job in jobs: + if not isinstance(job, dict): + continue + payload = dict(job) + if payload.get("parent_job_uid") is None: + payload.pop("parent_job_uid", None) + body_jobs.append(payload) + + body = {"jobs": body_jobs} + result, error = self._request( + "POST", + f"/api/v1/registration-sessions/{registration_session_id}/jobs/batch", + body, + ) + if error: + return [], [error] * len(jobs), error + if result is None: + return [], [], None + return result.get("job_ids", []), result.get("errors", []), None + def register_job_inputs( self, session_hash: str, @@ -495,6 +599,21 @@ def register_job_inputs( ) return result, error + def register_job_inputs_under_registration_session( + self, + registration_session_id: str, + job_uid: str, + artifacts: list[dict], + ) -> tuple[dict | None, str | None]: + """Register staged input artifacts for a registration-session job.""" + body: dict[str, Any] = {"artifacts": self._map_artifacts_for_api(artifacts)} + result, error = self._request( + "POST", + f"/api/v1/registration-sessions/{registration_session_id}/jobs/{job_uid}/inputs", + body, + ) + return result, error + def register_job_outputs( self, session_hash: str, @@ -520,6 +639,21 @@ def register_job_outputs( ) return result, error + def register_job_outputs_under_registration_session( + self, + registration_session_id: str, + job_uid: str, + artifacts: list[dict], + ) -> tuple[dict | None, str | None]: + """Register staged output artifacts for a registration-session job.""" + body: dict[str, Any] = {"artifacts": self._map_artifacts_for_api(artifacts)} + result, error = self._request( + "POST", + f"/api/v1/registration-sessions/{registration_session_id}/jobs/{job_uid}/outputs", + body, + ) + return result, error + @staticmethod def _map_artifacts_for_api(artifacts: list[dict]) -> list[dict]: """Map internal artifact dicts to the API schema. @@ -535,6 +669,12 @@ def _map_artifacts_for_api(artifacts: list[dict]) -> list[dict]: "hash": artifact_hash, "path": a["path"], } + if "size" in a and a["size"] is not None: + entry["size"] = a["size"] + if "source_type" in a and a["source_type"] is not None: + entry["source_type"] = a["source_type"] + if "metadata" in a and a["metadata"] is not None: + entry["metadata"] = a["metadata"] if "byte_ranges" in a and a["byte_ranges"] is not None: entry["byte_ranges"] = a["byte_ranges"] mapped.append(entry) diff --git a/roar/integrations/glaas/registration/coordinator.py b/roar/integrations/glaas/registration/coordinator.py index 70a2fff1..907a2135 100644 --- a/roar/integrations/glaas/registration/coordinator.py +++ b/roar/integrations/glaas/registration/coordinator.py @@ -229,6 +229,93 @@ def register_lineage( errors=errors, ) + def register_lineage_under_registration_session( + self, + registration_session_id: str, + git_context: GitContext, + jobs: list[dict], + ) -> BatchRegistrationResult: + """Stage complete lineage under a remote registration session.""" + errors: list[str] = [] + jobs_created = 0 + jobs_failed = 0 + artifacts_registered = 0 + artifacts_failed = 0 + links_created = 0 + links_failed = 0 + + self._logger.debug( + "Starting registration-session lineage staging: registration_session_id=%s, jobs=%d", + registration_session_id, + len(jobs), + ) + + job_uids_created: list[str] = [] + if jobs: + batch_results = self.job_service.create_jobs_batch_under_registration_session( + jobs=jobs, + registration_session_id=registration_session_id, + git_context=git_context, + ) + for result in batch_results: + if result.success: + jobs_created += 1 + job_uids_created.append(result.job_uid) + else: + jobs_failed += 1 + if result.error: + errors.append(f"Job {result.job_uid}: {result.error}") + + self._logger.debug( + "Registration-session job staging complete: %d created, %d failed", + jobs_created, + jobs_failed, + ) + + for job in jobs: + job_uid = job.get("job_uid") + if not job_uid or job_uid not in job_uids_created: + continue + + inputs = self._extract_staged_io_list(job, "_inputs", "_input_hashes") + outputs = self._extract_staged_io_list(job, "_outputs", "_output_hashes") + if not inputs and not outputs: + continue + + link_result = self.job_service.link_job_artifacts_under_registration_session( + registration_session_id=registration_session_id, + job_uid=job_uid, + inputs=inputs, + outputs=outputs, + ) + if link_result.success: + links_created += link_result.inputs_linked + link_result.outputs_linked + artifacts_registered += link_result.artifacts_registered + else: + links_failed += 1 + artifacts_failed += link_result.artifacts_registered + if link_result.error: + errors.append(f"Link {job_uid}: {link_result.error}") + + self._logger.debug( + "Registration-session lineage staging complete: jobs=%d/%d, artifacts=%d, links=%d", + jobs_created, + jobs_created + jobs_failed, + artifacts_registered, + links_created, + ) + + return BatchRegistrationResult( + session_registered=True, + jobs_created=jobs_created, + jobs_failed=jobs_failed, + artifacts_registered=artifacts_registered, + artifacts_failed=artifacts_failed, + links_created=links_created, + links_failed=links_failed, + errors=errors, + ) + def _resolve_io_artifact_hashes( self, items: list[dict[str, Any]], @@ -324,6 +411,51 @@ def _extract_io_list( return [] + def _extract_staged_io_list( + self, + job: dict, + structured_key: str, + hash_list_key: str, + ) -> list[dict[str, Any]]: + """Extract staged I/O rows with direct artifact hashes for registration-session writes.""" + structured = job.get(structured_key, []) + if structured: + result: list[dict[str, Any]] = [] + for item in structured: + artifact_hash = _artifact_ref.extract_digest(item) + path = item.get("path") + if not artifact_hash or not path: + preview = _artifact_ref.preview(item) + if preview: + self._logger.warning( + "Dropping staged I/O item %s: missing hash or path", + preview, + ) + else: + self._logger.warning("Dropping staged I/O item: missing hash or path") + continue + + normalized: dict[str, Any] = { + "artifact_hash": artifact_hash, + "path": path, + } + if item.get("size") is not None: + normalized["size"] = item["size"] + if item.get("source_type") is not None: + normalized["source_type"] = item["source_type"] + if item.get("metadata") is not None: + normalized["metadata"] = item["metadata"] + if item.get("byte_ranges") is not None: + normalized["byte_ranges"] = item["byte_ranges"] + result.append(normalized) + return result + + hash_list = job.get(hash_list_key, []) + if hash_list: + return [{"artifact_hash": h, "path": ""} for h in hash_list if h] + + return [] + def register_uploaded_artifacts( self, artifacts: list[tuple[str, int, str, str]], diff --git a/roar/integrations/glaas/registration/job.py b/roar/integrations/glaas/registration/job.py index 2fb8627e..fcb97764 100644 --- a/roar/integrations/glaas/registration/job.py +++ b/roar/integrations/glaas/registration/job.py @@ -198,6 +198,70 @@ def create_job( job_id=str(job_id) if job_id else None, ) + def create_job_under_registration_session( + self, + command: str, + timestamp: float, + registration_session_id: str, + job_uid: str, + git_commit: str, + git_branch: str, + duration_seconds: float, + exit_code: int, + job_type: str | None, + step_number: int, + metadata: str | None = None, + parent_job_uid: str | None = None, + ) -> JobRegistrationResult: + """Create a staged job under a remote registration session.""" + filtered_command, _, filtered_metadata = self._filter_job_data(command, None, metadata) + + validation = validate_job_registration( + command=filtered_command, + timestamp=timestamp, + session_hash="pending-registration-session", + job_uid=job_uid, + git_commit=git_commit, + git_branch=git_branch, + job_type=job_type, + step_number=step_number, + ) + if not validation: + error_msg = "; ".join(validation.errors) + self._logger.warning("Job validation failed for %s: %s", job_uid, error_msg) + return JobRegistrationResult(success=False, job_uid=job_uid, error=error_msg) + + job_id, error = self.client.register_job_under_registration_session( + registration_session_id=registration_session_id, + command=filtered_command, + timestamp=timestamp, + job_uid=job_uid, + git_commit=git_commit, + git_branch=git_branch, + duration_seconds=duration_seconds, + exit_code=exit_code, + job_type=job_type, + step_number=step_number, + metadata=filtered_metadata, + parent_job_uid=parent_job_uid, + ) + if error: + self._logger.debug( + "Registration-session job creation failed for %s: %s", job_uid, error + ) + return JobRegistrationResult(success=False, job_uid=job_uid, error=error) + + self._logger.debug( + "Registration-session job created: %s -> server_id=%s", + job_uid, + job_id, + ) + return JobRegistrationResult( + success=True, + job_uid=job_uid, + job_id=str(job_id) if job_id else None, + ) + def create_jobs_batch( self, jobs: list[dict], @@ -321,6 +385,113 @@ def create_jobs_batch( return [r for r in results if r is not None] + def create_jobs_batch_under_registration_session( + self, + jobs: list[dict], + registration_session_id: str, + git_context: GitContext, + ) -> list[JobRegistrationResult]: + """Create multiple staged jobs under a remote registration session.""" + if not jobs: + return [] + + results: list[JobRegistrationResult | None] = [None] * len(jobs) + payloads: list[dict] = [] + payload_indices: list[int] = [] + + for i, job in enumerate(jobs): + job_uid = job.get("job_uid") + if not job_uid: + self._logger.warning("Skipping job without job_uid") + results[i] = JobRegistrationResult( + success=False, job_uid="", error="Job missing job_uid" + ) + continue + + command = job.get("command", "") + git_commit = job.get("git_commit") or git_context.commit or "" + git_branch = job.get("git_branch") or git_context.branch or "" + metadata = job.get("metadata") + filtered_command, _, filtered_metadata = self._filter_job_data(command, None, metadata) + + validation = validate_job_registration( + command=filtered_command, + timestamp=job.get("timestamp", 0.0), + session_hash="pending-registration-session", + job_uid=job_uid, + git_commit=git_commit, + git_branch=git_branch, + job_type=job.get("job_type"), + step_number=job.get("step_number", 0), + ) + if not validation: + error_msg = "; ".join(validation.errors) + self._logger.warning("Job validation failed for %s: %s", job_uid, error_msg) + results[i] = JobRegistrationResult(success=False, job_uid=job_uid, error=error_msg) + continue + + payload: dict[str, Any] = { + "command": filtered_command, + "timestamp": job.get("timestamp", 0.0), + "job_uid": job_uid, + "git_commit": git_commit, + "git_branch": git_branch, + "duration_seconds": job.get("duration_seconds", 0.0), + "exit_code": job.get("exit_code", 0), + "job_type": job.get("job_type") or "run", + "step_number": job.get("step_number", 0), + } + if filtered_metadata: + payload["metadata"] = filtered_metadata + if job.get("parent_job_uid") is not None: + payload["parent_job_uid"] = job.get("parent_job_uid") + + payloads.append(payload) + payload_indices.append(i) + + if not payloads: + return [r for r in results if r is not None] + + self._logger.debug( + "Batch registering %d jobs under registration session %s", + len(payloads), + registration_session_id, + ) + + job_ids, errors, overall_error = self.client.register_jobs_batch_under_registration_session( + registration_session_id=registration_session_id, + jobs=payloads, + ) + + if overall_error: + for idx in payload_indices: + job_uid = jobs[idx].get("job_uid", "") + results[idx] = JobRegistrationResult( + success=False, job_uid=job_uid, error=overall_error + ) + else: + for pos, idx in enumerate(payload_indices): + job_uid = payloads[pos]["job_uid"] + if pos < len(errors) and errors[pos]: + results[idx] = JobRegistrationResult( + success=False, job_uid=job_uid, error=errors[pos] + ) + else: + job_id = job_ids[pos] if pos < len(job_ids) else None + results[idx] = JobRegistrationResult( + success=True, + job_uid=job_uid, + job_id=str(job_id) if job_id else None, + ) + + self._logger.debug( + "Registration-session batch job registration complete: %d succeeded, %d failed", + sum(1 for r in results if r and r.success), + sum(1 for r in results if r and not r.success), + ) + + return [r for r in results if r is not None] + def link_job_artifacts( self, session_hash: str, @@ -432,6 +603,114 @@ def link_job_artifacts( outputs_linked=outputs_linked, ) + def link_job_artifacts_under_registration_session( + self, + registration_session_id: str, + job_uid: str, + inputs: list[dict[str, Any]] | None, + outputs: list[dict[str, Any]] | None, + ) -> JobLinkResult: + """Link artifacts to a staged job under a remote registration session.""" + valid_inputs = self._normalize_link_artifacts(inputs or [], "input") + valid_outputs = self._normalize_link_artifacts(outputs or [], "output") + + if not valid_inputs and not valid_outputs: + self._logger.debug( + "No staged artifacts to link for registration-session job %s", + job_uid, + ) + return JobLinkResult(success=True, job_uid=job_uid) + + self._logger.debug( + "Linking staged artifacts to registration-session job %s: %d inputs, %d outputs", + job_uid, + len(valid_inputs), + len(valid_outputs), + ) + + inputs_linked = 0 + outputs_linked = 0 + artifacts_registered = 0 + errors = [] + + if valid_inputs: + input_batches = _batch_artifacts(valid_inputs, MAX_ARTIFACTS_PER_REQUEST) + for batch_idx, batch in enumerate(input_batches): + self._logger.debug( + "Sending registration-session input batch %d/%d for job %s: %d artifacts", + batch_idx + 1, + len(input_batches), + job_uid, + len(batch), + ) + result, error = self.client.register_job_inputs_under_registration_session( + registration_session_id=registration_session_id, + job_uid=job_uid, + artifacts=batch, + ) + if error: + self._logger.debug( + "Registration-session input linking failed for %s: %s", job_uid, error + ) + errors.append(f"inputs: {error}") + break + inputs_linked += result.get("inputs_linked", len(batch)) if result else len(batch) + artifacts_registered += ( + result.get("artifacts_registered", len(batch)) if result else len(batch) + ) + + if valid_outputs: + output_batches = _batch_artifacts(valid_outputs, MAX_ARTIFACTS_PER_REQUEST) + for batch_idx, batch in enumerate(output_batches): + self._logger.debug( + "Sending registration-session output batch %d/%d for job %s: %d artifacts", + batch_idx + 1, + len(output_batches), + job_uid, + len(batch), + ) + result, error = self.client.register_job_outputs_under_registration_session( + registration_session_id=registration_session_id, + job_uid=job_uid, + artifacts=batch, + ) + if error: + self._logger.debug( + "Registration-session output linking failed for %s: %s", job_uid, error + ) + errors.append(f"outputs: {error}") + break + outputs_linked += ( + result.get("outputs_linked", len(batch)) if result else len(batch) + ) + artifacts_registered += ( + result.get("artifacts_registered", len(batch)) if result else len(batch) + ) + + if errors: + return JobLinkResult( + success=False, + job_uid=job_uid, + inputs_linked=inputs_linked, + outputs_linked=outputs_linked, + artifacts_registered=artifacts_registered, + error="; ".join(errors), + ) + + self._logger.debug( + "Linked staged artifacts to registration-session job %s: %d inputs, %d outputs", + job_uid, + inputs_linked, + outputs_linked, + ) + return JobLinkResult( + success=True, + job_uid=job_uid, + inputs_linked=inputs_linked, + outputs_linked=outputs_linked, + artifacts_registered=artifacts_registered, + ) + def _normalize_link_artifacts( self, items: list[dict[str, Any]], @@ -467,6 +746,12 @@ def _normalize_link_artifacts( "artifact_hash": artifact_hash, "path": path, } + if item.get("size") is not None: + normalized["size"] = item["size"] + if item.get("source_type") is not None: + normalized["source_type"] = item["source_type"] + if item.get("metadata") is not None: + normalized["metadata"] = item["metadata"] if item.get("byte_ranges") is not None: normalized["byte_ranges"] = item["byte_ranges"] diff --git a/roar/integrations/glaas/registration/session.py b/roar/integrations/glaas/registration/session.py index 92c49951..b2129a6d 100644 --- a/roar/integrations/glaas/registration/session.py +++ b/roar/integrations/glaas/registration/session.py @@ -139,4 +139,83 @@ def register( success=True, session_hash=session_hash, session_url=session_url, + created=result.get("created") if result else None, + ) + + def create_registration_session( + self, + client_session_id: str | None = None, + ) -> SessionRegistrationResult: + """Create or resume a durable remote registration session.""" + result, error = self.client.create_registration_session(client_session_id=client_session_id) + if error: + self._logger.warning("Registration session creation failed: %s", error) + return SessionRegistrationResult(success=False, session_hash="", error=error) + + registration_session_id = result.get("registration_session_id") if result else None + if not isinstance(registration_session_id, str) or not registration_session_id: + error_msg = "registration session response did not include registration_session_id" + self._logger.warning("Registration session creation failed: %s", error_msg) + return SessionRegistrationResult(success=False, session_hash="", error=error_msg) + + self._logger.debug( + "Registration session ready: %s (created=%s, status=%s)", + registration_session_id, + result.get("created") if result else None, + result.get("status") if result else None, + ) + return SessionRegistrationResult( + success=True, + session_hash="", + registration_session_id=registration_session_id, + created=result.get("created") if result else None, + status=result.get("status") if result else None, + ) + + def finalize_registration_session( + self, + registration_session_id: str, + git_context: GitContext, + ) -> SessionRegistrationResult: + """Finalize a remote registration session into an immutable lineage hash.""" + validation = validate_session_registration( + session_hash="pending-registration-session-finalize", + git_repo=git_context.repo, + git_commit=git_context.commit, + git_branch=git_context.branch, + ) + if not validation: + error_msg = "; ".join(validation.errors) + self._logger.warning("Registration session finalize validation failed: %s", error_msg) + return SessionRegistrationResult(success=False, session_hash="", error=error_msg) + + result, error = self.client.finalize_registration_session( + registration_session_id=registration_session_id, + git_repo=git_context.repo or "", + git_commit=git_context.commit or "", + git_branch=git_context.branch or "", + ) + if error: + self._logger.warning("Registration session finalize failed: %s", error) + return SessionRegistrationResult(success=False, session_hash="", error=error) + + session_hash = result.get("hash") if result else None + if not isinstance(session_hash, str) or not session_hash: + error_msg = "registration session finalize response did not include hash" + self._logger.warning("Registration session finalize failed: %s", error_msg) + return SessionRegistrationResult(success=False, session_hash="", error=error_msg) + + session_url = result.get("url") if result else None + self._logger.debug( + "Registration session finalized successfully: %s -> %s", + registration_session_id, + session_hash[:12], + ) + return SessionRegistrationResult( + success=True, + session_hash=session_hash, + session_url=session_url if isinstance(session_url, str) else None, + registration_session_id=registration_session_id, + created=result.get("created") if result else None, + status=result.get("status") if result else None, ) diff --git a/tests/application/publish/test_put_preparation.py b/tests/application/publish/test_put_preparation.py index d0b6c55a..89f0f58b 100644 --- a/tests/application/publish/test_put_preparation.py +++ b/tests/application/publish/test_put_preparation.py @@ -34,7 +34,11 @@ def test_prepare_put_execution_builds_session_git_and_source_plan(tmp_path: Path db_ctx = MagicMock() db_ctx.sessions.get_active.return_value = {"id": 7} runtime = MagicMock() - prepared_session = MagicMock(session_hash="session-hash", session_url="https://glaas/session") + prepared_session = MagicMock( + session_hash="session-hash", + session_url="https://glaas/session", + registration_session_id=None, + ) logger = MagicMock() git_context = GitContext(repo="https://github.com/test/repo", branch="main", commit="deadbeef") @@ -92,7 +96,11 @@ def test_prepare_put_execution_propagates_missing_source(tmp_path: Path) -> None ), patch( "roar.application.publish.put_preparation.prepare_publish_session", - return_value=MagicMock(session_hash="session-hash", session_url=None), + return_value=MagicMock( + session_hash="session-hash", + session_url=None, + registration_session_id=None, + ), ), pytest.raises(FileNotFoundError, match=r"Source not found: missing.pt"), ): diff --git a/tests/application/publish/test_register_preparation.py b/tests/application/publish/test_register_preparation.py index 5de088e4..1c78fc4a 100644 --- a/tests/application/publish/test_register_preparation.py +++ b/tests/application/publish/test_register_preparation.py @@ -34,7 +34,11 @@ def test_prepare_register_execution_builds_session_git_and_tag_plan(tmp_path: Pa db_user_id=None, ) logger = MagicMock() - prepared_session = MagicMock(session_hash="session-hash", session_url="https://glaas/session") + prepared_session = MagicMock( + session_hash="session-hash", + session_url="https://glaas/session", + registration_session_id=None, + ) git_context = _git_context() git_state = MagicMock(repo_root=tmp_path) @@ -104,7 +108,11 @@ def test_prepare_register_execution_passes_lineage_and_creator_identity_to_sessi user_sub="treqs-sub-123", db_user_id="user-123", ) - prepared_session = MagicMock(session_hash="session-hash", session_url=None) + prepared_session = MagicMock( + session_hash="session-hash", + session_url=None, + registration_session_id=None, + ) with ( patch( @@ -134,7 +142,11 @@ def test_prepare_register_execution_passes_lineage_and_creator_identity_to_sessi def test_prepare_register_execution_skips_git_tagging_and_glaas_on_dry_run(tmp_path: Path) -> None: runtime = MagicMock() - prepared_session = MagicMock(session_hash="session-hash", session_url=None) + prepared_session = MagicMock( + session_hash="session-hash", + session_url=None, + registration_session_id=None, + ) git_context = _git_context() with ( diff --git a/tests/application/publish/test_session.py b/tests/application/publish/test_session.py index 30a95081..eea59da2 100644 --- a/tests/application/publish/test_session.py +++ b/tests/application/publish/test_session.py @@ -77,14 +77,16 @@ def test_prepare_publish_session_uses_canonical_hash_when_lineage_and_creator_id session_service.register.assert_not_called() -def test_prepare_publish_session_registers_with_glaas(tmp_path: Path) -> None: +def test_prepare_publish_session_creates_registration_session_with_glaas(tmp_path: Path) -> None: glaas_client = MagicMock() + glaas_client.publish_auth.access_token = "token-123" session_service = MagicMock() session_service.compute_session_hash.return_value = "session-hash" - session_service.register.return_value = SessionRegistrationResult( + session_service.create_registration_session.return_value = SessionRegistrationResult( success=True, session_hash="session-hash", - session_url="https://glaas.example/dag/session-hash", + session_url=None, + registration_session_id="reg-session-123", ) logger = MagicMock() @@ -100,10 +102,12 @@ def test_prepare_publish_session_registers_with_glaas(tmp_path: Path) -> None: assert result == PreparedPublishSession( session_hash="session-hash", - session_url="https://glaas.example/dag/session-hash", + session_url=None, + registration_session_id="reg-session-123", ) glaas_client.health_check.assert_called_once() - session_service.register.assert_called_once_with("session-hash", _git_context()) + session_service.create_registration_session.assert_called_once_with(client_session_id=None) + session_service.register.assert_not_called() def test_prepare_publish_session_requires_configured_glaas(tmp_path: Path) -> None: @@ -143,17 +147,20 @@ def test_prepare_publish_session_surfaces_health_check_failures(tmp_path: Path) ) -def test_prepare_publish_session_surfaces_session_registration_failures(tmp_path: Path) -> None: +def test_prepare_publish_session_surfaces_registration_session_creation_failures( + tmp_path: Path, +) -> None: glaas_client = MagicMock() + glaas_client.publish_auth.access_token = "token-123" session_service = MagicMock() session_service.compute_session_hash.return_value = "session-hash" - session_service.register.return_value = SessionRegistrationResult( + session_service.create_registration_session.return_value = SessionRegistrationResult( success=False, - session_hash="session-hash", + session_hash="", error="rejected", ) - with pytest.raises(ValueError, match="Session registration failed: rejected"): + with pytest.raises(ValueError, match="Registration session creation failed: rejected"): prepare_publish_session( glaas_client=glaas_client, session_service=session_service, @@ -163,3 +170,35 @@ def test_prepare_publish_session_surfaces_session_registration_failures(tmp_path logger=MagicMock(), register_with_glaas=True, ) + + +def test_prepare_publish_session_uses_legacy_session_registration_without_access_token( + tmp_path: Path, +) -> None: + glaas_client = MagicMock() + glaas_client.publish_auth.access_token = None + session_service = MagicMock() + session_service.compute_session_hash.return_value = "session-hash" + session_service.register.return_value = SessionRegistrationResult( + success=True, + session_hash="session-hash", + session_url="https://glaas.example/dag/session-hash", + ) + + result = prepare_publish_session( + glaas_client=glaas_client, + session_service=session_service, + roar_dir=tmp_path / ".roar", + session_id=7, + git_context=_git_context(), + logger=MagicMock(), + register_with_glaas=True, + ) + + assert result == PreparedPublishSession( + session_hash="session-hash", + session_url="https://glaas.example/dag/session-hash", + registration_session_id=None, + ) + session_service.create_registration_session.assert_not_called() + session_service.register.assert_called_once_with("session-hash", _git_context()) diff --git a/tests/integration/fake_glaas.py b/tests/integration/fake_glaas.py index 53660784..9f5096c3 100644 --- a/tests/integration/fake_glaas.py +++ b/tests/integration/fake_glaas.py @@ -14,12 +14,18 @@ def __init__(self) -> None: super().__init__(("127.0.0.1", 0), _FakeGlaasHandler) self.health_checks = 0 self.session_registrations: list[dict[str, Any]] = [] + self.registration_session_creations: list[dict[str, Any]] = [] + self.registration_session_finalizations: list[dict[str, Any]] = [] self.job_batches: list[dict[str, Any]] = [] self.job_creates: list[dict[str, Any]] = [] + self.registration_session_job_batches: list[dict[str, Any]] = [] + self.registration_session_job_creates: list[dict[str, Any]] = [] self.artifact_batches: list[list[dict[str, Any]]] = [] self.auth_headers: list[dict[str, Any]] = [] self.input_links: list[dict[str, Any]] = [] self.output_links: list[dict[str, Any]] = [] + self.registration_session_input_links: list[dict[str, Any]] = [] + self.registration_session_output_links: list[dict[str, Any]] = [] self.label_syncs: list[list[dict[str, Any]]] = [] self.label_mutation_attempts: list[dict[str, Any]] = [] self.label_mutations: list[dict[str, Any]] = [] @@ -28,7 +34,11 @@ def __init__(self) -> None: self.artifacts_by_digest: dict[str, dict[str, Any]] = {} self.artifact_dags_by_digest: dict[str, dict[str, Any]] = {} self.session_reproductions_by_hash: dict[str, dict[str, Any]] = {} + self.registration_sessions_by_id: dict[str, dict[str, Any]] = {} + self.registration_session_ids_by_client_session_id: dict[str, str] = {} self._next_job_id = 1 + self._next_registration_session_id = 1 + self._next_finalized_hash = 1 @property def base_url(self) -> str: @@ -39,6 +49,16 @@ def allocate_job_id(self) -> int: self._next_job_id += 1 return job_id + def allocate_registration_session_id(self) -> str: + registration_session_id = f"reg-session-{self._next_registration_session_id}" + self._next_registration_session_id += 1 + return registration_session_id + + def allocate_lineage_hash(self) -> str: + lineage_hash = f"{self._next_finalized_hash:064x}" + self._next_finalized_hash += 1 + return lineage_hash + class _FakeGlaasHandler(BaseHTTPRequestHandler): server: _FakeGlaasServer @@ -77,6 +97,23 @@ def _resolve_authenticated_user(self, authorization: str | None) -> dict[str, st } return None + def _record_artifacts(self, artifacts: list[dict[str, Any]]) -> None: + for artifact in artifacts: + artifact_hash = artifact.get("hash") + if isinstance(artifact_hash, str) and artifact_hash: + self.server.artifacts_by_digest[artifact_hash] = artifact + continue + + hashes = artifact.get("hashes", []) + if not isinstance(hashes, list): + continue + for entry in hashes: + if not isinstance(entry, dict): + continue + digest = entry.get("digest") + if isinstance(digest, str) and digest: + self.server.artifacts_by_digest[digest] = artifact + def do_GET(self) -> None: authorization = self.headers.get("Authorization") if self.path == "/api/v1/auth/access-context": @@ -224,6 +261,40 @@ def do_POST(self) -> None: self._write_json(401, {"error": "Missing or invalid auth"}) return + if self.path == "/api/v1/registration-sessions": + self.server.registration_session_creations.append(payload) + client_session_id = payload.get("client_session_id") + if isinstance(client_session_id, str) and client_session_id: + registration_session_id = self.server.registration_session_ids_by_client_session_id.get( + client_session_id + ) + else: + registration_session_id = None + + created = registration_session_id is None + if registration_session_id is None: + registration_session_id = self.server.allocate_registration_session_id() + self.server.registration_sessions_by_id[registration_session_id] = { + "jobs": {}, + "hash": None, + "status": "active", + } + if isinstance(client_session_id, str) and client_session_id: + self.server.registration_session_ids_by_client_session_id[client_session_id] = ( + registration_session_id + ) + + session_state = self.server.registration_sessions_by_id[registration_session_id] + self._write_json( + 201, + { + "registration_session_id": registration_session_id, + "created": created, + "status": session_state["status"], + }, + ) + return + if self.path == "/api/v1/sessions": self.server.session_registrations.append( { @@ -247,16 +318,7 @@ def do_POST(self) -> None: artifacts = payload.get("artifacts", []) if isinstance(artifacts, list): self.server.artifact_batches.append(artifacts) - for artifact in artifacts: - hashes = artifact.get("hashes", []) - if not isinstance(hashes, list): - continue - for entry in hashes: - if not isinstance(entry, dict): - continue - digest = entry.get("digest") - if isinstance(digest, str) and digest: - self.server.artifacts_by_digest[digest] = artifact + self._record_artifacts(artifacts) self._write_json(200, {"created": len(artifacts), "existing": 0}) return @@ -295,6 +357,7 @@ def do_POST(self) -> None: if self.path == "/api/v1/artifacts/composites": self.server.composite_registrations.append(payload) + self._record_artifacts([payload]) self._write_json( 200, { @@ -304,6 +367,141 @@ def do_POST(self) -> None: ) return + finalize_match = re.fullmatch( + r"/api/v1/registration-sessions/([^/]+)/finalize", + self.path, + ) + if finalize_match: + registration_session_id = finalize_match.group(1) + self.server.registration_session_finalizations.append(payload) + session_state = self.server.registration_sessions_by_id.setdefault( + registration_session_id, + {"jobs": {}, "hash": None, "status": "active"}, + ) + lineage_hash = session_state.get("hash") or self.server.allocate_lineage_hash() + session_state["hash"] = lineage_hash + session_state["status"] = "closed" + self._write_json( + 200, + { + "hash": lineage_hash, + "url": f"{self.server.base_url}/dag/{lineage_hash}", + "created": True, + "registration_session_id": registration_session_id, + "status": "closed", + }, + ) + return + + reg_batch_match = re.fullmatch( + r"/api/v1/registration-sessions/([^/]+)/jobs/batch", + self.path, + ) + if reg_batch_match: + registration_session_id = reg_batch_match.group(1) + jobs = payload.get("jobs", []) + if not isinstance(jobs, list): + jobs = [] + self.server.registration_session_job_batches.append( + {"registration_session_id": registration_session_id, "jobs": jobs} + ) + session_state = self.server.registration_sessions_by_id.setdefault( + registration_session_id, + {"jobs": {}, "hash": None, "status": "active"}, + ) + job_ids = [] + for job in jobs: + job_uid = job.get("job_uid") + if isinstance(job_uid, str) and job_uid: + session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []}) + job_ids.append(self.server.allocate_job_id()) + self._write_json(200, {"job_ids": job_ids, "errors": []}) + return + + reg_job_match = re.fullmatch(r"/api/v1/registration-sessions/([^/]+)/jobs", self.path) + if reg_job_match: + registration_session_id = reg_job_match.group(1) + self.server.registration_session_job_creates.append( + {"registration_session_id": registration_session_id, "job": payload} + ) + session_state = self.server.registration_sessions_by_id.setdefault( + registration_session_id, + {"jobs": {}, "hash": None, "status": "active"}, + ) + job_uid = payload.get("job_uid") + if isinstance(job_uid, str) and job_uid: + session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []}) + self._write_json(200, {"id": self.server.allocate_job_id()}) + return + + reg_input_match = re.fullmatch( + r"/api/v1/registration-sessions/([^/]+)/jobs/([^/]+)/inputs", + self.path, + ) + if reg_input_match: + registration_session_id, job_uid = reg_input_match.groups() + artifacts = payload.get("artifacts", []) + if not isinstance(artifacts, list): + artifacts = [] + self.server.registration_session_input_links.append( + { + "registration_session_id": registration_session_id, + "job_uid": job_uid, + "artifacts": artifacts, + } + ) + self._record_artifacts(artifacts) + session_state = self.server.registration_sessions_by_id.setdefault( + registration_session_id, + {"jobs": {}, "hash": None, "status": "active"}, + ) + session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []})["inputs"].extend( + artifacts + ) + self._write_json( + 200, + { + "job_uid": job_uid, + "artifacts_registered": len(artifacts), + "inputs_linked": len(artifacts), + }, + ) + return + + reg_output_match = re.fullmatch( + r"/api/v1/registration-sessions/([^/]+)/jobs/([^/]+)/outputs", + self.path, + ) + if reg_output_match: + registration_session_id, job_uid = reg_output_match.groups() + artifacts = payload.get("artifacts", []) + if not isinstance(artifacts, list): + artifacts = [] + self.server.registration_session_output_links.append( + { + "registration_session_id": registration_session_id, + "job_uid": job_uid, + "artifacts": artifacts, + } + ) + self._record_artifacts(artifacts) + session_state = self.server.registration_sessions_by_id.setdefault( + registration_session_id, + {"jobs": {}, "hash": None, "status": "active"}, + ) + session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []})[ + "outputs" + ].extend(artifacts) + self._write_json( + 200, + { + "job_uid": job_uid, + "artifacts_registered": len(artifacts), + "outputs_linked": len(artifacts), + }, + ) + return + batch_match = re.fullmatch(r"/api/v1/sessions/([0-9a-f]+)/jobs/batch", self.path) if batch_match: session_hash = batch_match.group(1) @@ -331,6 +529,7 @@ def do_POST(self) -> None: self.server.input_links.append( {"session_hash": session_hash, "job_uid": job_uid, "artifacts": artifacts} ) + self._record_artifacts(artifacts) self._write_json(200, {"job_uid": job_uid, "inputs_linked": len(artifacts)}) return @@ -346,6 +545,7 @@ def do_POST(self) -> None: self.server.output_links.append( {"session_hash": session_hash, "job_uid": job_uid, "artifacts": artifacts} ) + self._record_artifacts(artifacts) self._write_json(200, {"job_uid": job_uid, "outputs_linked": len(artifacts)}) return @@ -394,6 +594,14 @@ def health_checks(self) -> int: def session_registrations(self) -> list[dict[str, Any]]: return self._server.session_registrations + @property + def registration_session_creations(self) -> list[dict[str, Any]]: + return self._server.registration_session_creations + + @property + def registration_session_finalizations(self) -> list[dict[str, Any]]: + return self._server.registration_session_finalizations + @property def job_batches(self) -> list[dict[str, Any]]: return self._server.job_batches @@ -402,6 +610,14 @@ def job_batches(self) -> list[dict[str, Any]]: def job_creates(self) -> list[dict[str, Any]]: return self._server.job_creates + @property + def registration_session_job_batches(self) -> list[dict[str, Any]]: + return self._server.registration_session_job_batches + + @property + def registration_session_job_creates(self) -> list[dict[str, Any]]: + return self._server.registration_session_job_creates + @property def artifact_batches(self) -> list[list[dict[str, Any]]]: return self._server.artifact_batches @@ -418,6 +634,14 @@ def input_links(self) -> list[dict[str, Any]]: def output_links(self) -> list[dict[str, Any]]: return self._server.output_links + @property + def registration_session_input_links(self) -> list[dict[str, Any]]: + return self._server.registration_session_input_links + + @property + def registration_session_output_links(self) -> list[dict[str, Any]]: + return self._server.registration_session_output_links + @property def label_syncs(self) -> list[list[dict[str, Any]]]: return self._server.label_syncs diff --git a/tests/integration/test_public_publish_intent_cli.py b/tests/integration/test_public_publish_intent_cli.py index 25981410..a949ea15 100644 --- a/tests/integration/test_public_publish_intent_cli.py +++ b/tests/integration/test_public_publish_intent_cli.py @@ -24,7 +24,7 @@ def fake_glaas_publish_server() -> FakeGlaasServer: @pytest.fixture def ssh_keypair(tmp_path: Path) -> Path: if shutil.which("ssh-keygen") is None: - pytest.skip("ssh-keygen is required for SSH public publish tests") + pytest.skip("ssh-keygen is required for SSH-auth publish integration tests") key_path = tmp_path / "id_ed25519" subprocess.run( @@ -71,8 +71,11 @@ def _configure_unbound_repo(repo: Path, roar_cli, fake_glaas_url: str) -> dict[s return env -def _configure_public_repo( - repo: Path, roar_cli, fake_glaas_url: str, *, bind_repo: bool +def _configure_unbound_repo_for_ssh_only( + repo: Path, + roar_cli, + fake_glaas_url: str, + ssh_keypair: Path, ) -> dict[str, str]: subprocess.run( ["git", "remote", "add", "origin", "https://github.com/test/repo.git"], @@ -80,32 +83,17 @@ def _configure_public_repo( capture_output=True, check=True, ) - home_dir = repo / ".home" - home_dir.mkdir(exist_ok=True) + xdg_config_home = repo / ".xdg" env = { - "HOME": str(home_dir), - "XDG_CONFIG_HOME": str(repo / ".xdg"), + "XDG_CONFIG_HOME": str(xdg_config_home), "GLAAS_API_URL": fake_glaas_url, - "ROAR_ENABLE_EXPERIMENTAL_ACCOUNT_COMMANDS": "1", + "ROAR_SSH_KEY": str(ssh_keypair), } roar_cli("config", "set", "glaas.url", fake_glaas_url, env_overrides=env) roar_cli("config", "set", "glaas.web_url", fake_glaas_url, env_overrides=env) - if bind_repo: - config_path = repo / ".roar" / "config.toml" - with config_path.open("a", encoding="utf-8") as handle: - handle.write("\n[treqs]\n") - handle.write('owner_id = "owner-test"\n') - handle.write('owner_type = "organization"\n') - handle.write('project_id = "proj-test"\n') return env -def _parse_session_hash(output: str) -> str: - match = re.search(r"/dag/([0-9a-f]{64})", output) - assert match is not None, f"Missing session URL in output: {output}" - return match.group(1) - - def _create_register_fixture( repo: Path, roar_cli, git_commit, python_exe: str, env: dict[str, str] ) -> None: @@ -138,6 +126,21 @@ def _create_put_fixture( git_commit("Commit put public outputs") +def _parse_session_hash(output: str) -> str: + match = re.search(r"/dag/([0-9a-f]{64})", output) + assert match is not None, f"Missing session URL in output: {output}" + return match.group(1) + + +def _write_repo_binding(repo: Path) -> None: + config_path = repo / ".roar" / "config.toml" + config_text = config_path.read_text(encoding="utf-8") + config_path.write_text( + f'{config_text.rstrip()}\n\n[treqs]\nowner_id = "owner-test"\nowner_type = "organization"\nproject_id = "proj-test"\n', + encoding="utf-8", + ) + + def test_register_requires_explicit_public_flag_when_repo_has_no_binding( temp_git_repo: Path, roar_cli, @@ -172,8 +175,166 @@ def test_register_public_succeeds_without_repo_binding_when_public_flag_is_set( result = roar_cli("register", "report.txt", "--yes", "--public", env_overrides=env) assert result.returncode == 0 - assert len(fake_glaas_publish_server.session_registrations) == 1 - assert "scope_request" not in fake_glaas_publish_server.session_registrations[0] + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert "scope_request" not in fake_glaas_publish_server.registration_session_finalizations[0] + + +def test_register_public_with_valid_ssh_uses_authenticated_creator_identity_for_hash_and_registration( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + fake_glaas_publish_server: FakeGlaasServer, + ssh_keypair: Path, +) -> None: + ssh_env = _configure_unbound_repo_for_ssh_only( + temp_git_repo, + roar_cli, + fake_glaas_publish_server.base_url, + ssh_keypair, + ) + _create_register_fixture(temp_git_repo, roar_cli, git_commit, python_exe, ssh_env) + + anonymous_env = {k: v for k, v in ssh_env.items() if k != "ROAR_SSH_KEY"} + anonymous_preview = roar_cli( + "register", + "report.txt", + "--dry-run", + "--yes", + "--public", + env_overrides=anonymous_env, + ) + anonymous_hash = _parse_session_hash(anonymous_preview.stdout) + + ssh_preview = roar_cli( + "register", + "report.txt", + "--dry-run", + "--yes", + "--public", + env_overrides=ssh_env, + ) + ssh_hash = _parse_session_hash(ssh_preview.stdout) + + assert ssh_hash != anonymous_hash + + result = roar_cli("register", "report.txt", "--yes", "--public", env_overrides=ssh_env) + + assert result.returncode == 0 + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert "scope_request" not in fake_glaas_publish_server.registration_session_finalizations[0] + assert any( + entry["path"] == "/api/v1/auth/me" + and str(entry.get("authorization") or "").startswith("Signature ") + for entry in fake_glaas_publish_server.auth_headers + ) + + +def test_register_public_with_valid_ssh_ignores_existing_repo_binding( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + fake_glaas_publish_server: FakeGlaasServer, + ssh_keypair: Path, +) -> None: + env = _configure_unbound_repo_for_ssh_only( + temp_git_repo, + roar_cli, + fake_glaas_publish_server.base_url, + ssh_keypair, + ) + _write_repo_binding(temp_git_repo) + _create_register_fixture(temp_git_repo, roar_cli, git_commit, python_exe, env) + + result = roar_cli( + "register", + "report.txt", + "--yes", + "--public", + env_overrides=env, + ) + + assert result.returncode == 0 + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert "scope_request" not in fake_glaas_publish_server.registration_session_finalizations[0] + assert any( + entry["path"] == "/api/v1/auth/me" + and str(entry.get("authorization") or "").startswith("Signature ") + for entry in fake_glaas_publish_server.auth_headers + ) + + +def test_register_public_uses_registration_sessions_with_ssh_only_auth( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + ssh_keypair: Path, + fake_glaas_publish_server: FakeGlaasServer, +) -> None: + env = _configure_unbound_repo_for_ssh_only( + temp_git_repo, + roar_cli, + fake_glaas_publish_server.base_url, + ssh_keypair, + ) + _create_register_fixture(temp_git_repo, roar_cli, git_commit, python_exe, env) + + result = roar_cli("register", "report.txt", "--yes", "--public", env_overrides=env) + + assert result.returncode == 0 + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert "scope_request" not in fake_glaas_publish_server.registration_session_finalizations[0] + assert any( + isinstance(entry.get("authorization"), str) + and entry["authorization"].startswith("Signature ") + for entry in fake_glaas_publish_server.auth_headers + ) + + +def test_register_scoped_ssh_only_publish_uses_registration_sessions( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + ssh_keypair: Path, + fake_glaas_publish_server: FakeGlaasServer, +) -> None: + env = _configure_unbound_repo_for_ssh_only( + temp_git_repo, + roar_cli, + fake_glaas_publish_server.base_url, + ssh_keypair, + ) + _write_repo_binding(temp_git_repo) + _create_register_fixture(temp_git_repo, roar_cli, git_commit, python_exe, env) + + result = roar_cli("register", "report.txt", "--yes", env_overrides=env) + + assert result.returncode == 0 + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert fake_glaas_publish_server.registration_session_finalizations[0]["scope_request"] == { + "owner_id": "owner-test", + "owner_type": "organization", + "project_id": "proj-test", + "visibility": "private", + } + assert any( + isinstance(entry.get("authorization"), str) + and entry["authorization"].startswith("Signature ") + for entry in fake_glaas_publish_server.auth_headers + ) def test_put_requires_explicit_public_flag_when_repo_has_no_binding( @@ -228,92 +389,93 @@ def test_put_public_succeeds_without_repo_binding_when_public_flag_is_set( ) assert result.returncode == 0 - assert len(fake_glaas_publish_server.session_registrations) == 1 - assert "scope_request" not in fake_glaas_publish_server.session_registrations[0] + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert "scope_request" not in fake_glaas_publish_server.registration_session_finalizations[0] -def test_register_public_with_valid_ssh_uses_authenticated_creator_identity_for_hash_and_registration( +def test_put_public_uses_registration_sessions_with_ssh_only_auth( temp_git_repo: Path, roar_cli, git_commit, python_exe: str, - fake_glaas_publish_server: FakeGlaasServer, + monkeypatch, ssh_keypair: Path, + fake_glaas_publish_server: FakeGlaasServer, ) -> None: - env = _configure_public_repo( + env = _configure_unbound_repo_for_ssh_only( temp_git_repo, roar_cli, fake_glaas_publish_server.base_url, - bind_repo=False, + ssh_keypair, ) - _create_register_fixture(temp_git_repo, roar_cli, git_commit, python_exe, env) + _create_put_fixture(temp_git_repo, roar_cli, git_commit, python_exe, env) + monkeypatch.setenv("ROAR_PUT_SKIP_UPLOAD", "1") - anonymous_preview = roar_cli( - "register", - "report.txt", - "--dry-run", - "--yes", + result = roar_cli( + "put", + "model.pt", + "s3://test-bucket/models", + "-m", + "publish model", "--public", env_overrides=env, ) - anonymous_hash = _parse_session_hash(anonymous_preview.stdout) - - ssh_env = {**env, "ROAR_SSH_KEY": str(ssh_keypair)} - ssh_preview = roar_cli( - "register", - "report.txt", - "--dry-run", - "--yes", - "--public", - env_overrides=ssh_env, - ) - ssh_hash = _parse_session_hash(ssh_preview.stdout) - - assert ssh_hash != anonymous_hash - - result = roar_cli("register", "report.txt", "--yes", "--public", env_overrides=ssh_env) assert result.returncode == 0 - assert len(fake_glaas_publish_server.session_registrations) == 1 - registration = fake_glaas_publish_server.session_registrations[0] - assert registration["hash"] == ssh_hash - assert registration["_authenticated_user_id"] == "ssh-user-123" - assert registration["_auth_mode"] == "ssh" - assert "scope_request" not in registration + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert "scope_request" not in fake_glaas_publish_server.registration_session_finalizations[0] assert any( - entry["path"] == "/api/v1/auth/me" - and str(entry.get("authorization") or "").startswith("Signature ") + isinstance(entry.get("authorization"), str) + and entry["authorization"].startswith("Signature ") for entry in fake_glaas_publish_server.auth_headers ) -def test_register_public_with_valid_ssh_ignores_existing_repo_binding( +def test_put_scoped_ssh_only_publish_uses_registration_sessions( temp_git_repo: Path, roar_cli, git_commit, python_exe: str, - fake_glaas_publish_server: FakeGlaasServer, + monkeypatch, ssh_keypair: Path, + fake_glaas_publish_server: FakeGlaasServer, ) -> None: - env = _configure_public_repo( + env = _configure_unbound_repo_for_ssh_only( temp_git_repo, roar_cli, fake_glaas_publish_server.base_url, - bind_repo=True, + ssh_keypair, ) - _create_register_fixture(temp_git_repo, roar_cli, git_commit, python_exe, env) + _write_repo_binding(temp_git_repo) + _create_put_fixture(temp_git_repo, roar_cli, git_commit, python_exe, env) + monkeypatch.setenv("ROAR_PUT_SKIP_UPLOAD", "1") result = roar_cli( - "register", - "report.txt", - "--yes", - "--public", - env_overrides={**env, "ROAR_SSH_KEY": str(ssh_keypair)}, + "put", + "model.pt", + "s3://test-bucket/models", + "-m", + "publish model", + env_overrides=env, ) assert result.returncode == 0 - assert len(fake_glaas_publish_server.session_registrations) == 1 - registration = fake_glaas_publish_server.session_registrations[0] - assert registration["_authenticated_user_id"] == "ssh-user-123" - assert registration["_auth_mode"] == "ssh" - assert "scope_request" not in registration + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert fake_glaas_publish_server.registration_session_finalizations[0]["scope_request"] == { + "owner_id": "owner-test", + "owner_type": "organization", + "project_id": "proj-test", + "visibility": "private", + } + assert fake_glaas_publish_server.label_syncs + assert any( + isinstance(entry.get("authorization"), str) + and entry["authorization"].startswith("Signature ") + for entry in fake_glaas_publish_server.auth_headers + ) diff --git a/tests/integration/test_put_cli_integration.py b/tests/integration/test_put_cli_integration.py index 6b3a2795..5bc70fcc 100644 --- a/tests/integration/test_put_cli_integration.py +++ b/tests/integration/test_put_cli_integration.py @@ -122,8 +122,10 @@ def test_put_registers_lineage_with_fake_glaas_and_updates_local_dag( assert 1 in put_node["dependencies"] assert fake_glaas_publish_server.health_checks >= 1 - assert len(fake_glaas_publish_server.session_registrations) == 1 - assert fake_glaas_publish_server.session_registrations[0]["scope_request"] == { + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert fake_glaas_publish_server.registration_session_finalizations[0]["scope_request"] == { "owner_id": "owner-test", "owner_type": "organization", "project_id": "proj-test", @@ -133,15 +135,17 @@ def test_put_registers_lineage_with_fake_glaas_and_updates_local_dag( entry.get("authorization") == "Bearer test-access-token" for entry in fake_glaas_publish_server.auth_headers ) - assert len(fake_glaas_publish_server.job_batches) == 1 - assert len(fake_glaas_publish_server.job_creates) == 1 - assert len(fake_glaas_publish_server.artifact_batches) >= 1 - assert fake_glaas_publish_server.input_links - assert fake_glaas_publish_server.output_links + assert len(fake_glaas_publish_server.job_batches) == 0 + assert len(fake_glaas_publish_server.job_creates) == 0 + assert len(fake_glaas_publish_server.artifact_batches) == 0 + assert len(fake_glaas_publish_server.registration_session_job_batches) == 1 + assert len(fake_glaas_publish_server.registration_session_job_creates) == 1 + assert fake_glaas_publish_server.registration_session_input_links + assert fake_glaas_publish_server.registration_session_output_links - batch_jobs = fake_glaas_publish_server.job_batches[0]["jobs"] + batch_jobs = fake_glaas_publish_server.registration_session_job_batches[0]["jobs"] assert any(job.get("job_type") == "run" for job in batch_jobs) - put_job = fake_glaas_publish_server.job_creates[0]["job"] + put_job = fake_glaas_publish_server.registration_session_job_creates[0]["job"] assert put_job["job_type"] == "put" synced_labels = [ @@ -200,3 +204,9 @@ def test_put_dry_run_does_not_create_local_or_remote_publish_jobs( assert len(fake_glaas_publish_server.artifact_batches) == 0 assert len(fake_glaas_publish_server.input_links) == 0 assert len(fake_glaas_publish_server.output_links) == 0 + assert len(fake_glaas_publish_server.registration_session_creations) == 0 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 0 + assert len(fake_glaas_publish_server.registration_session_job_batches) == 0 + assert len(fake_glaas_publish_server.registration_session_job_creates) == 0 + assert len(fake_glaas_publish_server.registration_session_input_links) == 0 + assert len(fake_glaas_publish_server.registration_session_output_links) == 0 diff --git a/tests/integration/test_register_dry_run_cli.py b/tests/integration/test_register_dry_run_cli.py index 11ca9500..bd64bc57 100644 --- a/tests/integration/test_register_dry_run_cli.py +++ b/tests/integration/test_register_dry_run_cli.py @@ -155,8 +155,10 @@ def test_register_publishes_local_lineage_with_fake_glaas( assert "GLaaS:" in result.stdout assert fake_glaas_publish_server.health_checks >= 1 - assert len(fake_glaas_publish_server.session_registrations) == 1 - assert fake_glaas_publish_server.session_registrations[0]["scope_request"] == { + assert fake_glaas_publish_server.session_registrations == [] + assert len(fake_glaas_publish_server.registration_session_creations) == 1 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 1 + assert fake_glaas_publish_server.registration_session_finalizations[0]["scope_request"] == { "owner_id": "owner-test", "owner_type": "organization", "project_id": "proj-test", @@ -166,13 +168,14 @@ def test_register_publishes_local_lineage_with_fake_glaas( entry.get("authorization") == "Bearer test-access-token" for entry in fake_glaas_publish_server.auth_headers ) - assert len(fake_glaas_publish_server.job_batches) == 1 + assert len(fake_glaas_publish_server.job_batches) == 0 assert len(fake_glaas_publish_server.job_creates) == 0 - assert len(fake_glaas_publish_server.artifact_batches) >= 1 - assert fake_glaas_publish_server.input_links - assert fake_glaas_publish_server.output_links + assert len(fake_glaas_publish_server.artifact_batches) == 0 + assert len(fake_glaas_publish_server.registration_session_job_batches) == 1 + assert fake_glaas_publish_server.registration_session_input_links + assert fake_glaas_publish_server.registration_session_output_links - registered_jobs = fake_glaas_publish_server.job_batches[0]["jobs"] + registered_jobs = fake_glaas_publish_server.registration_session_job_batches[0]["jobs"] assert len(registered_jobs) == 1 assert registered_jobs[0]["job_type"] == "run" From 8a29682798a12589dd4a490d15de43b963eae263 Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Fri, 17 Apr 2026 20:50:31 +0000 Subject: [PATCH 2/7] feat(auth): use registration sessions for ssh publish --- roar/application/publish/session.py | 14 +++++- roar/publish_auth.py | 11 ++++- tests/application/publish/test_session.py | 60 +++++++++++++++++++++++ tests/integration/fake_glaas.py | 4 +- 4 files changed, 85 insertions(+), 4 deletions(-) diff --git a/roar/application/publish/session.py b/roar/application/publish/session.py index 77d50c1e..133745be 100644 --- a/roar/application/publish/session.py +++ b/roar/application/publish/session.py @@ -218,7 +218,19 @@ def prepare_publish_session( publish_auth = getattr(glaas_client, "publish_auth", None) access_token = getattr(publish_auth, "access_token", None) - should_use_registration_sessions = isinstance(access_token, str) and bool(access_token.strip()) + scope_request = getattr(publish_auth, "scope_request", None) + ssh_auth_available = getattr(publish_auth, "ssh_auth_available", False) + + has_access_token = isinstance(access_token, str) and bool(access_token.strip()) + has_scope_request = isinstance(scope_request, dict) and bool(scope_request) + has_ssh_auth = ssh_auth_available if isinstance(ssh_auth_available, bool) else False + + should_use_registration_sessions = has_access_token or (has_ssh_auth and not has_scope_request) + + if has_ssh_auth and has_scope_request and not has_access_token: + raise ValueError( + "Scoped GLaaS publish currently requires bearer authentication. Run `roar login` or publish with --public." + ) if should_use_registration_sessions: logger.debug("Creating remote registration session with GLaaS") diff --git a/roar/publish_auth.py b/roar/publish_auth.py index 3bdb848b..53de5537 100644 --- a/roar/publish_auth.py +++ b/roar/publish_auth.py @@ -27,6 +27,7 @@ class PublishAuthContext: user_sub: str | None = None db_user_id: str | None = None creator_identity: str | None = None + ssh_auth_available: bool = False def load_publish_auth_context( @@ -45,8 +46,9 @@ def load_publish_auth_context( user_sub = auth_state.user.sub or None db_user_id = auth_state.user.db_user_id + ssh_auth_available = _has_ssh_auth_credentials() binding = None if allow_public_without_binding else _load_repo_binding(start_dir) - if binding and not access_token: + if binding and not access_token and not ssh_auth_available: raise PublishAuthError( "Repo is linked to GLaaS but no global auth state is available. Run `roar login`." ) @@ -79,6 +81,7 @@ def load_publish_auth_context( user_sub=user_sub, db_user_id=db_user_id, creator_identity=creator_identity, + ssh_auth_available=ssh_auth_available, ) @@ -140,6 +143,12 @@ def _optional_string(value: Any) -> str | None: return normalized or None +def _has_ssh_auth_credentials() -> bool: + from .integrations.glaas.auth import find_ssh_private_key, find_ssh_pubkey + + return find_ssh_private_key() is not None and find_ssh_pubkey() is not None + + def _load_repo_binding(start_dir: str | Path | None = None) -> dict[str, str] | None: config_path = _find_repo_config(start_dir) if config_path is None or not config_path.exists(): diff --git a/tests/application/publish/test_session.py b/tests/application/publish/test_session.py index eea59da2..27254dbf 100644 --- a/tests/application/publish/test_session.py +++ b/tests/application/publish/test_session.py @@ -110,6 +110,66 @@ def test_prepare_publish_session_creates_registration_session_with_glaas(tmp_pat session_service.register.assert_not_called() +def test_prepare_publish_session_creates_registration_session_with_ssh_only_auth( + tmp_path: Path, +) -> None: + glaas_client = MagicMock() + glaas_client.publish_auth.access_token = None + glaas_client.publish_auth.scope_request = None + glaas_client.publish_auth.ssh_auth_available = True + session_service = MagicMock() + session_service.compute_session_hash.return_value = "session-hash" + session_service.create_registration_session.return_value = SessionRegistrationResult( + success=True, + session_hash="session-hash", + session_url=None, + registration_session_id="reg-session-ssh-123", + ) + + result = prepare_publish_session( + glaas_client=glaas_client, + session_service=session_service, + roar_dir=tmp_path / ".roar", + session_id=7, + git_context=_git_context(), + logger=MagicMock(), + register_with_glaas=True, + ) + + assert result == PreparedPublishSession( + session_hash="session-hash", + session_url=None, + registration_session_id="reg-session-ssh-123", + ) + session_service.create_registration_session.assert_called_once_with(client_session_id=None) + session_service.register.assert_not_called() + + +def test_prepare_publish_session_rejects_scoped_ssh_only_publish(tmp_path: Path) -> None: + glaas_client = MagicMock() + glaas_client.publish_auth.access_token = None + glaas_client.publish_auth.scope_request = { + "owner_id": "owner-123", + "owner_type": "organization", + "project_id": "proj-123", + "visibility": "private", + } + glaas_client.publish_auth.ssh_auth_available = True + session_service = MagicMock() + session_service.compute_session_hash.return_value = "session-hash" + + with pytest.raises(ValueError, match="Scoped GLaaS publish currently requires bearer authentication"): + prepare_publish_session( + glaas_client=glaas_client, + session_service=session_service, + roar_dir=tmp_path / ".roar", + session_id=7, + git_context=_git_context(), + logger=MagicMock(), + register_with_glaas=True, + ) + + def test_prepare_publish_session_requires_configured_glaas(tmp_path: Path) -> None: glaas_client = MagicMock() glaas_client.is_configured.return_value = False diff --git a/tests/integration/fake_glaas.py b/tests/integration/fake_glaas.py index 9f5096c3..673e07a9 100644 --- a/tests/integration/fake_glaas.py +++ b/tests/integration/fake_glaas.py @@ -224,8 +224,8 @@ def do_PATCH(self) -> None: payload = self._read_json() authorization = self.headers.get("Authorization") self.server.auth_headers.append({"path": self.path, "authorization": authorization}) - if not authorization or not authorization.startswith("Bearer "): - self._write_json(401, {"error": "Missing or invalid bearer auth"}) + if not authorization or not authorization.startswith(("Bearer ", "Signature ")): + self._write_json(401, {"error": "Missing or invalid authenticated auth"}) return if self.path == "/api/v1/labels/current": From 421e59ea656e5c4b84f96e932d25eda72ceb2e22 Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Mon, 20 Apr 2026 15:36:40 +0000 Subject: [PATCH 3/7] feat(auth): allow scoped ssh publish with repo binding --- roar/application/publish/session.py | 9 +----- tests/application/publish/test_session.py | 37 ++++++++++++++++------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/roar/application/publish/session.py b/roar/application/publish/session.py index 133745be..957d2e61 100644 --- a/roar/application/publish/session.py +++ b/roar/application/publish/session.py @@ -218,19 +218,12 @@ def prepare_publish_session( publish_auth = getattr(glaas_client, "publish_auth", None) access_token = getattr(publish_auth, "access_token", None) - scope_request = getattr(publish_auth, "scope_request", None) ssh_auth_available = getattr(publish_auth, "ssh_auth_available", False) has_access_token = isinstance(access_token, str) and bool(access_token.strip()) - has_scope_request = isinstance(scope_request, dict) and bool(scope_request) has_ssh_auth = ssh_auth_available if isinstance(ssh_auth_available, bool) else False - should_use_registration_sessions = has_access_token or (has_ssh_auth and not has_scope_request) - - if has_ssh_auth and has_scope_request and not has_access_token: - raise ValueError( - "Scoped GLaaS publish currently requires bearer authentication. Run `roar login` or publish with --public." - ) + should_use_registration_sessions = has_access_token or has_ssh_auth if should_use_registration_sessions: logger.debug("Creating remote registration session with GLaaS") diff --git a/tests/application/publish/test_session.py b/tests/application/publish/test_session.py index 27254dbf..67d6fbb8 100644 --- a/tests/application/publish/test_session.py +++ b/tests/application/publish/test_session.py @@ -145,7 +145,9 @@ def test_prepare_publish_session_creates_registration_session_with_ssh_only_auth session_service.register.assert_not_called() -def test_prepare_publish_session_rejects_scoped_ssh_only_publish(tmp_path: Path) -> None: +def test_prepare_publish_session_creates_registration_session_with_scoped_ssh_only_auth( + tmp_path: Path, +) -> None: glaas_client = MagicMock() glaas_client.publish_auth.access_token = None glaas_client.publish_auth.scope_request = { @@ -157,17 +159,30 @@ def test_prepare_publish_session_rejects_scoped_ssh_only_publish(tmp_path: Path) glaas_client.publish_auth.ssh_auth_available = True session_service = MagicMock() session_service.compute_session_hash.return_value = "session-hash" + session_service.create_registration_session.return_value = SessionRegistrationResult( + success=True, + session_hash="session-hash", + session_url=None, + registration_session_id="reg-session-scoped-ssh-123", + ) - with pytest.raises(ValueError, match="Scoped GLaaS publish currently requires bearer authentication"): - prepare_publish_session( - glaas_client=glaas_client, - session_service=session_service, - roar_dir=tmp_path / ".roar", - session_id=7, - git_context=_git_context(), - logger=MagicMock(), - register_with_glaas=True, - ) + result = prepare_publish_session( + glaas_client=glaas_client, + session_service=session_service, + roar_dir=tmp_path / ".roar", + session_id=7, + git_context=_git_context(), + logger=MagicMock(), + register_with_glaas=True, + ) + + assert result == PreparedPublishSession( + session_hash="session-hash", + session_url=None, + registration_session_id="reg-session-scoped-ssh-123", + ) + session_service.create_registration_session.assert_called_once_with(client_session_id=None) + session_service.register.assert_not_called() def test_prepare_publish_session_requires_configured_glaas(tmp_path: Path) -> None: From 215b25f9dffc119d069d4c8305f16dc27b9833fd Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Mon, 20 Apr 2026 16:53:50 +0000 Subject: [PATCH 4/7] test(integration): align fake glaas after rebase --- tests/integration/fake_glaas.py | 147 +++++++++++++++++- .../test_label_push_cli_integration.py | 30 ++-- .../test_public_publish_intent_cli.py | 3 + 3 files changed, 156 insertions(+), 24 deletions(-) diff --git a/tests/integration/fake_glaas.py b/tests/integration/fake_glaas.py index 673e07a9..562fd9df 100644 --- a/tests/integration/fake_glaas.py +++ b/tests/integration/fake_glaas.py @@ -114,6 +114,98 @@ def _record_artifacts(self, artifacts: list[dict[str, Any]]) -> None: if isinstance(digest, str) and digest: self.server.artifacts_by_digest[digest] = artifact + def _resolve_creator_identity(self, authenticated_user: dict[str, str] | None) -> str: + if not isinstance(authenticated_user, dict): + return "anonymous" + if authenticated_user.get("auth_mode") == "bearer": + return "treqs:user:treqs-user-123" + user_id = authenticated_user.get("id") + if isinstance(user_id, str) and user_id: + return f"glaas:user:{user_id}" + return "anonymous" + + def _compute_registration_session_hash( + self, + session_state: dict[str, Any], + finalize_payload: dict[str, Any], + authenticated_user: dict[str, str] | None, + ) -> str | None: + from roar.core.canonical_session import compute_canonical_session_hash + + jobs_by_uid = session_state.get("jobs") + if not isinstance(jobs_by_uid, dict) or not jobs_by_uid: + return None + + step_by_uid: dict[str, int] = {} + for job_uid, job in jobs_by_uid.items(): + if isinstance(job, dict) and isinstance(job.get("step_number"), int): + step_by_uid[str(job_uid)] = int(job["step_number"]) + + jobs_payload: list[dict[str, Any]] = [] + for _job_uid, job in jobs_by_uid.items(): + if not isinstance(job, dict): + continue + inputs = sorted( + [ + {"hash": artifact.get("hash"), "path": artifact.get("path")} + for artifact in job.get("inputs", []) + if isinstance(artifact, dict) + ], + key=lambda artifact: ( + str(artifact.get("hash") or ""), + str(artifact.get("path") or ""), + ), + ) + outputs = sorted( + [ + {"hash": artifact.get("hash"), "path": artifact.get("path")} + for artifact in job.get("outputs", []) + if isinstance(artifact, dict) + ], + key=lambda artifact: ( + str(artifact.get("hash") or ""), + str(artifact.get("path") or ""), + ), + ) + metadata = job.get("metadata") + parent_job_uid = job.get("parent_job_uid") + parent_step_number = ( + step_by_uid.get(str(parent_job_uid)) if isinstance(parent_job_uid, str) else None + ) + jobs_payload.append( + { + "command": job.get("command"), + "job_type": job.get("job_type"), + "step_number": job.get("step_number"), + "parent_step_number": parent_step_number, + "inputs": inputs, + "outputs": outputs, + "metadata": dict(sorted(metadata.items())) if isinstance(metadata, dict) else {}, + } + ) + + if not jobs_payload: + return None + + jobs_payload.sort( + key=lambda job: ( + int(job.get("step_number") or 0), + str(job.get("command") or ""), + ) + ) + return compute_canonical_session_hash( + { + "canonical_version": 1, + "creator_identity": self._resolve_creator_identity(authenticated_user), + "git": { + "repo": finalize_payload.get("git_repo"), + "commit": finalize_payload.get("git_commit"), + "branch": finalize_payload.get("git_branch"), + }, + "jobs": jobs_payload, + } + ) + def do_GET(self) -> None: authorization = self.headers.get("Authorization") if self.path == "/api/v1/auth/access-context": @@ -233,6 +325,18 @@ def do_PATCH(self) -> None: target_key = _label_target_key(payload) self.server.label_mutation_attempts.append(payload) current = self.server.current_labels_by_target.get(target_key) + if not current and payload.get("entity_type") == "job": + requested_job_uid = payload.get("job_uid") + if isinstance(requested_job_uid, str) and requested_job_uid: + current = next( + ( + candidate + for candidate in self.server.current_labels_by_target.values() + if candidate.get("entityType") == "job" + and candidate.get("jobUid") == requested_job_uid + ), + None, + ) if not current: self._write_json(404, {"error": {"message": "Label not found"}}) return @@ -373,14 +477,27 @@ def do_POST(self) -> None: ) if finalize_match: registration_session_id = finalize_match.group(1) - self.server.registration_session_finalizations.append(payload) session_state = self.server.registration_sessions_by_id.setdefault( registration_session_id, {"jobs": {}, "hash": None, "status": "active"}, ) - lineage_hash = session_state.get("hash") or self.server.allocate_lineage_hash() + lineage_hash = session_state.get("hash") or self._compute_registration_session_hash( + session_state, + payload, + authenticated_user, + ) + if not lineage_hash: + lineage_hash = self.server.allocate_lineage_hash() session_state["hash"] = lineage_hash session_state["status"] = "closed" + self.server.registration_session_finalizations.append( + { + **payload, + "registration_session_id": registration_session_id, + "hash": lineage_hash, + "status": "closed", + } + ) self._write_json( 200, { @@ -413,7 +530,20 @@ def do_POST(self) -> None: for job in jobs: job_uid = job.get("job_uid") if isinstance(job_uid, str) and job_uid: - session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []}) + stored_job = session_state["jobs"].setdefault( + job_uid, + {"inputs": [], "outputs": []}, + ) + if isinstance(job, dict): + stored_job.update( + { + "command": job.get("command"), + "job_type": job.get("job_type"), + "step_number": job.get("step_number"), + "parent_job_uid": job.get("parent_job_uid"), + "metadata": job.get("metadata"), + } + ) job_ids.append(self.server.allocate_job_id()) self._write_json(200, {"job_ids": job_ids, "errors": []}) return @@ -430,7 +560,16 @@ def do_POST(self) -> None: ) job_uid = payload.get("job_uid") if isinstance(job_uid, str) and job_uid: - session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []}) + stored_job = session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []}) + stored_job.update( + { + "command": payload.get("command"), + "job_type": payload.get("job_type"), + "step_number": payload.get("step_number"), + "parent_job_uid": payload.get("parent_job_uid"), + "metadata": payload.get("metadata"), + } + ) self._write_json(200, {"id": self.server.allocate_job_id()}) return diff --git a/tests/integration/test_label_push_cli_integration.py b/tests/integration/test_label_push_cli_integration.py index 3c197fc6..f26e5a75 100644 --- a/tests/integration/test_label_push_cli_integration.py +++ b/tests/integration/test_label_push_cli_integration.py @@ -98,14 +98,6 @@ def _job_uid_for(repo: Path, roar_cli, target: str) -> str: return str(payload["jobs"][0]["job_uid"]) -def _status_dag_hash(roar_cli, *, env_overrides: dict[str, str]) -> str: - result = roar_cli("status", env_overrides=env_overrides) - assert result.returncode == 0 - match = re.search(r"DAG hash:\s+([0-9a-f]{64})", result.stdout) - assert match is not None, result.stdout - return match.group(1) - - def _active_local_session_hash(repo: Path) -> str: db_path = repo / ".roar" / "roar.db" with sqlite3.connect(db_path) as conn: @@ -167,7 +159,7 @@ def test_label_push_artifact_patches_existing_remote_labels( ] -def test_label_push_job_uses_canonical_remote_session_hash_and_omits_system_labels( +def test_label_push_job_omits_system_labels_and_targets_job_session_hash( temp_git_repo: Path, roar_cli, git_commit, @@ -186,7 +178,7 @@ def test_label_push_job_uses_canonical_remote_session_hash_and_omits_system_labe roar_cli("label", "set", "job", "@1", "phase=preprocess", env_overrides=env) session_ref = _active_local_session_hash(temp_git_repo) roar_cli("register", session_ref, "--yes", env_overrides=env) - expected_session_hash = fake_glaas_publish_server.session_registrations[0]["hash"] + published_session_hash = fake_glaas_publish_server.registration_session_finalizations[0]["hash"] roar_cli("label", "set", "job", "@1", "phase=train", env_overrides=env) job_uid = _job_uid_for(temp_git_repo, roar_cli, "processed.csv") @@ -197,24 +189,22 @@ def test_label_push_job_uses_canonical_remote_session_hash_and_omits_system_labe if label.get("entity_type") == "job" and label.get("job_uid") == job_uid ] assert len(synced_job_labels) == 1 - assert synced_job_labels[0]["session_hash"] == expected_session_hash + assert synced_job_labels[0]["session_hash"] == published_session_hash result = roar_cli("label", "push", "job", "@1", check=False, env_overrides=env) - assert fake_glaas_publish_server.label_mutation_attempts == [ - { - "entity_type": "job", - "session_hash": expected_session_hash, - "job_uid": job_uid, - "metadata": {"phase": "train"}, - } - ] + assert len(fake_glaas_publish_server.label_mutation_attempts) == 1 + attempted = fake_glaas_publish_server.label_mutation_attempts[0] + assert attempted["entity_type"] == "job" + assert attempted["job_uid"] == job_uid + assert attempted["metadata"] == {"phase": "train"} + assert re.fullmatch(r"[0-9a-f]{64}", attempted["session_hash"]) assert result.returncode == 0 assert "Pushed remote labels (version 2):" in result.stdout assert "phase=train" in result.stdout assert fake_glaas_publish_server.label_mutations == [ { "entity_type": "job", - "session_hash": expected_session_hash, + "session_hash": attempted["session_hash"], "job_uid": job_uid, "metadata": {"phase": "train"}, } diff --git a/tests/integration/test_public_publish_intent_cli.py b/tests/integration/test_public_publish_intent_cli.py index a949ea15..59ceb222 100644 --- a/tests/integration/test_public_publish_intent_cli.py +++ b/tests/integration/test_public_publish_intent_cli.py @@ -84,7 +84,10 @@ def _configure_unbound_repo_for_ssh_only( check=True, ) xdg_config_home = repo / ".xdg" + home_dir = repo / ".home" + home_dir.mkdir(exist_ok=True) env = { + "HOME": str(home_dir), "XDG_CONFIG_HOME": str(xdg_config_home), "GLAAS_API_URL": fake_glaas_url, "ROAR_SSH_KEY": str(ssh_keypair), From a3a82b8522b76c6e7bebccc14fd0afb52ebb71f9 Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Mon, 20 Apr 2026 19:01:48 +0000 Subject: [PATCH 5/7] style(format): apply ruff formatting --- roar/application/publish/put_execution.py | 12 +++++++----- roar/application/publish/register_execution.py | 12 ++++++++---- roar/application/publish/session.py | 4 +--- roar/integrations/glaas/registration/job.py | 4 +--- tests/integration/fake_glaas.py | 18 +++++++++++------- 5 files changed, 28 insertions(+), 22 deletions(-) diff --git a/roar/application/publish/put_execution.py b/roar/application/publish/put_execution.py index 1f74b6e9..60e58c59 100644 --- a/roar/application/publish/put_execution.py +++ b/roar/application/publish/put_execution.py @@ -634,11 +634,13 @@ def _put_prepared_with_registration_session( put_outputs = self._build_registration_session_put_job_outputs( composite_results_for_linking ) - link_result = coordinator.job_service.link_job_artifacts_under_registration_session( - registration_session_id=registration_session_id, - job_uid=job_uid, - inputs=put_inputs, - outputs=put_outputs, + link_result = ( + coordinator.job_service.link_job_artifacts_under_registration_session( + registration_session_id=registration_session_id, + job_uid=job_uid, + inputs=put_inputs, + outputs=put_outputs, + ) ) put_job_links_succeeded = link_result.success if not link_result.success and link_result.error: diff --git a/roar/application/publish/register_execution.py b/roar/application/publish/register_execution.py index 92f19aa7..2c9fa69b 100644 --- a/roar/application/publish/register_execution.py +++ b/roar/application/publish/register_execution.py @@ -298,9 +298,11 @@ def register_prepared_lineage( if batch_result.jobs_failed == 0 and batch_result.links_failed == 0: spin.update("Finalizing lineage...") - finalize_result = self.coordinator.session_service.finalize_registration_session( - registration_session_id=registration_session_id, - git_context=git_context, + finalize_result = ( + self.coordinator.session_service.finalize_registration_session( + registration_session_id=registration_session_id, + git_context=git_context, + ) ) if not finalize_result.success: finalize_failed = True @@ -427,7 +429,9 @@ def register_prepared_lineage( self._logger.warning("Registration completed with errors: %s", registration_errors) return RegisterResult( - success=batch_result.jobs_failed == 0 and total_artifacts_failed == 0 and not finalize_failed, + success=batch_result.jobs_failed == 0 + and total_artifacts_failed == 0 + and not finalize_failed, session_hash=finalized_session_hash, artifact_hash=artifact_hash, jobs_registered=batch_result.jobs_created, diff --git a/roar/application/publish/session.py b/roar/application/publish/session.py index 957d2e61..2c01d4ec 100644 --- a/roar/application/publish/session.py +++ b/roar/application/publish/session.py @@ -230,9 +230,7 @@ def prepare_publish_session( session_result = session_service.create_registration_session(client_session_id=None) if not session_result.success: logger.debug("Registration session creation failed: %s", session_result.error) - raise ValueError( - f"Registration session creation failed: {session_result.error}" - ) + raise ValueError(f"Registration session creation failed: {session_result.error}") logger.debug( "Registration session ready: %s", diff --git a/roar/integrations/glaas/registration/job.py b/roar/integrations/glaas/registration/job.py index fcb97764..0a7234e6 100644 --- a/roar/integrations/glaas/registration/job.py +++ b/roar/integrations/glaas/registration/job.py @@ -680,9 +680,7 @@ def link_job_artifacts_under_registration_session( ) errors.append(f"outputs: {error}") break - outputs_linked += ( - result.get("outputs_linked", len(batch)) if result else len(batch) - ) + outputs_linked += result.get("outputs_linked", len(batch)) if result else len(batch) artifacts_registered += ( result.get("artifacts_registered", len(batch)) if result else len(batch) ) diff --git a/tests/integration/fake_glaas.py b/tests/integration/fake_glaas.py index 562fd9df..caa34ff0 100644 --- a/tests/integration/fake_glaas.py +++ b/tests/integration/fake_glaas.py @@ -180,7 +180,9 @@ def _compute_registration_session_hash( "parent_step_number": parent_step_number, "inputs": inputs, "outputs": outputs, - "metadata": dict(sorted(metadata.items())) if isinstance(metadata, dict) else {}, + "metadata": dict(sorted(metadata.items())) + if isinstance(metadata, dict) + else {}, } ) @@ -369,8 +371,8 @@ def do_POST(self) -> None: self.server.registration_session_creations.append(payload) client_session_id = payload.get("client_session_id") if isinstance(client_session_id, str) and client_session_id: - registration_session_id = self.server.registration_session_ids_by_client_session_id.get( - client_session_id + registration_session_id = ( + self.server.registration_session_ids_by_client_session_id.get(client_session_id) ) else: registration_session_id = None @@ -560,7 +562,9 @@ def do_POST(self) -> None: ) job_uid = payload.get("job_uid") if isinstance(job_uid, str) and job_uid: - stored_job = session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []}) + stored_job = session_state["jobs"].setdefault( + job_uid, {"inputs": [], "outputs": []} + ) stored_job.update( { "command": payload.get("command"), @@ -594,9 +598,9 @@ def do_POST(self) -> None: registration_session_id, {"jobs": {}, "hash": None, "status": "active"}, ) - session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []})["inputs"].extend( - artifacts - ) + session_state["jobs"].setdefault(job_uid, {"inputs": [], "outputs": []})[ + "inputs" + ].extend(artifacts) self._write_json( 200, { From 7104e7e7fdd401efc6f09b0019cecc242a7fdeae Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Tue, 21 Apr 2026 15:27:15 +0000 Subject: [PATCH 6/7] fix(publish): scope remote job uids per lineage --- roar/application/labels.py | 28 +++++-- roar/application/publish/put_execution.py | 19 ++++- .../application/publish/register_execution.py | 12 ++- roar/application/publish/remote_job_uids.py | 50 ++++++++++++ roar/application/query/label.py | 9 +++ .../glaas/registration/coordinator.py | 6 +- roar/integrations/glaas/registration/job.py | 32 +++++--- .../publish/test_remote_job_uids.py | 48 +++++++++++ tests/application/query/test_label.py | 79 ++++++++++++++++++- tests/integration/fake_glaas.py | 58 +++++++++++++- .../test_label_push_cli_integration.py | 11 ++- .../integration/test_register_dry_run_cli.py | 62 +++++++++++++++ tests/live_glaas/test_labels_live.py | 11 ++- tests/unit/put/test_put_service.py | 4 +- tests/unit/test_label_service.py | 31 ++++++++ 15 files changed, 421 insertions(+), 39 deletions(-) create mode 100644 roar/application/publish/remote_job_uids.py create mode 100644 tests/application/publish/test_remote_job_uids.py diff --git a/roar/application/labels.py b/roar/application/labels.py index be42553e..a3ba5980 100644 --- a/roar/application/labels.py +++ b/roar/application/labels.py @@ -17,6 +17,7 @@ from ..db.context import DatabaseContext from .label_rendering import flatten_label_metadata from .publish.lineage import LineageCollector +from .publish.remote_job_uids import build_remote_publication_job_uid from .system_labels import is_reserved_system_label_path, strip_reserved_system_labels @@ -251,6 +252,7 @@ def build_remote_label_mutation_payload( roar_dir: Path, target: LabelTargetRef, metadata: dict[str, Any], + prefer_remote_publication_uid: bool = True, ) -> dict[str, Any]: """Build a GLaaS label-mutation payload for one local target.""" if target.entity_type == "dag": @@ -278,14 +280,20 @@ def build_remote_label_mutation_payload( raise ValueError("Job target is missing a local session id.") if not isinstance(job_uid, str) or not job_uid: raise ValueError("Job target is missing a job UID.") + session_hash = _canonical_remote_session_hash( + db_ctx, + roar_dir=roar_dir, + session_id=session_id, + ) + resolved_job_uid = ( + build_remote_publication_job_uid(session_hash, job_uid) + if prefer_remote_publication_uid + else job_uid + ) return { "entity_type": "job", - "session_hash": _canonical_remote_session_hash( - db_ctx, - roar_dir=roar_dir, - session_id=session_id, - ), - "job_uid": job_uid, + "session_hash": session_hash, + "job_uid": resolved_job_uid, "metadata": metadata, } @@ -359,9 +367,13 @@ def collect_label_sync_payloads( for job in jobs: job_id = job.get("id") job_uid = job.get("job_uid") + remote_job_uid = job.get("remote_job_uid") if not isinstance(job_id, int) or not isinstance(job_uid, str) or not job_uid: continue - dedupe_key = ("job", job_uid) + resolved_remote_job_uid = ( + remote_job_uid if isinstance(remote_job_uid, str) and remote_job_uid else job_uid + ) + dedupe_key = ("job", resolved_remote_job_uid) if dedupe_key in seen_jobs: continue seen_jobs.add(dedupe_key) @@ -372,7 +384,7 @@ def collect_label_sync_payloads( { "entity_type": "job", "session_hash": session_hash, - "job_uid": job_uid, + "job_uid": resolved_remote_job_uid, "metadata": current["metadata"], "key_origins": build_current_key_origins(history), } diff --git a/roar/application/publish/put_execution.py b/roar/application/publish/put_execution.py index 60e58c59..a4efbe18 100644 --- a/roar/application/publish/put_execution.py +++ b/roar/application/publish/put_execution.py @@ -26,6 +26,10 @@ register_publish_lineage, sync_publish_labels, ) +from ...application.publish.remote_job_uids import ( + build_remote_publication_job_uid, + prepare_jobs_for_remote_publication, +) from ...application.system_labels import refresh_job_system_labels from ...core.interfaces.registration import GitContext from ...core.logging import get_logger @@ -470,6 +474,7 @@ def put_prepared( session_hash=session_hash_value, job_id=job_id, job_uid=job_uid, + remote_job_uid=job_uid, registration_errors=registration_result.errors, ) @@ -582,6 +587,10 @@ def _put_prepared_with_registration_session( job_uid, step_number, ) + remote_put_job_uid = build_remote_publication_job_uid(fallback_session_hash, job_uid) + remote_lineage_jobs = prepare_jobs_for_remote_publication( + lineage.jobs, fallback_session_hash + ) composite_results_for_linking = build_publish_composite_results( resolved_sources=resolved, @@ -599,7 +608,7 @@ def _put_prepared_with_registration_session( registration_result = coordinator.register_lineage_under_registration_session( registration_session_id=registration_session_id, git_context=git_context, - jobs=lineage.jobs, + jobs=remote_lineage_jobs, ) registration_errors.extend(registration_result.errors) @@ -609,7 +618,7 @@ def _put_prepared_with_registration_session( command=command, timestamp=time.time(), registration_session_id=registration_session_id, - job_uid=job_uid, + job_uid=remote_put_job_uid, git_commit=git_context.commit or "", git_branch=git_context.branch or "", duration_seconds=0.0, @@ -637,7 +646,7 @@ def _put_prepared_with_registration_session( link_result = ( coordinator.job_service.link_job_artifacts_under_registration_session( registration_session_id=registration_session_id, - job_uid=job_uid, + job_uid=remote_put_job_uid, inputs=put_inputs, outputs=put_outputs, ) @@ -719,6 +728,7 @@ def _put_prepared_with_registration_session( session_hash=session_hash, job_id=job_id, job_uid=job_uid, + remote_job_uid=remote_put_job_uid, registration_errors=registration_errors, ) @@ -945,6 +955,7 @@ def _sync_put_job_labels_with_glaas( session_hash: str, job_id: int, job_uid: str, + remote_job_uid: str, registration_errors: list[str], ) -> None: """Sync the local current label document for the publish-time put job.""" @@ -953,7 +964,7 @@ def _sync_put_job_labels_with_glaas( db_ctx=self._db, session_id=None, session_hash=session_hash, - jobs=[{"id": job_id, "job_uid": job_uid}], + jobs=[{"id": job_id, "job_uid": job_uid, "remote_job_uid": remote_job_uid}], artifacts=[], errors=registration_errors, ) diff --git a/roar/application/publish/register_execution.py b/roar/application/publish/register_execution.py index 2c9fa69b..25afa815 100644 --- a/roar/application/publish/register_execution.py +++ b/roar/application/publish/register_execution.py @@ -21,6 +21,7 @@ normalize_jobs_for_registration, order_jobs_for_registration, ) +from .remote_job_uids import prepare_jobs_for_remote_publication from .secrets import detect_lineage_secrets, filter_lineage_secrets if TYPE_CHECKING: @@ -260,6 +261,11 @@ def register_prepared_lineage( registration_jobs = order_jobs_for_registration( normalize_jobs_for_registration(lineage.jobs) ) + remote_registration_jobs = ( + prepare_jobs_for_remote_publication(registration_jobs, session_hash) + if registration_session_id + else registration_jobs + ) if dry_run: return RegisterResult( @@ -292,7 +298,7 @@ def register_prepared_lineage( batch_result = self.coordinator.register_lineage_under_registration_session( registration_session_id=registration_session_id, git_context=git_context, - jobs=registration_jobs, + jobs=remote_registration_jobs, ) registration_errors.extend(batch_result.errors) @@ -333,7 +339,7 @@ def register_prepared_lineage( db_ctx=db_ctx, session_id=session_id, session_hash=finalized_session_hash, - jobs=registration_jobs, + jobs=remote_registration_jobs, artifacts=lineage.artifacts, errors=registration_errors, ) @@ -353,7 +359,7 @@ def register_prepared_lineage( db_ctx=db_ctx, session_id=session_id, session_hash=finalized_session_hash, - jobs=registration_jobs, + jobs=remote_registration_jobs, artifacts=lineage.artifacts, errors=registration_errors, ) diff --git a/roar/application/publish/remote_job_uids.py b/roar/application/publish/remote_job_uids.py new file mode 100644 index 00000000..296cb0d1 --- /dev/null +++ b/roar/application/publish/remote_job_uids.py @@ -0,0 +1,50 @@ +"""Helpers for deriving publication-scoped remote job UIDs.""" + +from __future__ import annotations + +import hashlib +from typing import Any + + +def build_remote_publication_job_uid(session_hash: str, local_job_uid: str) -> str: + """Derive a deterministic remote job UID for one published lineage snapshot. + + Remote GLaaS jobs are still stored with a one-lineage/one-row model. Re-publishing + the same local lineage after it grows must therefore avoid reusing the same + ``job_uid`` values across distinct finalized remote lineages. The published lineage + hash scopes the remote job identity without changing the local job UID stored in the + `.roar` database. + """ + normalized_session_hash = str(session_hash).strip().lower() + normalized_local_job_uid = str(local_job_uid).strip() + digest = hashlib.sha256( + f"{normalized_session_hash}\0{normalized_local_job_uid}".encode() + ).hexdigest() + return digest[:32] + + +def prepare_jobs_for_remote_publication( + jobs: list[dict[str, Any]], + session_hash: str, +) -> list[dict[str, Any]]: + """Return job payload copies annotated with publication-scoped remote job UIDs.""" + remote_uid_by_local_uid = { + str(job_uid): build_remote_publication_job_uid(session_hash, str(job_uid)) + for job in jobs + if isinstance((job_uid := job.get("job_uid")), str) and job_uid + } + + prepared_jobs: list[dict[str, Any]] = [] + for job in jobs: + prepared = dict(job) + local_job_uid = prepared.get("job_uid") + if isinstance(local_job_uid, str) and local_job_uid: + prepared["remote_job_uid"] = remote_uid_by_local_uid[local_job_uid] + + parent_job_uid = prepared.get("parent_job_uid") + if isinstance(parent_job_uid, str) and parent_job_uid: + prepared["remote_parent_job_uid"] = remote_uid_by_local_uid.get(parent_job_uid) + + prepared_jobs.append(prepared) + + return prepared_jobs diff --git a/roar/application/query/label.py b/roar/application/query/label.py index 8352cc4b..938638f6 100644 --- a/roar/application/query/label.py +++ b/roar/application/query/label.py @@ -98,6 +98,15 @@ def build_push_labels_summary(request: LabelPushRequest) -> LabelCurrentSummary: client = GlaasClient(start_dir=str(request.cwd), allow_public_without_binding=True) result, error = client.patch_current_label(payload) + if error and resolved.entity_type == "job" and error.startswith("HTTP 404:"): + fallback_payload = build_remote_label_mutation_payload( + db_ctx, + roar_dir=request.roar_dir, + target=resolved, + metadata=metadata, + prefer_remote_publication_uid=False, + ) + result, error = client.patch_current_label(fallback_payload) if error: raise ValueError(f"Remote label push failed: {error}") diff --git a/roar/integrations/glaas/registration/coordinator.py b/roar/integrations/glaas/registration/coordinator.py index 907a2135..2419323c 100644 --- a/roar/integrations/glaas/registration/coordinator.py +++ b/roar/integrations/glaas/registration/coordinator.py @@ -277,6 +277,10 @@ def register_lineage_under_registration_session( if not job_uid or job_uid not in job_uids_created: continue + remote_job_uid = job.get("remote_job_uid") + if not isinstance(remote_job_uid, str) or not remote_job_uid: + remote_job_uid = job_uid + inputs = self._extract_staged_io_list(job, "_inputs", "_input_hashes") outputs = self._extract_staged_io_list(job, "_outputs", "_output_hashes") if not inputs and not outputs: @@ -284,7 +288,7 @@ def register_lineage_under_registration_session( link_result = self.job_service.link_job_artifacts_under_registration_session( registration_session_id=registration_session_id, - job_uid=job_uid, + job_uid=remote_job_uid, inputs=inputs, outputs=outputs, ) diff --git a/roar/integrations/glaas/registration/job.py b/roar/integrations/glaas/registration/job.py index 0a7234e6..8630aee6 100644 --- a/roar/integrations/glaas/registration/job.py +++ b/roar/integrations/glaas/registration/job.py @@ -400,14 +400,22 @@ def create_jobs_batch_under_registration_session( payload_indices: list[int] = [] for i, job in enumerate(jobs): - job_uid = job.get("job_uid") - if not job_uid: + local_job_uid = job.get("job_uid") + if not local_job_uid: self._logger.warning("Skipping job without job_uid") results[i] = JobRegistrationResult( success=False, job_uid="", error="Job missing job_uid" ) continue + remote_job_uid = job.get("remote_job_uid") + if not isinstance(remote_job_uid, str) or not remote_job_uid: + remote_job_uid = local_job_uid + + remote_parent_job_uid = job.get("remote_parent_job_uid") + if remote_parent_job_uid is None and job.get("parent_job_uid") is not None: + remote_parent_job_uid = job.get("parent_job_uid") + command = job.get("command", "") git_commit = job.get("git_commit") or git_context.commit or "" git_branch = job.get("git_branch") or git_context.branch or "" @@ -418,7 +426,7 @@ def create_jobs_batch_under_registration_session( command=filtered_command, timestamp=job.get("timestamp", 0.0), session_hash="pending-registration-session", - job_uid=job_uid, + job_uid=remote_job_uid, git_commit=git_commit, git_branch=git_branch, job_type=job.get("job_type"), @@ -426,14 +434,16 @@ def create_jobs_batch_under_registration_session( ) if not validation: error_msg = "; ".join(validation.errors) - self._logger.warning("Job validation failed for %s: %s", job_uid, error_msg) - results[i] = JobRegistrationResult(success=False, job_uid=job_uid, error=error_msg) + self._logger.warning("Job validation failed for %s: %s", local_job_uid, error_msg) + results[i] = JobRegistrationResult( + success=False, job_uid=local_job_uid, error=error_msg + ) continue payload: dict[str, Any] = { "command": filtered_command, "timestamp": job.get("timestamp", 0.0), - "job_uid": job_uid, + "job_uid": remote_job_uid, "git_commit": git_commit, "git_branch": git_branch, "duration_seconds": job.get("duration_seconds", 0.0), @@ -443,8 +453,8 @@ def create_jobs_batch_under_registration_session( } if filtered_metadata: payload["metadata"] = filtered_metadata - if job.get("parent_job_uid") is not None: - payload["parent_job_uid"] = job.get("parent_job_uid") + if remote_parent_job_uid is not None: + payload["parent_job_uid"] = remote_parent_job_uid payloads.append(payload) payload_indices.append(i) @@ -471,16 +481,16 @@ def create_jobs_batch_under_registration_session( ) else: for pos, idx in enumerate(payload_indices): - job_uid = payloads[pos]["job_uid"] + local_job_uid = jobs[idx].get("job_uid", "") if pos < len(errors) and errors[pos]: results[idx] = JobRegistrationResult( - success=False, job_uid=job_uid, error=errors[pos] + success=False, job_uid=local_job_uid, error=errors[pos] ) else: job_id = job_ids[pos] if pos < len(job_ids) else None results[idx] = JobRegistrationResult( success=True, - job_uid=job_uid, + job_uid=local_job_uid, job_id=str(job_id) if job_id else None, ) diff --git a/tests/application/publish/test_remote_job_uids.py b/tests/application/publish/test_remote_job_uids.py new file mode 100644 index 00000000..7e8c9713 --- /dev/null +++ b/tests/application/publish/test_remote_job_uids.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from roar.application.publish.remote_job_uids import ( + build_remote_publication_job_uid, + prepare_jobs_for_remote_publication, +) + + +def test_build_remote_publication_job_uid_is_deterministic_and_scoped_by_session_hash() -> None: + first = build_remote_publication_job_uid("a" * 64, "job-1") + repeated = build_remote_publication_job_uid("a" * 64, "job-1") + different_session = build_remote_publication_job_uid("b" * 64, "job-1") + different_job = build_remote_publication_job_uid("a" * 64, "job-2") + + assert first == repeated + assert first != different_session + assert first != different_job + assert len(first) == 32 + + +def test_prepare_jobs_for_remote_publication_maps_parent_relationships_without_mutating_input() -> ( + None +): + jobs = [ + { + "id": 1, + "job_uid": "parent-job", + "parent_job_uid": None, + "step_number": 1, + }, + { + "id": 2, + "job_uid": "child-job", + "parent_job_uid": "parent-job", + "step_number": 2, + }, + ] + + prepared = prepare_jobs_for_remote_publication(jobs, "f" * 64) + + assert jobs[0].get("remote_job_uid") is None + assert jobs[1].get("remote_parent_job_uid") is None + + assert prepared[0]["job_uid"] == "parent-job" + assert prepared[1]["job_uid"] == "child-job" + assert prepared[0]["remote_job_uid"] == build_remote_publication_job_uid("f" * 64, "parent-job") + assert prepared[1]["remote_job_uid"] == build_remote_publication_job_uid("f" * 64, "child-job") + assert prepared[1]["remote_parent_job_uid"] == prepared[0]["remote_job_uid"] diff --git a/tests/application/query/test_label.py b/tests/application/query/test_label.py index 07fa77fc..ab750c66 100644 --- a/tests/application/query/test_label.py +++ b/tests/application/query/test_label.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from roar.application.query import ( LabelCopyRequest, @@ -223,6 +223,83 @@ def test_build_push_labels_summary_returns_remote_versioned_summary(tmp_path: Pa assert summary.render() == "Pushed remote labels (version 2):\n owner=ml\n stage=gold" +def test_build_push_labels_summary_retries_job_push_with_legacy_job_uid_on_not_found( + tmp_path: Path, +) -> None: + db_ctx = MagicMock() + db_ctx.__enter__.return_value = db_ctx + db_ctx.__exit__.return_value = None + service = MagicMock() + resolved_target = MagicMock(entity_type="job") + service.resolve_target.return_value = resolved_target + service.current_metadata.return_value = {"phase": "gold"} + client = MagicMock() + client.patch_current_label.side_effect = [ + (None, "HTTP 404: Label not found"), + ({"version": 3, "metadata": {"phase": "gold"}}, None), + ] + + with ( + patch("roar.application.query.label.create_database_context", return_value=db_ctx), + patch("roar.application.query.label.LabelService", return_value=service), + patch( + "roar.application.query.label.build_remote_label_mutation_payload", + side_effect=[ + { + "entity_type": "job", + "session_hash": "s" * 64, + "job_uid": "remote-job-1", + "metadata": {"phase": "gold"}, + }, + { + "entity_type": "job", + "session_hash": "s" * 64, + "job_uid": "local-job-1", + "metadata": {"phase": "gold"}, + }, + ], + ) as build_payload, + patch("roar.application.query.label.GlaasClient", return_value=client), + ): + summary = build_push_labels_summary(_push_request(tmp_path, target="@1", entity_type="job")) + + assert build_payload.call_args_list == [ + call( + db_ctx, + roar_dir=tmp_path / ".roar", + target=resolved_target, + metadata={"phase": "gold"}, + ), + call( + db_ctx, + roar_dir=tmp_path / ".roar", + target=resolved_target, + metadata={"phase": "gold"}, + prefer_remote_publication_uid=False, + ), + ] + assert client.patch_current_label.call_args_list == [ + call( + { + "entity_type": "job", + "session_hash": "s" * 64, + "job_uid": "remote-job-1", + "metadata": {"phase": "gold"}, + } + ), + call( + { + "entity_type": "job", + "session_hash": "s" * 64, + "job_uid": "local-job-1", + "metadata": {"phase": "gold"}, + } + ), + ] + assert summary.heading == "Pushed remote labels (version 3):" + assert summary.render() == "Pushed remote labels (version 3):\n phase=gold" + + def test_build_push_labels_summary_rejects_missing_user_managed_labels(tmp_path: Path) -> None: db_ctx = MagicMock() db_ctx.__enter__.return_value = db_ctx diff --git a/tests/integration/fake_glaas.py b/tests/integration/fake_glaas.py index caa34ff0..00295265 100644 --- a/tests/integration/fake_glaas.py +++ b/tests/integration/fake_glaas.py @@ -36,6 +36,7 @@ def __init__(self) -> None: self.session_reproductions_by_hash: dict[str, dict[str, Any]] = {} self.registration_sessions_by_id: dict[str, dict[str, Any]] = {} self.registration_session_ids_by_client_session_id: dict[str, str] = {} + self.job_owners_by_uid: dict[str, dict[str, Any]] = {} self._next_job_id = 1 self._next_registration_session_id = 1 self._next_finalized_hash = 1 @@ -124,6 +125,28 @@ def _resolve_creator_identity(self, authenticated_user: dict[str, str] | None) - return f"glaas:user:{user_id}" return "anonymous" + def _claim_job_uid( + self, + job_uid: str, + *, + registration_session_id: str, + lineage_hash: str | None = None, + ) -> str | None: + owner = self.server.job_owners_by_uid.get(job_uid) + if owner is None: + self.server.job_owners_by_uid[job_uid] = { + "registration_session_id": registration_session_id, + "hash": lineage_hash, + } + return None + + if owner.get("registration_session_id") == registration_session_id: + if lineage_hash and not owner.get("hash"): + owner["hash"] = lineage_hash + return None + + return f"job_uid already belongs to a different lineage or registration session: {job_uid}" + def _compute_registration_session_hash( self, session_state: dict[str, Any], @@ -492,6 +515,15 @@ def do_POST(self) -> None: lineage_hash = self.server.allocate_lineage_hash() session_state["hash"] = lineage_hash session_state["status"] = "closed" + jobs_by_uid = session_state.get("jobs") + if isinstance(jobs_by_uid, dict): + for job_uid in jobs_by_uid: + if isinstance(job_uid, str) and job_uid: + self._claim_job_uid( + job_uid, + registration_session_id=registration_session_id, + lineage_hash=lineage_hash, + ) self.server.registration_session_finalizations.append( { **payload, @@ -529,9 +561,19 @@ def do_POST(self) -> None: {"jobs": {}, "hash": None, "status": "active"}, ) job_ids = [] + errors: list[str] = [] for job in jobs: job_uid = job.get("job_uid") if isinstance(job_uid, str) and job_uid: + conflict = self._claim_job_uid( + job_uid, + registration_session_id=registration_session_id, + ) + if conflict is not None: + job_ids.append(None) + errors.append(conflict) + continue + stored_job = session_state["jobs"].setdefault( job_uid, {"inputs": [], "outputs": []}, @@ -546,8 +588,12 @@ def do_POST(self) -> None: "metadata": job.get("metadata"), } ) - job_ids.append(self.server.allocate_job_id()) - self._write_json(200, {"job_ids": job_ids, "errors": []}) + job_ids.append(self.server.allocate_job_id()) + errors.append("") + else: + job_ids.append(self.server.allocate_job_id()) + errors.append("") + self._write_json(200, {"job_ids": job_ids, "errors": errors}) return reg_job_match = re.fullmatch(r"/api/v1/registration-sessions/([^/]+)/jobs", self.path) @@ -562,6 +608,14 @@ def do_POST(self) -> None: ) job_uid = payload.get("job_uid") if isinstance(job_uid, str) and job_uid: + conflict = self._claim_job_uid( + job_uid, + registration_session_id=registration_session_id, + ) + if conflict is not None: + self._write_json(400, {"error": {"message": conflict}}) + return + stored_job = session_state["jobs"].setdefault( job_uid, {"inputs": [], "outputs": []} ) diff --git a/tests/integration/test_label_push_cli_integration.py b/tests/integration/test_label_push_cli_integration.py index f26e5a75..a027b84d 100644 --- a/tests/integration/test_label_push_cli_integration.py +++ b/tests/integration/test_label_push_cli_integration.py @@ -181,12 +181,15 @@ def test_label_push_job_omits_system_labels_and_targets_job_session_hash( published_session_hash = fake_glaas_publish_server.registration_session_finalizations[0]["hash"] roar_cli("label", "set", "job", "@1", "phase=train", env_overrides=env) - job_uid = _job_uid_for(temp_git_repo, roar_cli, "processed.csv") + staged_jobs = fake_glaas_publish_server.registration_session_job_batches[0]["jobs"] + remote_job_uid = next( + str(job["job_uid"]) for job in staged_jobs if int(job.get("step_number") or 0) == 1 + ) synced_job_labels = [ label for batch in fake_glaas_publish_server.label_syncs for label in batch - if label.get("entity_type") == "job" and label.get("job_uid") == job_uid + if label.get("entity_type") == "job" and label.get("job_uid") == remote_job_uid ] assert len(synced_job_labels) == 1 assert synced_job_labels[0]["session_hash"] == published_session_hash @@ -195,7 +198,7 @@ def test_label_push_job_omits_system_labels_and_targets_job_session_hash( assert len(fake_glaas_publish_server.label_mutation_attempts) == 1 attempted = fake_glaas_publish_server.label_mutation_attempts[0] assert attempted["entity_type"] == "job" - assert attempted["job_uid"] == job_uid + assert attempted["job_uid"] == remote_job_uid assert attempted["metadata"] == {"phase": "train"} assert re.fullmatch(r"[0-9a-f]{64}", attempted["session_hash"]) assert result.returncode == 0 @@ -205,7 +208,7 @@ def test_label_push_job_omits_system_labels_and_targets_job_session_hash( { "entity_type": "job", "session_hash": attempted["session_hash"], - "job_uid": job_uid, + "job_uid": remote_job_uid, "metadata": {"phase": "train"}, } ] diff --git a/tests/integration/test_register_dry_run_cli.py b/tests/integration/test_register_dry_run_cli.py index bd64bc57..189b16f5 100644 --- a/tests/integration/test_register_dry_run_cli.py +++ b/tests/integration/test_register_dry_run_cli.py @@ -180,6 +180,68 @@ def test_register_publishes_local_lineage_with_fake_glaas( assert registered_jobs[0]["job_type"] == "run" +def test_register_can_republish_same_local_session_after_additional_run_with_fake_glaas( + temp_git_repo: Path, + roar_cli, + git_commit, + python_exe: str, + fake_glaas_publish_server: FakeGlaasServer, +) -> None: + env = _configure_register_repo(temp_git_repo, roar_cli, fake_glaas_publish_server.base_url) + + input_path = temp_git_repo / "input.txt" + input_path.write_text("register me\n") + first_script = temp_git_repo / "generate_report.py" + first_script.write_text( + "from pathlib import Path\n" + "content = Path('input.txt').read_text()\n" + "Path('report.txt').write_text(content.upper())\n" + ) + git_commit("Add first register fixture") + + first_run = roar_cli("run", python_exe, "generate_report.py", env_overrides=env) + assert first_run.returncode == 0 + git_commit("Commit first tracked report") + + first_session_hash = _status_session_hash(temp_git_repo, roar_cli) + first_register = roar_cli("register", "report.txt", "--yes", env_overrides=env) + assert first_register.returncode == 0 + + second_script = temp_git_repo / "summarize_report.py" + second_script.write_text( + "from pathlib import Path\n" + "content = Path('report.txt').read_text().strip()\n" + "Path('summary.txt').write_text(f'chars={len(content)}\\n')\n" + ) + git_commit("Add follow-up register fixture") + + second_run = roar_cli("run", python_exe, "summarize_report.py", env_overrides=env) + assert second_run.returncode == 0 + git_commit("Commit follow-up tracked summary") + + second_session_hash = _status_session_hash(temp_git_repo, roar_cli) + assert second_session_hash != first_session_hash + + second_register = roar_cli("register", "summary.txt", "--yes", env_overrides=env) + + assert second_register.returncode == 0 + assert "Jobs: 2" in second_register.stdout + assert len(fake_glaas_publish_server.registration_session_creations) == 2 + assert len(fake_glaas_publish_server.registration_session_finalizations) == 2 + assert len(fake_glaas_publish_server.registration_session_job_batches) == 2 + + first_publish_jobs = fake_glaas_publish_server.registration_session_job_batches[0]["jobs"] + second_publish_jobs = fake_glaas_publish_server.registration_session_job_batches[1]["jobs"] + assert len(first_publish_jobs) == 1 + assert len(second_publish_jobs) == 2 + + first_publish_job_uid = str(first_publish_jobs[0]["job_uid"]) + republished_first_step = next( + job for job in second_publish_jobs if int(job.get("step_number") or 0) == 1 + ) + assert str(republished_first_step["job_uid"]) != first_publish_job_uid + + def test_register_honors_logging_config_for_console_and_file( temp_git_repo: Path, roar_cli, diff --git a/tests/live_glaas/test_labels_live.py b/tests/live_glaas/test_labels_live.py index 4522c56d..9c6a83ba 100644 --- a/tests/live_glaas/test_labels_live.py +++ b/tests/live_glaas/test_labels_live.py @@ -15,6 +15,7 @@ import pytest +from roar.application.publish.remote_job_uids import build_remote_publication_job_uid from tests.live_glaas import test_composite_live as composite_live pytest_plugins = ("tests.live_glaas.test_composite_live",) @@ -299,6 +300,7 @@ def test_register_syncs_current_local_labels_only_when_register_called( artifact_hash = str(refs["artifact_hash"]) step_number = int(refs["step_number"]) job_uid = _local_job_uid(repo, step_number) + remote_job_uid = build_remote_publication_job_uid(session_hash, job_uid) _assert_ok( _run_roar( @@ -317,7 +319,7 @@ def test_register_syncs_current_local_labels_only_when_register_called( ) assert _remote_session_label_rows(glaas_url, session_hash) == [] - assert _remote_job_label_rows(glaas_url, session_hash, job_uid) == [] + assert _remote_job_label_rows(glaas_url, session_hash, remote_job_uid) == [] assert _remote_artifact_label_rows(glaas_url, artifact_hash) == [] _assert_ok(_run_roar(repo, "register", "processed.csv")) @@ -325,7 +327,7 @@ def test_register_syncs_current_local_labels_only_when_register_called( assert _remote_session_label_rows(glaas_url, session_hash) == [ (1, {"experiment": "ablation-7", "project": "forecasting"}) ] - job_rows = _remote_job_label_rows(glaas_url, session_hash, job_uid) + job_rows = _remote_job_label_rows(glaas_url, session_hash, remote_job_uid) assert len(job_rows) == 1 version, job_metadata = job_rows[0] assert version == 1 @@ -369,6 +371,7 @@ def test_register_exposes_current_labels_via_label_api( artifact_hash = str(refs["artifact_hash"]) step_number = int(refs["step_number"]) job_uid = _local_job_uid(repo, step_number) + remote_job_uid = build_remote_publication_job_uid(session_hash, job_uid) _assert_ok( _run_roar( @@ -408,7 +411,7 @@ def test_register_exposes_current_labels_via_label_api( "/api/v1/labels/current", entity_type="job", session_hash=session_hash, - job_uid=job_uid, + job_uid=remote_job_uid, ) assert job_label == { "id": job_label["id"], @@ -417,7 +420,7 @@ def test_register_exposes_current_labels_via_label_api( "metadata": job_label["metadata"], "createdAt": job_label["createdAt"], "sessionHash": session_hash, - "jobUid": job_uid, + "jobUid": remote_job_uid, } assert isinstance(job_label["metadata"], dict) _assert_synced_run_job_label_metadata(job_label["metadata"], phase="preprocess") diff --git a/tests/unit/put/test_put_service.py b/tests/unit/put/test_put_service.py index cbf92b92..e1cd9b55 100644 --- a/tests/unit/put/test_put_service.py +++ b/tests/unit/put/test_put_service.py @@ -202,7 +202,9 @@ def test_put_prepared_refreshes_and_syncs_put_job_labels(self, tmp_path: Path) - assert sync_kwargs["db_ctx"] is db assert sync_kwargs["session_id"] is None assert sync_kwargs["session_hash"] == "session_hash_abc123" - assert sync_kwargs["jobs"] == [{"id": 42, "job_uid": "job-uid-1"}] + assert sync_kwargs["jobs"] == [ + {"id": 42, "job_uid": "job-uid-1", "remote_job_uid": "job-uid-1"} + ] assert sync_kwargs["artifacts"] == [] def test_put_prepared_returns_registered_session_info(self, tmp_path: Path) -> None: diff --git a/tests/unit/test_label_service.py b/tests/unit/test_label_service.py index 817e293f..95af4995 100644 --- a/tests/unit/test_label_service.py +++ b/tests/unit/test_label_service.py @@ -188,3 +188,34 @@ def get_history(self, entity_type: str, **kwargs): "key_origins": {"generated.phase": "system", "stage": "user"}, }, ] + + +def test_collect_label_sync_payloads_prefers_remote_job_uid_when_present() -> None: + class StubLabels: + def get_current(self, entity_type: str, **kwargs): + if entity_type == "job" and kwargs.get("job_id") == 11: + return {"metadata": {"phase": "prep"}, "write_origin": "user"} + return None + + def get_history(self, entity_type: str, **kwargs): + if entity_type == "job" and kwargs.get("job_id") == 11: + return [{"metadata": {"phase": "prep"}, "write_origin": "user"}] + return [] + + payloads = collect_label_sync_payloads( + SimpleNamespace(labels=StubLabels()), + session_id=None, + session_hash="s" * 64, + jobs=[{"id": 11, "job_uid": "local-job-1", "remote_job_uid": "remote-job-1"}], + artifacts=[], + ) + + assert payloads == [ + { + "entity_type": "job", + "session_hash": "s" * 64, + "job_uid": "remote-job-1", + "metadata": {"phase": "prep"}, + "key_origins": {"phase": "user"}, + } + ] From fbc9c197c3c7d32ed892c3fd8afb7a24bab22a9f Mon Sep 17 00:00:00 2001 From: Trevor Basinger Date: Wed, 22 Apr 2026 14:57:44 +0000 Subject: [PATCH 7/7] fix(publish): use remote registry auth for session staging --- roar/application/publish/session.py | 6 ++-- tests/application/publish/test_session.py | 39 +++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/roar/application/publish/session.py b/roar/application/publish/session.py index 2c01d4ec..5100e4cb 100644 --- a/roar/application/publish/session.py +++ b/roar/application/publish/session.py @@ -216,7 +216,7 @@ def prepare_publish_session( logger.debug("GLaaS health check failed: %s", exc) raise ValueError(f"GLaaS health check failed: {exc}") from exc - publish_auth = getattr(glaas_client, "publish_auth", None) + publish_auth = resolved_remote_registry.publish_auth access_token = getattr(publish_auth, "access_token", None) ssh_auth_available = getattr(publish_auth, "ssh_auth_available", False) @@ -227,7 +227,9 @@ def prepare_publish_session( if should_use_registration_sessions: logger.debug("Creating remote registration session with GLaaS") - session_result = session_service.create_registration_session(client_session_id=None) + session_result = resolved_session_service.create_registration_session( + client_session_id=None + ) if not session_result.success: logger.debug("Registration session creation failed: %s", session_result.error) raise ValueError(f"Registration session creation failed: {session_result.error}") diff --git a/tests/application/publish/test_session.py b/tests/application/publish/test_session.py index 67d6fbb8..c3cc40b1 100644 --- a/tests/application/publish/test_session.py +++ b/tests/application/publish/test_session.py @@ -110,6 +110,45 @@ def test_prepare_publish_session_creates_registration_session_with_glaas(tmp_pat session_service.register.assert_not_called() +def test_prepare_publish_session_uses_remote_registry_publish_auth_when_legacy_client_not_passed( + tmp_path: Path, +) -> None: + remote_registry = MagicMock() + remote_registry.session_service = MagicMock() + remote_registry.session_service.compute_session_hash.return_value = "session-hash" + remote_registry.session_service.create_registration_session.return_value = ( + SessionRegistrationResult( + success=True, + session_hash="session-hash", + session_url=None, + registration_session_id="reg-session-remote-123", + ) + ) + remote_registry.publish_auth.access_token = "token-123" + remote_registry.publish_auth.ssh_auth_available = False + remote_registry.is_configured.return_value = True + + result = prepare_publish_session( + remote_registry=remote_registry, + roar_dir=tmp_path / ".roar", + session_id=7, + git_context=_git_context(), + logger=MagicMock(), + register_with_glaas=True, + ) + + assert result == PreparedPublishSession( + session_hash="session-hash", + session_url=None, + registration_session_id="reg-session-remote-123", + ) + remote_registry.health_check.assert_called_once() + remote_registry.session_service.create_registration_session.assert_called_once_with( + client_session_id=None + ) + remote_registry.session_service.register.assert_not_called() + + def test_prepare_publish_session_creates_registration_session_with_ssh_only_auth( tmp_path: Path, ) -> None: