From 7ddd04e5bc07fe982a6057c0ba543feb009ce917 Mon Sep 17 00:00:00 2001 From: jbellister-slac Date: Fri, 27 Mar 2026 09:47:55 -0700 Subject: [PATCH 1/6] ENH: Add ability for snapshot restore to report progress to the frontend --- app/api/v1/snapshots.py | 31 ++++++++++- app/services/background_tasks.py | 88 +++++++++++++++++++++++++++++++- app/services/epics_service.py | 48 +++++++++++++++++ app/services/snapshot_service.py | 27 ++++++++-- app/tasks/snapshot_tasks.py | 20 +++++++- 5 files changed, 206 insertions(+), 8 deletions(-) diff --git a/app/api/v1/snapshots.py b/app/api/v1/snapshots.py index a10bf84..f73b345 100644 --- a/app/api/v1/snapshots.py +++ b/app/api/v1/snapshots.py @@ -11,7 +11,7 @@ from app.models.job import JobType from app.schemas.job import JobCreatedDTO from app.dependencies import require_read_access, require_write_access -from app.api.responses import APIException, success_response +from app.api.responses import APIException, error_response, success_response from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO, UpdateSnapshotDTO from app.services.job_service import JobService from app.services.epics_service import EpicsService, get_epics_service @@ -190,6 +190,7 @@ async def restore_snapshot( snapshot_id: str, request: RestoreRequestDTO | None = None, db: AsyncSession = Depends(get_db), + async_mode: bool = Query(True, alias="async"), epics: EpicsService = Depends(get_epics_service), ) -> dict: """ @@ -208,6 +209,34 @@ async def restore_snapshot( if not snapshot: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) + if async_mode: + job_service = JobService(db) + job = await job_service.create_job( + JobType.SNAPSHOT_RESTORE, + job_data={"snapshotId": snapshot_id}, + ) + + await db.commit() + pool = await get_arq_pool() + if pool: + try: + await pool.enqueue_job( + "restore_snapshot_task", + job_id=str(job.id), + snapshot_id=snapshot_id, + pv_ids=request.pvIds if request else None, + ) + logger.info(f"Enqueued restore job to Arq: {job.id}") + + return success_response( + JobCreatedDTO( + jobId=job.id, + message=f"Snapshot restore queued ({snapshot_id})", + ) + ) + except Exception as e: + return error_response(500, f"Failed to enqueue to Arq: {e}") + result = await service.restore_snapshot(snapshot_id, request) return success_response(result) diff --git a/app/services/background_tasks.py b/app/services/background_tasks.py index 99b86b3..0435d6d 100644 --- a/app/services/background_tasks.py +++ b/app/services/background_tasks.py @@ -4,7 +4,7 @@ from datetime import datetime from app.db.session import async_session_maker -from app.schemas.snapshot import NewSnapshotDTO +from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO from app.services.epics_service import get_epics_service from app.services.redis_service import get_redis_service from app.services.snapshot_service import SnapshotService @@ -111,6 +111,92 @@ async def progress_update(current: int, total: int, message: str): logger.exception(f"Failed to update job status: {inner_e}") +async def run_snapshot_restore( + job_id: str, + snapshot_id: str, + pv_ids: list[str] | None = None, +) -> None: + """ + Background task to restore a snapshot. + + This runs in a separate asyncio task and uses its own database session. + """ + logger.info(f"Background task started for job {job_id}: Restoring snapshot '{snapshot_id}'") + + async with async_session_maker() as session: + try: + job_repo = JobRepository(session) + + # Mark job as running + await job_repo.mark_running(job_id) + await session.commit() + await asyncio.sleep(0) + + # Initial progress update + await job_repo.update_progress(job_id, 5, "Loading snapshot values...") + await session.commit() + await asyncio.sleep(0) + + epics = get_epics_service() + snapshot_service = SnapshotService(session, epics) + + # Optional restore request + request = RestoreRequestDTO(pvIds=pv_ids) if pv_ids else None + + last_update = {"progress": 5, "last_time": datetime.now()} + + async def progress_update(current: int, total: int, message: str): + try: + write_progress = int((current / total) * 85) if total > 0 else 0 + job_progress = 10 + write_progress + + now = datetime.now() + time_elapsed = (now - last_update["last_time"]).total_seconds() + progress_changed = job_progress - last_update["progress"] >= 2 + + if progress_changed or time_elapsed >= 2.0 or current >= total: + last_update["progress"] = job_progress + last_update["last_time"] = now + await job_repo.update_progress(job_id, job_progress, message) + await session.commit() + await asyncio.sleep(0) + except Exception as e: + logger.error(f"Error in restore progress_update: {e}") + + result = await snapshot_service.restore_snapshot( + snapshot_id, + request, + progress_callback=progress_update, + ) + + # Final completion update + completion_message = f"Restore completed: {result.successCount}/{result.totalPVs} PVs restored" + ( + f", {result.failureCount} failed" if result.failureCount > 0 else "" + ) + + await job_repo.mark_completed( + job_id, + message=completion_message, + ) + await session.commit() + + logger.info( + f"Background restore completed for job {job_id}: " + f"{result.successCount}/{result.totalPVs} succeeded, " + f"{result.failureCount} failed" + ) + + except Exception as e: + logger.exception(f"Background restore failed for job {job_id}: {e}") + error_msg = f"{type(e).__name__}: {str(e)}" + try: + await session.rollback() + await job_repo.mark_failed(job_id, error_msg) + await session.commit() + except Exception as inner_e: + logger.exception(f"Failed to update restore job status: {inner_e}") + + def schedule_snapshot_creation(job_id: str, title: str, description: str | None = None, use_cache: bool = True) -> None: """ Schedule a snapshot creation task to run in the background. diff --git a/app/services/epics_service.py b/app/services/epics_service.py index 9c2f534..672ebe7 100644 --- a/app/services/epics_service.py +++ b/app/services/epics_service.py @@ -471,6 +471,54 @@ async def put_many(self, values: dict[str, Any]) -> dict[str, tuple[bool, str | return results + async def put_many_with_progress( + self, + values: dict[str, Any], + progress_callback: Callable | None = None, + ) -> dict[str, tuple[bool, str | None]]: + """ + Put values to multiple PVs with progress tracking. + Can be used to update the user on progress when a snapshot restore is initiated. + """ + total_pvs = len(values) + results: dict[str, tuple[bool, str | None]] = {} + + logger.info(f"Starting put_many_with_progress for {total_pvs} PVs") + + if progress_callback: + await progress_callback(0, total_pvs, f"Starting restore of {total_pvs:,} PVs") + + items = list(values.items()) + batch_size = self._chunk_size + + for i in range(0, total_pvs, batch_size): + batch_items = items[i : i + batch_size] + batch_values = dict(batch_items) + + try: + batch_results = await self.put_many(batch_values) + results.update(batch_results) + except Exception as e: + logger.error(f"Chunk put error ({i}-{i + len(batch_items)}): {e}") + for pv_name, _ in batch_items: + if pv_name not in results: + results[pv_name] = (False, str(e)) + + current = min(i + batch_size, total_pvs) + success_count = sum(1 for ok, _ in results.values() if ok) + + if progress_callback: + await progress_callback( + current, + total_pvs, + f"Restored {current:,}/{total_pvs:,} PVs ({success_count:,} successful)", + ) + + logger.info(f"Restored {current:,}/{total_pvs:,} PVs " f"({success_count:,} successful)") + + logger.info(f"Completed put_many_with_progress: {len(results)}/{total_pvs} PVs processed") + return results + async def shutdown(self): """Cleanup resources.""" # aioca manages its own connections via libca diff --git a/app/services/snapshot_service.py b/app/services/snapshot_service.py index 818baa4..3feca37 100644 --- a/app/services/snapshot_service.py +++ b/app/services/snapshot_service.py @@ -522,7 +522,12 @@ def _format_ts(ts: float | None) -> str | None: logger.exception(f"Error creating snapshot from cache '{data.title}': {e}") raise - async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO | None = None) -> RestoreResultDTO: + async def restore_snapshot( + self, + snapshot_id: str, + request: RestoreRequestDTO | None = None, + progress_callback: Callable | None = None, + ) -> RestoreResultDTO: """ Restore PV values from a snapshot to EPICS. @@ -560,10 +565,18 @@ async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO | values_to_write[pv.setpoint_address] = write_value pv_id_by_address[pv.setpoint_address] = pv.id - logger.info(f"Writing {len(values_to_write)} PV values") + total_pvs = len(values_to_write) + logger.info(f"Writing {total_pvs} PV values") + + # Initial progress update + if progress_callback: + await progress_callback(0, total_pvs, f"Starting restore of {total_pvs:,} PVs") # Write to EPICS in parallel - results = await self.epics.put_many(values_to_write) + if progress_callback: + results = await self.epics.put_many_with_progress(values_to_write, progress_callback) + else: + results = await self.epics.put_many(values_to_write) # Process results failures = [] @@ -576,6 +589,12 @@ async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO | pv_id = pv_id_by_address.get(address, "") failures.append({"pvId": pv_id, "pvName": address, "error": error or "Unknown error"}) + if progress_callback: + await progress_callback( + total_pvs, + total_pvs, + f"Completed restore: {success_count}/{total_pvs} PVs successful", + ) total_time = datetime.now() logger.info( f"Restore completed in {(total_time - start_time).total_seconds():.2f}s " @@ -583,7 +602,7 @@ async def restore_snapshot(self, snapshot_id: str, request: RestoreRequestDTO | ) return RestoreResultDTO( - totalPVs=len(values_to_write), successCount=success_count, failureCount=len(failures), failures=failures + totalPVs=total_pvs, successCount=success_count, failureCount=len(failures), failures=failures ) async def compare_snapshots(self, snapshot1_id: str, snapshot2_id: str) -> ComparisonResultDTO: diff --git a/app/tasks/snapshot_tasks.py b/app/tasks/snapshot_tasks.py index fa969c6..bea2e3a 100644 --- a/app/tasks/snapshot_tasks.py +++ b/app/tasks/snapshot_tasks.py @@ -128,6 +128,13 @@ async def restore_snapshot_task(ctx: dict, job_id: str, snapshot_id: str, pv_ids await job_repo.mark_running(job_id) await session.commit() + # Create progress callback for job updates + async def on_progress(current: int, total: int, message: str) -> None: + progress = int((current / total) * 100) if total > 0 else 0 + await job_repo.update_progress(job_id, progress, message) + await session.commit() + logger.debug(f"Restore job {job_id} progress: {progress}% - {message}") + # Initialize services epics = ctx.get("epics") or get_epics_service() @@ -140,7 +147,11 @@ async def restore_snapshot_task(ctx: dict, job_id: str, snapshot_id: str, pv_ids request = RestoreRequestDTO(pvIds=pv_ids) if pv_ids else None # Restore the snapshot - result = await snapshot_service.restore_snapshot(snapshot_id, request) + result = await snapshot_service.restore_snapshot( + snapshot_id, + request, + progress_callback=on_progress, + ) # Mark job as completed result_data = { @@ -148,8 +159,13 @@ async def restore_snapshot_task(ctx: dict, job_id: str, snapshot_id: str, pv_ids "success_count": result.successCount, "failure_count": result.failureCount, } + completion_message = f"Restored {result.successCount}/{result.totalPVs} PVs" + ( + f" ({result.failureCount} failed)" if result.failureCount > 0 else "" + ) await job_repo.mark_completed( - job_id, result_id=snapshot_id, message=f"Restored {result.successCount}/{result.totalPVs} PVs" + job_id, + result_id=snapshot_id, + message=completion_message, ) await session.commit() From 746f719a3aebcf3523a7218bb399d8b2a8b10455 Mon Sep 17 00:00:00 2001 From: jbellister-slac Date: Mon, 30 Mar 2026 17:16:08 -0700 Subject: [PATCH 2/6] MNT: Arq is likely sufficient --- app/services/background_tasks.py | 88 +------------------------------- 1 file changed, 1 insertion(+), 87 deletions(-) diff --git a/app/services/background_tasks.py b/app/services/background_tasks.py index 0435d6d..99b86b3 100644 --- a/app/services/background_tasks.py +++ b/app/services/background_tasks.py @@ -4,7 +4,7 @@ from datetime import datetime from app.db.session import async_session_maker -from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO +from app.schemas.snapshot import NewSnapshotDTO from app.services.epics_service import get_epics_service from app.services.redis_service import get_redis_service from app.services.snapshot_service import SnapshotService @@ -111,92 +111,6 @@ async def progress_update(current: int, total: int, message: str): logger.exception(f"Failed to update job status: {inner_e}") -async def run_snapshot_restore( - job_id: str, - snapshot_id: str, - pv_ids: list[str] | None = None, -) -> None: - """ - Background task to restore a snapshot. - - This runs in a separate asyncio task and uses its own database session. - """ - logger.info(f"Background task started for job {job_id}: Restoring snapshot '{snapshot_id}'") - - async with async_session_maker() as session: - try: - job_repo = JobRepository(session) - - # Mark job as running - await job_repo.mark_running(job_id) - await session.commit() - await asyncio.sleep(0) - - # Initial progress update - await job_repo.update_progress(job_id, 5, "Loading snapshot values...") - await session.commit() - await asyncio.sleep(0) - - epics = get_epics_service() - snapshot_service = SnapshotService(session, epics) - - # Optional restore request - request = RestoreRequestDTO(pvIds=pv_ids) if pv_ids else None - - last_update = {"progress": 5, "last_time": datetime.now()} - - async def progress_update(current: int, total: int, message: str): - try: - write_progress = int((current / total) * 85) if total > 0 else 0 - job_progress = 10 + write_progress - - now = datetime.now() - time_elapsed = (now - last_update["last_time"]).total_seconds() - progress_changed = job_progress - last_update["progress"] >= 2 - - if progress_changed or time_elapsed >= 2.0 or current >= total: - last_update["progress"] = job_progress - last_update["last_time"] = now - await job_repo.update_progress(job_id, job_progress, message) - await session.commit() - await asyncio.sleep(0) - except Exception as e: - logger.error(f"Error in restore progress_update: {e}") - - result = await snapshot_service.restore_snapshot( - snapshot_id, - request, - progress_callback=progress_update, - ) - - # Final completion update - completion_message = f"Restore completed: {result.successCount}/{result.totalPVs} PVs restored" + ( - f", {result.failureCount} failed" if result.failureCount > 0 else "" - ) - - await job_repo.mark_completed( - job_id, - message=completion_message, - ) - await session.commit() - - logger.info( - f"Background restore completed for job {job_id}: " - f"{result.successCount}/{result.totalPVs} succeeded, " - f"{result.failureCount} failed" - ) - - except Exception as e: - logger.exception(f"Background restore failed for job {job_id}: {e}") - error_msg = f"{type(e).__name__}: {str(e)}" - try: - await session.rollback() - await job_repo.mark_failed(job_id, error_msg) - await session.commit() - except Exception as inner_e: - logger.exception(f"Failed to update restore job status: {inner_e}") - - def schedule_snapshot_creation(job_id: str, title: str, description: str | None = None, use_cache: bool = True) -> None: """ Schedule a snapshot creation task to run in the background. From 78ecffb3a23efff69f131f037488ed1d0ceac7af Mon Sep 17 00:00:00 2001 From: jbellister-slac Date: Mon, 30 Mar 2026 17:16:33 -0700 Subject: [PATCH 3/6] MNT: Make the connection failed error message more obvious for the UI --- app/services/epics_service.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/app/services/epics_service.py b/app/services/epics_service.py index 672ebe7..5acc1eb 100644 --- a/app/services/epics_service.py +++ b/app/services/epics_service.py @@ -4,7 +4,7 @@ from datetime import datetime from collections.abc import Callable -from aioca import FORMAT_TIME, caget, caput, connect, purge_channel_caches +from aioca import FORMAT_TIME, CANothing, caget, caput, connect, purge_channel_caches from app.config import get_settings from app.services.epics_types import EpicsValue @@ -109,6 +109,13 @@ def _sanitize_value(self, value: Any) -> Any: return self._sanitize_value(value.tolist()) return value + def _ca_error_message(self, error_msg: CANothing) -> str: + """Convert CA error result to a more user-friendly message.""" + msg = str(error_msg).strip() + if "user specified timeout" in msg.lower(): + return "Connection timeout" + return msg if msg else "Unknown error" + def _augmented_to_epics_value(self, pv_name: str, result) -> EpicsValue: """Convert aioca AugmentedValue to our EpicsValue dataclass.""" if not result.ok: @@ -421,7 +428,7 @@ async def put_many(self, values: dict[str, Any]) -> dict[str, tuple[bool, str | if result.ok: results[original] = (True, None) else: - results[original] = (False, f"Failed: {getattr(result, 'errorcode', 'unknown')}") + results[original] = (False, self._ca_error_message(result)) except Exception as e: logger.error(f"Batch put error: {e}") @@ -511,7 +518,7 @@ async def put_many_with_progress( await progress_callback( current, total_pvs, - f"Restored {current:,}/{total_pvs:,} PVs ({success_count:,} successful)", + f"{current:,}/{total_pvs:,} PVs", ) logger.info(f"Restored {current:,}/{total_pvs:,} PVs " f"({success_count:,} successful)") From cf1f07891eb437e0cf7c8ab47fe6df367ec4d57b Mon Sep 17 00:00:00 2001 From: jbellister-slac Date: Mon, 30 Mar 2026 17:17:49 -0700 Subject: [PATCH 4/6] ENH: Make information about final job status available to the frontend (PVs that failed to restore in this case) --- app/repositories/job_repository.py | 35 +++++++++++++++++++++++------- app/schemas/job.py | 1 + app/services/job_service.py | 1 + app/services/snapshot_service.py | 2 +- app/tasks/snapshot_tasks.py | 13 ++++++++--- 5 files changed, 40 insertions(+), 12 deletions(-) diff --git a/app/repositories/job_repository.py b/app/repositories/job_repository.py index 8d3e880..7391c9b 100644 --- a/app/repositories/job_repository.py +++ b/app/repositories/job_repository.py @@ -69,15 +69,34 @@ async def mark_running(self, job_id: str) -> Job | None: """Mark a job as running.""" return await self.update_status(job_id, JobStatus.RUNNING, progress=0, message="Job started") - async def mark_completed(self, job_id: str, result_id: str | None = None, message: str | None = None) -> Job | None: + async def mark_completed( + self, + job_id: str, + result_id: str | None = None, + message: str | None = None, + result_data: dict | None = None, + ) -> Job | None: """Mark a job as completed.""" - return await self.update_status( - job_id, - JobStatus.COMPLETED, - progress=100, - result_id=result_id, - message=message or "Job completed successfully", - ) + job = await self.get_by_id(job_id) + if not job: + return None + + job.status = JobStatus.COMPLETED.value + job.progress = 100 + job.message = message or "Job completed successfully" + job.completed_at = datetime.now() + + if result_id is not None: + job.result_id = result_id + + if result_data: + existing = dict(job.job_data or {}) + existing["result"] = result_data + job.job_data = existing + + await self.session.flush() + await self.session.refresh(job) + return job async def mark_failed(self, job_id: str, error: str) -> Job | None: """Mark a job as failed.""" diff --git a/app/schemas/job.py b/app/schemas/job.py index d0fcc5d..0361ae9 100644 --- a/app/schemas/job.py +++ b/app/schemas/job.py @@ -14,6 +14,7 @@ class JobDTO(BaseModel): message: str | None = None resultId: str | None = None error: str | None = None + jobData: dict | None = None createdAt: datetime startedAt: datetime | None = None completedAt: datetime | None = None diff --git a/app/services/job_service.py b/app/services/job_service.py index 46317af..741421b 100644 --- a/app/services/job_service.py +++ b/app/services/job_service.py @@ -61,6 +61,7 @@ def _to_dto(self, job: Job) -> JobDTO: message=job.message, resultId=job.result_id, error=job.error, + jobData=job.job_data, createdAt=job.created_at, startedAt=job.started_at, completedAt=job.completed_at, diff --git a/app/services/snapshot_service.py b/app/services/snapshot_service.py index 3feca37..0c9895f 100644 --- a/app/services/snapshot_service.py +++ b/app/services/snapshot_service.py @@ -593,7 +593,7 @@ async def restore_snapshot( await progress_callback( total_pvs, total_pvs, - f"Completed restore: {success_count}/{total_pvs} PVs successful", + f"{success_count:,}/{total_pvs:,} PVs restored", ) total_time = datetime.now() logger.info( diff --git a/app/tasks/snapshot_tasks.py b/app/tasks/snapshot_tasks.py index bea2e3a..0e60b31 100644 --- a/app/tasks/snapshot_tasks.py +++ b/app/tasks/snapshot_tasks.py @@ -158,14 +158,21 @@ async def on_progress(current: int, total: int, message: str) -> None: "total_pvs": result.totalPVs, "success_count": result.successCount, "failure_count": result.failureCount, + "failures": [ + {"pvId": f["pvId"], "pvName": f["pvName"], "error": f["error"]} for f in result.failures[:50] + ], } - completion_message = f"Restored {result.successCount}/{result.totalPVs} PVs" + ( - f" ({result.failureCount} failed)" if result.failureCount > 0 else "" - ) + if result.failureCount > 0: + completion_message = ( + f"Restored {result.successCount:,}/{result.totalPVs:,} PVs " f"({result.failureCount} failed)" + ) + else: + completion_message = f"All {result.totalPVs:,} PVs have been restored to their snapshot values." await job_repo.mark_completed( job_id, result_id=snapshot_id, message=completion_message, + result_data=result_data, ) await session.commit() From 4a54cf02c1f6d48f9ec9676d4381fb936e23f6ad Mon Sep 17 00:00:00 2001 From: jbellister-slac Date: Thu, 16 Apr 2026 14:43:26 -0700 Subject: [PATCH 5/6] Add fallback for snapshot restore to use fastapi background tasks if arq cannot enqueue the task for any reason --- app/api/v1/snapshots.py | 55 +++++++++++------ app/services/background_tasks.py | 102 ++++++++++++++++++++++++++++++- 2 files changed, 136 insertions(+), 21 deletions(-) diff --git a/app/api/v1/snapshots.py b/app/api/v1/snapshots.py index f73b345..efe1e46 100644 --- a/app/api/v1/snapshots.py +++ b/app/api/v1/snapshots.py @@ -11,12 +11,12 @@ from app.models.job import JobType from app.schemas.job import JobCreatedDTO from app.dependencies import require_read_access, require_write_access -from app.api.responses import APIException, error_response, success_response +from app.api.responses import APIException, success_response from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO, UpdateSnapshotDTO from app.services.job_service import JobService from app.services.epics_service import EpicsService, get_epics_service from app.services.redis_service import get_redis_service -from app.services.background_tasks import run_snapshot_creation +from app.services.background_tasks import run_snapshot_restore, run_snapshot_creation from app.services.snapshot_service import SnapshotService logger = logging.getLogger(__name__) @@ -188,9 +188,11 @@ async def update_snapshot( @router.post("/{snapshot_id}/restore", dependencies=[Security(require_write_access)]) async def restore_snapshot( snapshot_id: str, + background_tasks: BackgroundTasks, request: RestoreRequestDTO | None = None, db: AsyncSession = Depends(get_db), async_mode: bool = Query(True, alias="async"), + use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"), epics: EpicsService = Depends(get_epics_service), ) -> dict: """ @@ -217,25 +219,38 @@ async def restore_snapshot( ) await db.commit() - pool = await get_arq_pool() - if pool: - try: - await pool.enqueue_job( - "restore_snapshot_task", - job_id=str(job.id), - snapshot_id=snapshot_id, - pv_ids=request.pvIds if request else None, - ) - logger.info(f"Enqueued restore job to Arq: {job.id}") - - return success_response( - JobCreatedDTO( - jobId=job.id, - message=f"Snapshot restore queued ({snapshot_id})", + pv_ids = request.pvIds if request else None + + if use_arq: + pool = await get_arq_pool() + if pool: + try: + await pool.enqueue_job( + "restore_snapshot_task", + job_id=str(job.id), + snapshot_id=snapshot_id, + pv_ids=pv_ids, + ) + logger.info(f"Enqueued restore job to Arq: {job.id}") + + return success_response( + JobCreatedDTO( + jobId=job.id, + message=f"Snapshot restore queued ({snapshot_id})", + ) ) - ) - except Exception as e: - return error_response(500, f"Failed to enqueue to Arq: {e}") + except Exception as e: + logger.warning(f"Failed to enqueue to Arq, falling back to BackgroundTasks: {e}") + + # Fallback to FastAPI BackgroundTasks + background_tasks.add_task(run_snapshot_restore, str(job.id), snapshot_id, pv_ids) + logger.info(f"Scheduled restore job via BackgroundTasks: {job.id}") + return success_response( + JobCreatedDTO( + jobId=job.id, + message=f"Snapshot restore started ({snapshot_id})", + ) + ) result = await service.restore_snapshot(snapshot_id, request) return success_response(result) diff --git a/app/services/background_tasks.py b/app/services/background_tasks.py index 99b86b3..79ec284 100644 --- a/app/services/background_tasks.py +++ b/app/services/background_tasks.py @@ -4,7 +4,7 @@ from datetime import datetime from app.db.session import async_session_maker -from app.schemas.snapshot import NewSnapshotDTO +from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO from app.services.epics_service import get_epics_service from app.services.redis_service import get_redis_service from app.services.snapshot_service import SnapshotService @@ -111,6 +111,106 @@ async def progress_update(current: int, total: int, message: str): logger.exception(f"Failed to update job status: {inner_e}") +async def run_snapshot_restore( + job_id: str, + snapshot_id: str, + pv_ids: list[str] | None = None, +) -> None: + """ + Background task to restore a snapshot. + + This runs in a separate asyncio task and uses its own database session. + """ + logger.info(f"Background task started for job {job_id}: Restoring snapshot '{snapshot_id}'") + + async with async_session_maker() as session: + try: + job_repo = JobRepository(session) + + # Mark job as running + await job_repo.mark_running(job_id) + await session.commit() + await asyncio.sleep(0) + + # Initial progress update + await job_repo.update_progress(job_id, 5, "Loading snapshot values...") + await session.commit() + await asyncio.sleep(0) + + epics = get_epics_service() + snapshot_service = SnapshotService(session, epics) + + # Optional restore request + request = RestoreRequestDTO(pvIds=pv_ids) if pv_ids else None + + last_update = {"progress": 5, "last_time": datetime.now()} + + async def progress_update(current: int, total: int, message: str): + try: + write_progress = int((current / total) * 85) if total > 0 else 0 + job_progress = 10 + write_progress + + now = datetime.now() + time_elapsed = (now - last_update["last_time"]).total_seconds() + progress_changed = job_progress - last_update["progress"] >= 2 + + if progress_changed or time_elapsed >= 2.0 or current >= total: + last_update["progress"] = job_progress + last_update["last_time"] = now + await job_repo.update_progress(job_id, job_progress, message) + await session.commit() + await asyncio.sleep(0) + except Exception as e: + logger.error(f"Error in restore progress_update: {e}") + + result = await snapshot_service.restore_snapshot( + snapshot_id, + request, + progress_callback=progress_update, + ) + + # Build result data with capped failures + result_data = { + "total_pvs": result.totalPVs, + "success_count": result.successCount, + "failure_count": result.failureCount, + "failures": [ + {"pvId": f["pvId"], "pvName": f["pvName"], "error": f["error"]} for f in result.failures[:50] + ], + } + + # Final completion update + if result.failureCount > 0: + completion_message = ( + f"Restored {result.successCount:,}/{result.totalPVs:,} PVs " f"({result.failureCount} failed)" + ) + else: + completion_message = f"All {result.totalPVs:,} PVs have been restored to their snapshot values." + + await job_repo.mark_completed( + job_id, + message=completion_message, + result_data=result_data, + ) + await session.commit() + + logger.info( + f"Background restore completed for job {job_id}: " + f"{result.successCount}/{result.totalPVs} succeeded, " + f"{result.failureCount} failed" + ) + + except Exception as e: + logger.exception(f"Background restore failed for job {job_id}: {e}") + error_msg = f"{type(e).__name__}: {str(e)}" + try: + await session.rollback() + await job_repo.mark_failed(job_id, error_msg) + await session.commit() + except Exception as inner_e: + logger.exception(f"Failed to update restore job status: {inner_e}") + + def schedule_snapshot_creation(job_id: str, title: str, description: str | None = None, use_cache: bool = True) -> None: """ Schedule a snapshot creation task to run in the background. From 563d542099eec08f3a899c1b1931988d654e656e Mon Sep 17 00:00:00 2001 From: jbellister-slac Date: Thu, 16 Apr 2026 15:03:48 -0700 Subject: [PATCH 6/6] TST: Update snapshot restore tests to avoid using the queue --- tests/test_api/test_snapshots.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_api/test_snapshots.py b/tests/test_api/test_snapshots.py index 100f716..798c529 100644 --- a/tests/test_api/test_snapshots.py +++ b/tests/test_api/test_snapshots.py @@ -114,7 +114,7 @@ class TestSnapshotRestore: async def test_restore_snapshot(self, client: AsyncClient, sample_snapshot: dict, mock_epics: MockEpicsService): """Test restoring a snapshot writes values to EPICS.""" snapshot_id = sample_snapshot["id"] - response = await client.post(f"/v1/snapshots/{snapshot_id}/restore") + response = await client.post(f"/v1/snapshots/{snapshot_id}/restore?async=false") assert response.status_code == 200 data = response.json() @@ -131,7 +131,7 @@ async def test_restore_snapshot_partial( snapshot_id = sample_snapshot["id"] pv_ids = [sample_pvs[0]["id"], sample_pvs[1]["id"]] - response = await client.post(f"/v1/snapshots/{snapshot_id}/restore", json={"pvIds": pv_ids}) + response = await client.post(f"/v1/snapshots/{snapshot_id}/restore?async=false", json={"pvIds": pv_ids}) assert response.status_code == 200 data = response.json()