From 586cd1ecbaa1839aabfeafb8418d1b4acefff0a8 Mon Sep 17 00:00:00 2001 From: yekta yazar Date: Wed, 22 Apr 2026 13:53:35 -0700 Subject: [PATCH 1/3] Use dependencies.py service factories in API handlers --- app/api/v1/api_keys.py | 33 +++++++++++++++------------ app/api/v1/jobs.py | 10 ++++----- app/api/v1/pvs.py | 33 ++++++++++++++------------- app/api/v1/snapshots.py | 42 +++++++++++++---------------------- app/api/v1/tags.py | 49 +++++++++++++++++++++++------------------ app/dependencies.py | 16 ++++++++++++-- 6 files changed, 98 insertions(+), 85 deletions(-) diff --git a/app/api/v1/api_keys.py b/app/api/v1/api_keys.py index f3db670..59dd16e 100644 --- a/app/api/v1/api_keys.py +++ b/app/api/v1/api_keys.py @@ -1,9 +1,7 @@ -"""API endpoints for job status monitoring.""" +"""API endpoints for API key management.""" from fastapi import Depends, APIRouter, status -from sqlalchemy.ext.asyncio import AsyncSession -from app.db.session import get_db -from app.dependencies import require_read_access, require_write_access +from app.dependencies import get_api_key_service, require_read_access, require_write_access from app.api.responses import APIException from app.schemas.common import ApiResultResponse from app.schemas.api_key import ApiKeyCreateDTO @@ -13,18 +11,21 @@ @router.get("", dependencies=[Depends(require_read_access)]) -async def list_all_keys(active_only: bool = False, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def list_all_keys( + active_only: bool = False, + service: ApiKeyService = Depends(get_api_key_service), +) -> ApiResultResponse: """List all API Keys, optionally filtered by active status.""" - service = ApiKeyService(db) keys = await service.list_keys(active_only) - key_response = ApiResultResponse(errorCode=0, errorMessage=None, payload=keys) - return key_response + return ApiResultResponse(errorCode=0, errorMessage=None, payload=keys) @router.post("", dependencies=[Depends(require_write_access)]) -async def create_api_key(data: ApiKeyCreateDTO, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def create_api_key( + data: ApiKeyCreateDTO, + service: ApiKeyService = Depends(get_api_key_service), +) -> ApiResultResponse: """Create a new API Key.""" - service = ApiKeyService(db) try: new_key = await service.create_key(data) return ApiResultResponse(errorCode=0, errorMessage=None, payload=new_key) @@ -33,9 +34,11 @@ async def create_api_key(data: ApiKeyCreateDTO, db: AsyncSession = Depends(get_d @router.delete("/{key_id}", dependencies=[Depends(require_write_access)]) -async def deactivate_api_key(key_id: str, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def deactivate_api_key( + key_id: str, + service: ApiKeyService = Depends(get_api_key_service), +) -> ApiResultResponse: """Deactivate an API Key by ID.""" - service = ApiKeyService(db) try: deactivated_key = await service.deactivate_key(key_id) except LookupError as e: @@ -46,8 +49,10 @@ async def deactivate_api_key(key_id: str, db: AsyncSession = Depends(get_db)) -> @router.get("/count", dependencies=[Depends(require_read_access)]) -async def get_api_key_count(active_only: bool = False, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def get_api_key_count( + active_only: bool = False, + service: ApiKeyService = Depends(get_api_key_service), +) -> ApiResultResponse: """Get the current number of API Keys.""" - service = ApiKeyService(db) count = await service.get_count(active_only) return ApiResultResponse(errorCode=0, errorMessage=None, payload=count) diff --git a/app/api/v1/jobs.py b/app/api/v1/jobs.py index ea67311..d203188 100644 --- a/app/api/v1/jobs.py +++ b/app/api/v1/jobs.py @@ -1,9 +1,7 @@ """API endpoints for job status monitoring.""" from fastapi import Depends, Security, APIRouter -from sqlalchemy.ext.asyncio import AsyncSession -from app.db.session import get_db -from app.dependencies import require_read_access +from app.dependencies import get_job_service, require_read_access from app.api.responses import APIException, success_response from app.services.job_service import JobService @@ -11,14 +9,16 @@ @router.get("/{job_id}", dependencies=[Security(require_read_access)]) -async def get_job_status(job_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def get_job_status( + job_id: str, + job_service: JobService = Depends(get_job_service), +) -> dict: """ Get the status of a background job. Returns the current status, progress percentage, and result when complete. Poll this endpoint to track the progress of async operations. """ - job_service = JobService(db) job = await job_service.get_job(job_id) if not job: raise APIException(404, f"Job {job_id} not found", 404) diff --git a/app/api/v1/pvs.py b/app/api/v1/pvs.py index 731178a..c48e7ed 100644 --- a/app/api/v1/pvs.py +++ b/app/api/v1/pvs.py @@ -7,7 +7,7 @@ from app.db.session import get_db from app.schemas.pv import LivePVRequest, NewPVElementDTO, UpdatePVElementDTO -from app.dependencies import require_read_access, require_write_access +from app.dependencies import get_pv_service, require_read_access, require_write_access from app.api.responses import APIException, success_response from app.services.pv_service import PVService from app.services.epics_service import get_epics_service @@ -18,9 +18,11 @@ @router.get("", dependencies=[Security(require_read_access)]) -async def search_pvs(pvName: str | None = Query(None), db: AsyncSession = Depends(get_db)) -> dict: +async def search_pvs( + pvName: str | None = Query(None), + service: PVService = Depends(get_pv_service), +) -> dict: """Search PVs by name (non-paginated, for backward compatibility).""" - service = PVService(db) result = await service.search_paged(search=pvName, page_size=1000) return success_response(result.results) @@ -31,7 +33,7 @@ async def search_pvs_paged( pageSize: int = Query(100, ge=1, le=1000), continuationToken: str | None = Query(None), tagFilters: str | None = Query(None, description="JSON object: {groupId: [tagId1, tagId2], ...}"), - db: AsyncSession = Depends(get_db), + service: PVService = Depends(get_pv_service), ) -> dict: """ Search PVs with pagination and optional tag filtering. @@ -39,8 +41,6 @@ async def search_pvs_paged( Example tagFilters: {"group-1": ["tag-a", "tag-b"], "group-2": ["tag-c"]} This returns PVs that have (tag-a OR tag-b) AND (tag-c) """ - service = PVService(db) - # Parse tag filters from JSON string tag_filters = None if tagFilters: @@ -63,9 +63,8 @@ async def search_pvs_paged( @router.post("", dependencies=[Security(require_write_access)]) -async def create_pv(data: NewPVElementDTO, db: AsyncSession = Depends(get_db)) -> dict: +async def create_pv(data: NewPVElementDTO, service: PVService = Depends(get_pv_service)) -> dict: """Create a new PV.""" - service = PVService(db) try: pv = await service.create(data) return success_response(pv) @@ -74,9 +73,11 @@ async def create_pv(data: NewPVElementDTO, db: AsyncSession = Depends(get_db)) - @router.post("/multi", dependencies=[Security(require_write_access)]) -async def create_multiple_pvs(data: list[NewPVElementDTO], db: AsyncSession = Depends(get_db)) -> dict: +async def create_multiple_pvs( + data: list[NewPVElementDTO], + service: PVService = Depends(get_pv_service), +) -> dict: """Bulk create PVs (for CSV import).""" - service = PVService(db) try: pvs = await service.create_many(data) return success_response(pvs) @@ -85,13 +86,16 @@ async def create_multiple_pvs(data: list[NewPVElementDTO], db: AsyncSession = De @router.put("/{pv_id}", dependencies=[Security(require_write_access)]) -async def update_pv(pv_id: str, data: UpdatePVElementDTO, db: AsyncSession = Depends(get_db)) -> dict: +async def update_pv( + pv_id: str, + data: UpdatePVElementDTO, + service: PVService = Depends(get_pv_service), +) -> dict: """Update a PV.""" try: UUID(pv_id) except ValueError: raise APIException(404, f"PV {pv_id} not found", 404) - service = PVService(db) pv = await service.update(pv_id, data) if not pv: raise APIException(404, f"PV {pv_id} not found", 404) @@ -99,13 +103,12 @@ async def update_pv(pv_id: str, data: UpdatePVElementDTO, db: AsyncSession = Dep @router.delete("/{pv_id}", dependencies=[Security(require_write_access)]) -async def delete_pv(pv_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def delete_pv(pv_id: str, service: PVService = Depends(get_pv_service)) -> dict: """Delete a PV.""" try: UUID(pv_id) except ValueError: raise APIException(404, f"PV {pv_id} not found", 404) - service = PVService(db) success = await service.delete(pv_id) if not success: raise APIException(404, f"PV {pv_id} not found", 404) @@ -121,6 +124,7 @@ async def search_pvs_filtered( offset: int = Query(0, description="Offset for pagination"), include_live_values: bool = Query(False, description="Include Redis cache values"), db: AsyncSession = Depends(get_db), + service: PVService = Depends(get_pv_service), ) -> dict: """ Server-side filtered search with optional live values. @@ -128,7 +132,6 @@ async def search_pvs_filtered( This is more efficient than client-side filtering for large datasets. """ pv_repo = PVRepository(db) - service = PVService(db) pvs, total = await pv_repo.search_filtered( search_term=q, devices=devices if devices else None, tag_ids=tags if tags else None, limit=limit, offset=offset diff --git a/app/api/v1/snapshots.py b/app/api/v1/snapshots.py index efe1e46..c680665 100644 --- a/app/api/v1/snapshots.py +++ b/app/api/v1/snapshots.py @@ -10,12 +10,15 @@ from app.db.session import get_db from app.models.job import JobType from app.schemas.job import JobCreatedDTO -from app.dependencies import require_read_access, require_write_access +from app.dependencies import ( + get_job_service, + get_snapshot_service, + require_read_access, + require_write_access, +) from app.api.responses import APIException, success_response from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO, UpdateSnapshotDTO from app.services.job_service import JobService -from app.services.epics_service import EpicsService, get_epics_service -from app.services.redis_service import get_redis_service from app.services.background_tasks import run_snapshot_restore, run_snapshot_creation from app.services.snapshot_service import SnapshotService @@ -46,11 +49,9 @@ async def list_snapshots( title: str | None = Query(None), tags: list[str] | None = Query(None, description="Filter by tag IDs (returns snapshots containing PVs with any of these tags)"), - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), + service: SnapshotService = Depends(get_snapshot_service), ) -> dict: """List all snapshots, optionally filtered by title and/or tags.""" - service = SnapshotService(db, epics) snapshots = await service.list_snapshots(title=title, tag_ids=tags) return success_response(snapshots) @@ -60,8 +61,7 @@ async def get_snapshot( snapshot_id: str, limit: int | None = Query(None, description="Limit number of PV values returned"), offset: int = Query(0, description="Offset for pagination"), - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), + service: SnapshotService = Depends(get_snapshot_service), ) -> dict: """ Get snapshot by ID with values. @@ -73,7 +73,6 @@ async def get_snapshot( UUID(snapshot_id) except ValueError: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - service = SnapshotService(db, epics) snapshot = await service.get_by_id(snapshot_id, limit=limit, offset=offset) if not snapshot: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) @@ -88,7 +87,8 @@ async def create_snapshot( async_mode: bool = Query(True, alias="async", description="Run snapshot creation in background"), use_cache: bool = Query(True, description="Read from Redis cache (instant) vs direct EPICS read"), use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"), - epics: EpicsService = Depends(get_epics_service), + service: SnapshotService = Depends(get_snapshot_service), + job_service: JobService = Depends(get_job_service), ) -> dict: """ Create a new snapshot by reading all PVs. @@ -106,7 +106,6 @@ async def create_snapshot( """ if async_mode: # Create a job record - job_service = JobService(db) job = await job_service.create_job( JobType.SNAPSHOT_CREATE, job_data={"title": data.title, "description": data.description, "use_cache": use_cache}, @@ -152,9 +151,6 @@ async def create_snapshot( ) else: # Synchronous mode (legacy behavior) - redis = get_redis_service() - service = SnapshotService(db, epics, redis) - if use_cache: snapshot = await service.create_snapshot_from_cache(data) else: @@ -167,12 +163,9 @@ async def create_snapshot( async def update_snapshot( snapshot_id: str, data: UpdateSnapshotDTO, - db: AsyncSession = Depends(get_db), + service: SnapshotService = Depends(get_snapshot_service), ) -> dict: """Update snapshot title and/or description.""" - epics = get_epics_service() - service = SnapshotService(db, epics) - snapshot = await service.update_snapshot_metadata( snapshot_id, title=data.title, @@ -193,7 +186,8 @@ async def restore_snapshot( db: AsyncSession = Depends(get_db), async_mode: bool = Query(True, alias="async"), use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"), - epics: EpicsService = Depends(get_epics_service), + service: SnapshotService = Depends(get_snapshot_service), + job_service: JobService = Depends(get_job_service), ) -> dict: """ Restore PV values from a snapshot to EPICS. @@ -204,7 +198,6 @@ async def restore_snapshot( UUID(snapshot_id) except ValueError: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - service = SnapshotService(db, epics) # Verify snapshot exists snapshot = await service.get_by_id(snapshot_id) @@ -212,7 +205,6 @@ async def restore_snapshot( raise APIException(404, f"Snapshot {snapshot_id} not found", 404) if async_mode: - job_service = JobService(db) job = await job_service.create_job( JobType.SNAPSHOT_RESTORE, job_data={"snapshotId": snapshot_id}, @@ -260,15 +252,13 @@ async def restore_snapshot( async def delete_snapshot( snapshot_id: str, deleteData: bool = Query(True), - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), + service: SnapshotService = Depends(get_snapshot_service), ) -> dict: """Delete a snapshot.""" try: UUID(snapshot_id) except ValueError: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - service = SnapshotService(db, epics) success = await service.delete_snapshot(snapshot_id) if not success: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) @@ -279,8 +269,7 @@ async def delete_snapshot( async def compare_snapshots( snapshot1_id: str, snapshot2_id: str, - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), + service: SnapshotService = Depends(get_snapshot_service), ) -> dict: """Compare two snapshots and return differences.""" try: @@ -288,7 +277,6 @@ async def compare_snapshots( UUID(snapshot2_id) except ValueError: raise APIException(404, "Snapshot not found", 404) - service = SnapshotService(db, epics) try: result = await service.compare_snapshots(snapshot1_id, snapshot2_id) return success_response(result) diff --git a/app/api/v1/tags.py b/app/api/v1/tags.py index 3ff2dd7..947aae6 100644 --- a/app/api/v1/tags.py +++ b/app/api/v1/tags.py @@ -2,11 +2,9 @@ from fastapi import Query, Depends, Security, APIRouter from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncSession -from app.db.session import get_db from app.schemas.tag import TagCreate, TagUpdate, TagGroupCreate, TagGroupUpdate -from app.dependencies import require_read_access, require_write_access +from app.dependencies import get_tag_service, require_read_access, require_write_access from app.api.responses import APIException, success_response from app.services.tag_service import TagService @@ -14,21 +12,19 @@ @router.get("", dependencies=[Security(require_read_access)]) -async def get_all_tag_groups(db: AsyncSession = Depends(get_db)) -> dict: +async def get_all_tag_groups(service: TagService = Depends(get_tag_service)) -> dict: """Get all tag groups with tag counts.""" - service = TagService(db) groups = await service.get_all_groups_summary() return success_response(groups) @router.get("/{group_id}", dependencies=[Security(require_read_access)]) -async def get_tag_group(group_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def get_tag_group(group_id: str, service: TagService = Depends(get_tag_service)) -> dict: """Get tag group by ID with all tags.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) group = await service.get_group_by_id(group_id) if not group: raise APIException(404, f"Tag group {group_id} not found", 404) @@ -37,9 +33,8 @@ async def get_tag_group(group_id: str, db: AsyncSession = Depends(get_db)) -> di @router.post("", dependencies=[Security(require_write_access)]) -async def create_tag_group(data: TagGroupCreate, db: AsyncSession = Depends(get_db)) -> dict: +async def create_tag_group(data: TagGroupCreate, service: TagService = Depends(get_tag_service)) -> dict: """Create a new tag group.""" - service = TagService(db) try: group = await service.create_group(data) return success_response(group) @@ -48,13 +43,16 @@ async def create_tag_group(data: TagGroupCreate, db: AsyncSession = Depends(get_ @router.put("/{group_id}", dependencies=[Security(require_write_access)]) -async def update_tag_group(group_id: str, data: TagGroupUpdate, db: AsyncSession = Depends(get_db)) -> dict: +async def update_tag_group( + group_id: str, + data: TagGroupUpdate, + service: TagService = Depends(get_tag_service), +) -> dict: """Update a tag group.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) try: group = await service.update_group(group_id, data) if not group: @@ -65,13 +63,16 @@ async def update_tag_group(group_id: str, data: TagGroupUpdate, db: AsyncSession @router.delete("/{group_id}", dependencies=[Security(require_write_access)]) -async def delete_tag_group(group_id: str, force: bool = Query(False), db: AsyncSession = Depends(get_db)) -> dict: +async def delete_tag_group( + group_id: str, + force: bool = Query(False), + service: TagService = Depends(get_tag_service), +) -> dict: """Delete a tag group.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) success = await service.delete_group(group_id, force=force) if not success: raise APIException(404, f"Tag group {group_id} not found", 404) @@ -83,14 +84,13 @@ async def add_tag_to_group( group_id: str, data: TagCreate, skip_duplicates: bool = Query(False, description="Skip duplicate tags instead of raising error"), - db: AsyncSession = Depends(get_db), + service: TagService = Depends(get_tag_service), ) -> dict: """Add a tag to a group.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) try: group, was_created = await service.add_tag_to_group(group_id, data, skip_duplicates=skip_duplicates) if not group: @@ -101,14 +101,18 @@ async def add_tag_to_group( @router.put("/{group_id}/tags/{tag_id}", dependencies=[Security(require_write_access)]) -async def update_tag(group_id: str, tag_id: str, data: TagUpdate, db: AsyncSession = Depends(get_db)) -> dict: +async def update_tag( + group_id: str, + tag_id: str, + data: TagUpdate, + service: TagService = Depends(get_tag_service), +) -> dict: """Update a tag in a group.""" try: UUID(group_id) UUID(tag_id) except ValueError: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) - service = TagService(db) group = await service.update_tag(group_id, tag_id, data) if not group: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) @@ -116,14 +120,17 @@ async def update_tag(group_id: str, tag_id: str, data: TagUpdate, db: AsyncSessi @router.delete("/{group_id}/tags/{tag_id}", dependencies=[Security(require_write_access)]) -async def remove_tag(group_id: str, tag_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def remove_tag( + group_id: str, + tag_id: str, + service: TagService = Depends(get_tag_service), +) -> dict: """Remove a tag from a group.""" try: UUID(group_id) UUID(tag_id) except ValueError: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) - service = TagService(db) group = await service.remove_tag(group_id, tag_id) if not group: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) @@ -146,10 +153,8 @@ class BulkTagImportResponse(BaseModel): @router.post("/bulk", response_model=dict) -async def bulk_import_tags(data: BulkTagImportRequest, db: AsyncSession = Depends(get_db)): +async def bulk_import_tags(data: BulkTagImportRequest, service: TagService = Depends(get_tag_service)): """Bulk import tags with duplicate handling.""" - service = TagService(db) - groups_created = 0 tags_created = 0 tags_skipped = 0 diff --git a/app/dependencies.py b/app/dependencies.py index fd3c5a5..32c0b11 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -9,7 +9,9 @@ from app.api.responses import APIException from app.schemas.api_key import ApiKeyDTO from app.services.pv_service import PVService +from app.services.job_service import JobService from app.services.tag_service import TagService +from app.services.redis_service import get_redis_service from app.services.epics_service import EpicsService, get_epics_service from app.services.api_key_service import ApiKeyService from app.services.snapshot_service import SnapshotService @@ -30,8 +32,8 @@ def get_pv_service(db: AsyncSession = Depends(get_db)) -> PVService: def get_snapshot_service( db: AsyncSession = Depends(get_db), epics: EpicsService = Depends(get_epics_service) ) -> SnapshotService: - """Get Snapshot service instance.""" - return SnapshotService(db, epics) + """Get Snapshot service instance (Redis attached from the module singleton).""" + return SnapshotService(db, epics, get_redis_service()) def get_tag_service(db: AsyncSession = Depends(get_db)) -> TagService: @@ -39,6 +41,16 @@ def get_tag_service(db: AsyncSession = Depends(get_db)) -> TagService: return TagService(db) +def get_api_key_service(db: AsyncSession = Depends(get_db)) -> ApiKeyService: + """Get API key service instance.""" + return ApiKeyService(db) + + +def get_job_service(db: AsyncSession = Depends(get_db)) -> JobService: + """Get Job service instance.""" + return JobService(db) + + # --------------------------------------------------------------------------- # API Key auth dependencies # --------------------------------------------------------------------------- From 0ddbadfdba9c97d63e223bc77052ba1d1e021050 Mon Sep 17 00:00:00 2001 From: yekta yazar Date: Wed, 22 Apr 2026 14:14:25 -0700 Subject: [PATCH 2/3] STYL: formatting fix --- app/api/v1/api_keys.py | 6 +++++- app/api/v1/snapshots.py | 2 +- app/dependencies.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/app/api/v1/api_keys.py b/app/api/v1/api_keys.py index 59dd16e..1edc8be 100644 --- a/app/api/v1/api_keys.py +++ b/app/api/v1/api_keys.py @@ -1,7 +1,11 @@ """API endpoints for API key management.""" from fastapi import Depends, APIRouter, status -from app.dependencies import get_api_key_service, require_read_access, require_write_access +from app.dependencies import ( + get_api_key_service, + require_read_access, + require_write_access, +) from app.api.responses import APIException from app.schemas.common import ApiResultResponse from app.schemas.api_key import ApiKeyCreateDTO diff --git a/app/api/v1/snapshots.py b/app/api/v1/snapshots.py index c680665..ca4fd7b 100644 --- a/app/api/v1/snapshots.py +++ b/app/api/v1/snapshots.py @@ -12,8 +12,8 @@ from app.schemas.job import JobCreatedDTO from app.dependencies import ( get_job_service, - get_snapshot_service, require_read_access, + get_snapshot_service, require_write_access, ) from app.api.responses import APIException, success_response diff --git a/app/dependencies.py b/app/dependencies.py index 32c0b11..36117f4 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -11,8 +11,8 @@ from app.services.pv_service import PVService from app.services.job_service import JobService from app.services.tag_service import TagService -from app.services.redis_service import get_redis_service from app.services.epics_service import EpicsService, get_epics_service +from app.services.redis_service import get_redis_service from app.services.api_key_service import ApiKeyService from app.services.snapshot_service import SnapshotService From a5ed8712060d5f82340ad340332310323326aaf5 Mon Sep 17 00:00:00 2001 From: yekta yazar Date: Wed, 22 Apr 2026 14:30:58 -0700 Subject: [PATCH 3/3] Remove envelope response pattern and add Pydantic response models --- app/api/responses.py | 32 ---- app/api/v1/api_keys.py | 40 +++-- app/api/v1/health.py | 244 +++++++++++++++++-------------- app/api/v1/jobs.py | 12 +- app/api/v1/pvs.py | 222 ++++++++++++++++++---------- app/api/v1/snapshots.py | 152 ++++++++++--------- app/api/v1/tags.py | 144 +++++++++++------- app/dependencies.py | 18 +-- app/main.py | 132 +---------------- app/schemas/__init__.py | 3 +- app/schemas/common.py | 8 - app/schemas/health.py | 55 +++++++ app/schemas/pv.py | 52 +++++++ app/schemas/tag.py | 7 + tests/conftest.py | 10 +- tests/test_api/test_api_keys.py | 36 ++--- tests/test_api/test_pvs.py | 90 +++++------- tests/test_api/test_snapshots.py | 60 +++----- tests/test_api/test_tags.py | 101 +++++-------- tests/test_dependencies.py | 32 ++-- 20 files changed, 728 insertions(+), 722 deletions(-) delete mode 100644 app/api/responses.py diff --git a/app/api/responses.py b/app/api/responses.py deleted file mode 100644 index a0e2768..0000000 --- a/app/api/responses.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import TypeVar - -from fastapi import HTTPException -from fastapi.responses import JSONResponse - -from app.schemas.common import ApiResultResponse - -T = TypeVar("T") - - -def success_response(payload: T) -> dict: - """Wrap successful response.""" - return ApiResultResponse(errorCode=0, errorMessage=None, payload=payload).model_dump() - - -def error_response(code: int, message: str, status_code: int = 400) -> JSONResponse: - """Create error response.""" - return JSONResponse( - status_code=status_code, - content=ApiResultResponse(errorCode=code, errorMessage=message, payload=None).model_dump(), - ) - - -class APIException(HTTPException): - """Custom API exception with error codes.""" - - def __init__(self, error_code: int, message: str, status_code: int | None = None): - if status_code is None: - status_code = error_code - super().__init__(status_code=status_code, detail=message) - self.error_code = error_code - self.error_message = message diff --git a/app/api/v1/api_keys.py b/app/api/v1/api_keys.py index 1edc8be..81c1ae2 100644 --- a/app/api/v1/api_keys.py +++ b/app/api/v1/api_keys.py @@ -1,62 +1,56 @@ """API endpoints for API key management.""" -from fastapi import Depends, APIRouter, status +from fastapi import Depends, APIRouter, HTTPException, status from app.dependencies import ( get_api_key_service, require_read_access, require_write_access, ) -from app.api.responses import APIException -from app.schemas.common import ApiResultResponse -from app.schemas.api_key import ApiKeyCreateDTO +from app.schemas.api_key import ApiKeyDTO, ApiKeyCreateDTO, ApiKeyCreateResultDTO from app.services.api_key_service import ApiKeyService router = APIRouter(prefix="/api-keys", tags=["ApiKeys"]) -@router.get("", dependencies=[Depends(require_read_access)]) +@router.get("", dependencies=[Depends(require_read_access)], response_model=list[ApiKeyDTO]) async def list_all_keys( active_only: bool = False, service: ApiKeyService = Depends(get_api_key_service), -) -> ApiResultResponse: +) -> list[ApiKeyDTO]: """List all API Keys, optionally filtered by active status.""" - keys = await service.list_keys(active_only) - return ApiResultResponse(errorCode=0, errorMessage=None, payload=keys) + return await service.list_keys(active_only) -@router.post("", dependencies=[Depends(require_write_access)]) +@router.post("", dependencies=[Depends(require_write_access)], response_model=ApiKeyCreateResultDTO) async def create_api_key( data: ApiKeyCreateDTO, service: ApiKeyService = Depends(get_api_key_service), -) -> ApiResultResponse: +) -> ApiKeyCreateResultDTO: """Create a new API Key.""" try: - new_key = await service.create_key(data) - return ApiResultResponse(errorCode=0, errorMessage=None, payload=new_key) + return await service.create_key(data) except ValueError as e: - raise APIException(status.HTTP_409_CONFLICT, str(e), status_code=status.HTTP_409_CONFLICT) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) -@router.delete("/{key_id}", dependencies=[Depends(require_write_access)]) +@router.delete("/{key_id}", dependencies=[Depends(require_write_access)], response_model=ApiKeyDTO) async def deactivate_api_key( key_id: str, service: ApiKeyService = Depends(get_api_key_service), -) -> ApiResultResponse: +) -> ApiKeyDTO: """Deactivate an API Key by ID.""" try: - deactivated_key = await service.deactivate_key(key_id) + return await service.deactivate_key(key_id) except LookupError as e: - raise APIException(status.HTTP_404_NOT_FOUND, str(e), status_code=status.HTTP_404_NOT_FOUND) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) except ValueError as e: - raise APIException(status.HTTP_409_CONFLICT, str(e), status_code=status.HTTP_409_CONFLICT) - return ApiResultResponse(errorCode=0, errorMessage=None, payload=deactivated_key) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) -@router.get("/count", dependencies=[Depends(require_read_access)]) +@router.get("/count", dependencies=[Depends(require_read_access)], response_model=int) async def get_api_key_count( active_only: bool = False, service: ApiKeyService = Depends(get_api_key_service), -) -> ApiResultResponse: +) -> int: """Get the current number of API Keys.""" - count = await service.get_count(active_only) - return ApiResultResponse(errorCode=0, errorMessage=None, payload=count) + return await service.get_count(active_only) diff --git a/app/api/v1/health.py b/app/api/v1/health.py index 42778d3..44f61e9 100644 --- a/app/api/v1/health.py +++ b/app/api/v1/health.py @@ -12,11 +12,17 @@ from fastapi import Security, APIRouter, HTTPException from app.dependencies import require_read_access, require_write_access -from app.api.responses import success_response from app.schemas.health import ( + StalePVsResponse, + HeartbeatResponse, + CircuitStatsResponse, + CircuitActionResponse, + CircuitStatusResponse, HealthSummaryResponse, MonitorHealthResponse, WatchdogStatsResponse, + DisconnectedPVsResponse, + MonitorProcessStatusResponse, ) from app.services.watchdog import get_watchdog from app.services.pv_monitor import get_pv_monitor @@ -25,8 +31,8 @@ router = APIRouter(prefix="/health", tags=["Health"]) -@router.get("/heartbeat") -async def get_heartbeat() -> dict: +@router.get("/heartbeat", response_model=HeartbeatResponse) +async def get_heartbeat() -> HeartbeatResponse: """ Simple heartbeat check for frontend polling. @@ -39,39 +45,37 @@ async def get_heartbeat() -> dict: # Check if Redis is connected if not redis.is_connected(): - return success_response( - { - "timestamp": None, - "alive": False, - "age_seconds": None, - "error": "Redis not connected", - } + return HeartbeatResponse( + timestamp=None, + alive=False, + age_seconds=None, + error="Redis not connected", ) heartbeat = await redis.get_heartbeat() alive = await redis.is_monitor_alive(max_age_seconds=5.0) age_seconds = await redis.get_heartbeat_age() - return success_response( - { - "timestamp": heartbeat, - "alive": alive, - "age_seconds": age_seconds, - } + return HeartbeatResponse( + timestamp=heartbeat, + alive=alive, + age_seconds=age_seconds, ) except Exception as e: # Redis not available - monitor is definitely not healthy - return success_response( - { - "timestamp": None, - "alive": False, - "age_seconds": None, - "error": str(e), - } + return HeartbeatResponse( + timestamp=None, + alive=False, + age_seconds=None, + error=str(e), ) -@router.get("/monitor", dependencies=[Security(require_read_access)]) +@router.get( + "/monitor", + dependencies=[Security(require_read_access)], + response_model=MonitorHealthResponse, +) async def get_monitor_health() -> MonitorHealthResponse: """ Get detailed monitor health information. @@ -112,7 +116,11 @@ async def get_monitor_health() -> MonitorHealthResponse: raise HTTPException(status_code=503, detail=f"Health check failed: {e}") -@router.get("/watchdog", dependencies=[Security(require_read_access)]) +@router.get( + "/watchdog", + dependencies=[Security(require_read_access)], + response_model=WatchdogStatsResponse, +) async def get_watchdog_stats() -> WatchdogStatsResponse: """ Get watchdog statistics. @@ -138,7 +146,11 @@ async def get_watchdog_stats() -> WatchdogStatsResponse: raise HTTPException(status_code=503, detail=f"Watchdog stats failed: {e}") -@router.post("/watchdog/check", dependencies=[Security(require_write_access)]) +@router.post( + "/watchdog/check", + dependencies=[Security(require_write_access)], + response_model=WatchdogStatsResponse, +) async def force_watchdog_check() -> WatchdogStatsResponse: """ Force an immediate watchdog health check. @@ -165,7 +177,11 @@ async def force_watchdog_check() -> WatchdogStatsResponse: raise HTTPException(status_code=503, detail=f"Watchdog check failed: {e}") -@router.get("/summary", dependencies=[Security(require_read_access)]) +@router.get( + "/summary", + dependencies=[Security(require_read_access)], + response_model=HealthSummaryResponse, +) async def get_health_summary() -> HealthSummaryResponse: """ Get a complete health summary for monitoring dashboards. @@ -252,8 +268,12 @@ async def get_health_summary() -> HealthSummaryResponse: ) -@router.get("/disconnected", dependencies=[Security(require_read_access)]) -async def get_disconnected_pvs() -> dict: +@router.get( + "/disconnected", + dependencies=[Security(require_read_access)], + response_model=DisconnectedPVsResponse, +) +async def get_disconnected_pvs() -> DisconnectedPVsResponse: """ Get list of all disconnected PVs. @@ -262,18 +282,18 @@ async def get_disconnected_pvs() -> dict: try: redis = get_redis_service() disconnected = await redis.get_disconnected_pvs() - - return { - "count": len(disconnected), - "pvs": sorted(list(disconnected)), - } + return DisconnectedPVsResponse(count=len(disconnected), pvs=sorted(list(disconnected))) except Exception as e: raise HTTPException(status_code=503, detail=f"Failed to get disconnected PVs: {e}") -@router.get("/stale", dependencies=[Security(require_read_access)]) -async def get_stale_pvs(max_age_seconds: float = 300) -> dict: +@router.get( + "/stale", + dependencies=[Security(require_read_access)], + response_model=StalePVsResponse, +) +async def get_stale_pvs(max_age_seconds: float = 300) -> StalePVsResponse: """ Get list of stale PVs (connected but not updated recently). @@ -283,19 +303,18 @@ async def get_stale_pvs(max_age_seconds: float = 300) -> dict: try: redis = get_redis_service() stale = await redis.get_stale_pvs(max_age_seconds=max_age_seconds) - - return { - "count": len(stale), - "threshold_seconds": max_age_seconds, - "pvs": sorted(stale), - } + return StalePVsResponse(count=len(stale), threshold_seconds=max_age_seconds, pvs=sorted(stale)) except Exception as e: raise HTTPException(status_code=503, detail=f"Failed to get stale PVs: {e}") -@router.get("/circuits", dependencies=[Security(require_read_access)]) -async def get_circuit_breaker_status() -> dict: +@router.get( + "/circuits", + dependencies=[Security(require_read_access)], + response_model=CircuitStatusResponse, +) +async def get_circuit_breaker_status() -> CircuitStatusResponse: """ Get circuit breaker status for all EPICS IOCs. @@ -316,41 +335,43 @@ async def get_circuit_breaker_status() -> dict: stats = manager.get_all_stats() open_circuits = manager.get_open_circuits() - return { - "open_circuit_count": len(open_circuits), - "total_circuits": len(stats), - "open_circuits": open_circuits, - "circuits": [ - { - "name": s.name, - "state": s.state.value, - "failure_count": s.failure_count, - "success_count": s.success_count, - "call_count": s.call_count, - "last_failure": s.last_failure.isoformat() if s.last_failure else None, - "opened_at": s.opened_at.isoformat() if s.opened_at else None, - } + return CircuitStatusResponse( + open_circuit_count=len(open_circuits), + total_circuits=len(stats), + open_circuits=open_circuits, + circuits=[ + CircuitStatsResponse( + name=s.name, + state=s.state.value, + failure_count=s.failure_count, + success_count=s.success_count, + call_count=s.call_count, + last_failure=s.last_failure.isoformat() if s.last_failure else None, + opened_at=s.opened_at.isoformat() if s.opened_at else None, + ) for s in stats ], - } + ) except ImportError: - return { - "error": "Circuit breaker not available", - "open_circuit_count": 0, - "total_circuits": 0, - "circuits": [], - } + return CircuitStatusResponse( + open_circuit_count=0, + total_circuits=0, + error="Circuit breaker not available", + ) except Exception as e: - return { - "error": str(e), - "open_circuit_count": 0, - "total_circuits": 0, - "circuits": [], - } + return CircuitStatusResponse( + open_circuit_count=0, + total_circuits=0, + error=str(e), + ) -@router.post("/circuits/{circuit_name}/close", dependencies=[Security(require_write_access)]) -async def force_close_circuit(circuit_name: str) -> dict: +@router.post( + "/circuits/{circuit_name}/close", + dependencies=[Security(require_write_access)], + response_model=CircuitActionResponse, +) +async def force_close_circuit(circuit_name: str) -> CircuitActionResponse: """ Force close a circuit breaker (allow requests to IOC). @@ -362,13 +383,17 @@ async def force_close_circuit(circuit_name: str) -> dict: manager = get_circuit_breaker_manager() manager.force_close(circuit_name) - return {"success": True, "message": f"Circuit '{circuit_name}' forced closed"} + return CircuitActionResponse(success=True, message=f"Circuit '{circuit_name}' forced closed") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.post("/circuits/{circuit_name}/open", dependencies=[Security(require_write_access)]) -async def force_open_circuit(circuit_name: str) -> dict: +@router.post( + "/circuits/{circuit_name}/open", + dependencies=[Security(require_write_access)], + response_model=CircuitActionResponse, +) +async def force_open_circuit(circuit_name: str) -> CircuitActionResponse: """ Force open a circuit breaker (block requests to IOC). @@ -380,13 +405,17 @@ async def force_open_circuit(circuit_name: str) -> dict: manager = get_circuit_breaker_manager() manager.force_open(circuit_name) - return {"success": True, "message": f"Circuit '{circuit_name}' forced open"} + return CircuitActionResponse(success=True, message=f"Circuit '{circuit_name}' forced open") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@router.get("/monitor/status", dependencies=[Security(require_read_access)]) -async def monitor_process_status() -> dict: +@router.get( + "/monitor/status", + dependencies=[Security(require_read_access)], + response_model=MonitorProcessStatusResponse, +) +async def monitor_process_status() -> MonitorProcessStatusResponse: """ Check if the separate PV Monitor process is alive via Redis heartbeat. @@ -402,44 +431,39 @@ async def monitor_process_status() -> dict: redis = get_redis_service() if not redis.is_connected(): - return { - "status": "unknown", - "message": "Redis not connected", - "age_seconds": None, - "leader": None, - } + return MonitorProcessStatusResponse( + status="unknown", + message="Redis not connected", + ) heartbeat = await redis.get_monitor_heartbeat() heartbeat_age = await redis.get_heartbeat_age() leader = await redis.get_monitor_lock_holder() if heartbeat is None: - return { - "status": "unknown", - "message": "No heartbeat found - monitor may not be running", - "age_seconds": None, - "leader": leader, - } + return MonitorProcessStatusResponse( + status="unknown", + message="No heartbeat found - monitor may not be running", + leader=leader, + ) if heartbeat_age is not None and heartbeat_age > 30: - return { - "status": "stale", - "message": f"Heartbeat is {heartbeat_age:.1f}s old - monitor may be down", - "age_seconds": heartbeat_age, - "leader": leader, - } - - return { - "status": "healthy", - "message": "Monitor process is alive", - "age_seconds": heartbeat_age, - "leader": leader, - } + return MonitorProcessStatusResponse( + status="stale", + message=f"Heartbeat is {heartbeat_age:.1f}s old - monitor may be down", + age_seconds=heartbeat_age, + leader=leader, + ) + + return MonitorProcessStatusResponse( + status="healthy", + message="Monitor process is alive", + age_seconds=heartbeat_age, + leader=leader, + ) except Exception as e: - return { - "status": "error", - "message": str(e), - "age_seconds": None, - "leader": None, - } + return MonitorProcessStatusResponse( + status="error", + message=str(e), + ) diff --git a/app/api/v1/jobs.py b/app/api/v1/jobs.py index d203188..bbb9597 100644 --- a/app/api/v1/jobs.py +++ b/app/api/v1/jobs.py @@ -1,18 +1,18 @@ """API endpoints for job status monitoring.""" -from fastapi import Depends, Security, APIRouter +from fastapi import Depends, Security, APIRouter, HTTPException +from app.schemas.job import JobDTO from app.dependencies import get_job_service, require_read_access -from app.api.responses import APIException, success_response from app.services.job_service import JobService router = APIRouter(prefix="/jobs", tags=["Jobs"]) -@router.get("/{job_id}", dependencies=[Security(require_read_access)]) +@router.get("/{job_id}", dependencies=[Security(require_read_access)], response_model=JobDTO) async def get_job_status( job_id: str, job_service: JobService = Depends(get_job_service), -) -> dict: +) -> JobDTO: """ Get the status of a background job. @@ -21,5 +21,5 @@ async def get_job_status( """ job = await job_service.get_job(job_id) if not job: - raise APIException(404, f"Job {job_id} not found", 404) - return success_response(job) + raise HTTPException(status_code=404, detail=f"Job {job_id} not found") + return job diff --git a/app/api/v1/pvs.py b/app/api/v1/pvs.py index c48e7ed..ecad4ba 100644 --- a/app/api/v1/pvs.py +++ b/app/api/v1/pvs.py @@ -2,13 +2,23 @@ import json from uuid import UUID -from fastapi import Query, Depends, Security, APIRouter +from fastapi import Query, Depends, Security, APIRouter, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from app.db.session import get_db -from app.schemas.pv import LivePVRequest, NewPVElementDTO, UpdatePVElementDTO +from app.schemas.pv import ( + PVElementDTO, + LivePVRequest, + NewPVElementDTO, + EpicsTestResponse, + UpdatePVElementDTO, + CacheStatusResponse, + PVCacheEntryResponse, + AllLiveValuesResponse, + FilteredSearchResponse, +) from app.dependencies import get_pv_service, require_read_access, require_write_access -from app.api.responses import APIException, success_response +from app.schemas.common import PagedResult from app.services.pv_service import PVService from app.services.epics_service import get_epics_service from app.services.redis_service import get_redis_service @@ -17,24 +27,32 @@ router = APIRouter(prefix="/pvs", tags=["PVs"]) -@router.get("", dependencies=[Security(require_read_access)]) +@router.get( + "", + dependencies=[Security(require_read_access)], + response_model=list[PVElementDTO], +) async def search_pvs( pvName: str | None = Query(None), service: PVService = Depends(get_pv_service), -) -> dict: +) -> list[PVElementDTO]: """Search PVs by name (non-paginated, for backward compatibility).""" result = await service.search_paged(search=pvName, page_size=1000) - return success_response(result.results) + return result.results -@router.get("/paged", dependencies=[Security(require_read_access)]) +@router.get( + "/paged", + dependencies=[Security(require_read_access)], + response_model=PagedResult[PVElementDTO], +) async def search_pvs_paged( pvName: str | None = Query(None), pageSize: int = Query(100, ge=1, le=1000), continuationToken: str | None = Query(None), tagFilters: str | None = Query(None, description="JSON object: {groupId: [tagId1, tagId2], ...}"), service: PVService = Depends(get_pv_service), -) -> dict: +) -> PagedResult[PVElementDTO]: """ Search PVs with pagination and optional tag filtering. @@ -49,73 +67,90 @@ async def search_pvs_paged( if not tag_filters: tag_filters = None except json.JSONDecodeError as e: - raise APIException(400, f"Invalid tagFilters JSON: {e}", 400) + raise HTTPException(status_code=400, detail=f"Invalid tagFilters JSON: {e}") except ValueError as e: - raise APIException(400, str(e), 400) + raise HTTPException(status_code=400, detail=str(e)) - result = await service.search_paged( + return await service.search_paged( search=pvName, page_size=pageSize, continuation_token=continuationToken, tag_filters=tag_filters, ) - return success_response(result) -@router.post("", dependencies=[Security(require_write_access)]) -async def create_pv(data: NewPVElementDTO, service: PVService = Depends(get_pv_service)) -> dict: +@router.post( + "", + dependencies=[Security(require_write_access)], + response_model=PVElementDTO, +) +async def create_pv(data: NewPVElementDTO, service: PVService = Depends(get_pv_service)) -> PVElementDTO: """Create a new PV.""" try: - pv = await service.create(data) - return success_response(pv) + return await service.create(data) except ValueError as e: - raise APIException(409, str(e), 409) + raise HTTPException(status_code=409, detail=str(e)) -@router.post("/multi", dependencies=[Security(require_write_access)]) +@router.post( + "/multi", + dependencies=[Security(require_write_access)], + response_model=list[PVElementDTO], +) async def create_multiple_pvs( data: list[NewPVElementDTO], service: PVService = Depends(get_pv_service), -) -> dict: +) -> list[PVElementDTO]: """Bulk create PVs (for CSV import).""" try: - pvs = await service.create_many(data) - return success_response(pvs) + return await service.create_many(data) except ValueError as e: - raise APIException(409, str(e), 409) + raise HTTPException(status_code=409, detail=str(e)) -@router.put("/{pv_id}", dependencies=[Security(require_write_access)]) +@router.put( + "/{pv_id}", + dependencies=[Security(require_write_access)], + response_model=PVElementDTO, +) async def update_pv( pv_id: str, data: UpdatePVElementDTO, service: PVService = Depends(get_pv_service), -) -> dict: +) -> PVElementDTO: """Update a PV.""" try: UUID(pv_id) except ValueError: - raise APIException(404, f"PV {pv_id} not found", 404) + raise HTTPException(status_code=404, detail=f"PV {pv_id} not found") pv = await service.update(pv_id, data) if not pv: - raise APIException(404, f"PV {pv_id} not found", 404) - return success_response(pv) + raise HTTPException(status_code=404, detail=f"PV {pv_id} not found") + return pv -@router.delete("/{pv_id}", dependencies=[Security(require_write_access)]) -async def delete_pv(pv_id: str, service: PVService = Depends(get_pv_service)) -> dict: +@router.delete( + "/{pv_id}", + dependencies=[Security(require_write_access)], + response_model=bool, +) +async def delete_pv(pv_id: str, service: PVService = Depends(get_pv_service)) -> bool: """Delete a PV.""" try: UUID(pv_id) except ValueError: - raise APIException(404, f"PV {pv_id} not found", 404) + raise HTTPException(status_code=404, detail=f"PV {pv_id} not found") success = await service.delete(pv_id) if not success: - raise APIException(404, f"PV {pv_id} not found", 404) - return success_response(True) + raise HTTPException(status_code=404, detail=f"PV {pv_id} not found") + return True -@router.get("/search", dependencies=[Security(require_read_access)]) +@router.get( + "/search", + dependencies=[Security(require_read_access)], + response_model=FilteredSearchResponse, +) async def search_pvs_filtered( q: str | None = Query(None, description="Text search"), devices: list[str] = Query(default=[], description="Filter by device"), @@ -125,7 +160,7 @@ async def search_pvs_filtered( include_live_values: bool = Query(False, description="Include Redis cache values"), db: AsyncSession = Depends(get_db), service: PVService = Depends(get_pv_service), -) -> dict: +) -> FilteredSearchResponse: """ Server-side filtered search with optional live values. @@ -140,9 +175,9 @@ async def search_pvs_filtered( # Convert to DTOs results = [service._to_dto(pv) for pv in pvs] - response = {"results": results, "totalCount": total, "limit": limit, "offset": offset} + live_values: dict[str, PVCacheEntryResponse] | None = None + live_values_error: str | None = None - # Optionally include live values from Redis if include_live_values: try: redis = get_redis_service() @@ -153,84 +188,115 @@ async def search_pvs_filtered( if pv.readback_address: pv_addresses.append(pv.readback_address) - live_values = await redis.get_pv_values_bulk(pv_addresses) - response["liveValues"] = live_values + entries = await redis.get_pv_values_bulk(pv_addresses) + live_values = {name: PVCacheEntryResponse(**entry.to_dict()) for name, entry in entries.items()} except Exception as e: - response["liveValuesError"] = str(e) - - return success_response(response) + live_values_error = str(e) + + return FilteredSearchResponse( + results=results, + totalCount=total, + limit=limit, + offset=offset, + liveValues=live_values, + liveValuesError=live_values_error, + ) -@router.get("/devices", dependencies=[Security(require_read_access)]) -async def get_all_devices(db: AsyncSession = Depends(get_db)) -> dict: +@router.get( + "/devices", + dependencies=[Security(require_read_access)], + response_model=list[str], +) +async def get_all_devices(db: AsyncSession = Depends(get_db)) -> list[str]: """Get all unique device names for filtering.""" pv_repo = PVRepository(db) - devices = await pv_repo.get_all_devices() - return success_response(devices) + return await pv_repo.get_all_devices() -@router.get("/live", dependencies=[Security(require_read_access)]) -async def get_live_values(pv_names: list[str] = Query(..., description="List of PV names to fetch")) -> dict: +@router.get( + "/live", + dependencies=[Security(require_read_access)], + response_model=dict[str, PVCacheEntryResponse], +) +async def get_live_values( + pv_names: list[str] = Query(..., description="List of PV names to fetch"), +) -> dict[str, PVCacheEntryResponse]: """Get current values from Redis cache (instant).""" try: redis = get_redis_service() entries = await redis.get_pv_values_bulk(pv_names) - # Convert PVCacheEntry objects to dicts for JSON serialization - values = {pv_name: entry.to_dict() for pv_name, entry in entries.items()} - return success_response(values) + return {pv_name: PVCacheEntryResponse(**entry.to_dict()) for pv_name, entry in entries.items()} except Exception as e: - raise APIException(500, f"Failed to get live values: {e}", 500) + raise HTTPException(status_code=500, detail=f"Failed to get live values: {e}") -@router.post("/live", dependencies=[Security(require_read_access)]) -async def get_live_values_post(request: LivePVRequest) -> dict: +@router.post( + "/live", + dependencies=[Security(require_read_access)], + response_model=dict[str, PVCacheEntryResponse], +) +async def get_live_values_post(request: LivePVRequest) -> dict[str, PVCacheEntryResponse]: """Get current values from Redis cache (instant) - POST version for large PV lists.""" try: redis = get_redis_service() entries = await redis.get_pv_values_bulk(request.pv_names) - # Convert PVCacheEntry objects to dicts for JSON serialization - values = {pv_name: entry.to_dict() for pv_name, entry in entries.items()} - return success_response(values) + return {pv_name: PVCacheEntryResponse(**entry.to_dict()) for pv_name, entry in entries.items()} except Exception as e: - raise APIException(500, f"Failed to get live values: {e}", 500) + raise HTTPException(status_code=500, detail=f"Failed to get live values: {e}") -@router.get("/live/all", dependencies=[Security(require_read_access)]) -async def get_all_live_values() -> dict: +@router.get( + "/live/all", + dependencies=[Security(require_read_access)], + response_model=AllLiveValuesResponse, +) +async def get_all_live_values() -> AllLiveValuesResponse: """Get all cached PV values (for initial table load).""" try: redis = get_redis_service() values = await redis.get_all_pv_values_as_dict() - return success_response({"values": values, "count": len(values)}) + return AllLiveValuesResponse( + values={name: PVCacheEntryResponse(**entry) for name, entry in values.items()}, + count=len(values), + ) except Exception as e: - raise APIException(500, f"Failed to get live values: {e}", 500) + raise HTTPException(status_code=500, detail=f"Failed to get live values: {e}") -@router.get("/cache/status", dependencies=[Security(require_read_access)]) -async def get_cache_status() -> dict: +@router.get( + "/cache/status", + dependencies=[Security(require_read_access)], + response_model=CacheStatusResponse, +) +async def get_cache_status() -> CacheStatusResponse: """Get Redis cache status.""" try: redis = get_redis_service() count = await redis.get_cached_pv_count() - return success_response({"cachedPvCount": count, "status": "connected"}) + return CacheStatusResponse(cachedPvCount=count, status="connected") except Exception as e: - return success_response({"cachedPvCount": 0, "status": "disconnected", "error": str(e)}) + return CacheStatusResponse(cachedPvCount=0, status="disconnected", error=str(e)) -@router.get("/test-epics", dependencies=[Security(require_read_access)]) -async def test_epics_connection(pv: str = Query("KLYS:LI22:31:KVAC", description="PV name to test")) -> dict: +@router.get( + "/test-epics", + dependencies=[Security(require_read_access)], + response_model=EpicsTestResponse, +) +async def test_epics_connection( + pv: str = Query("KLYS:LI22:31:KVAC", description="PV name to test"), +) -> EpicsTestResponse: """Test EPICS connectivity using aioca.""" epics = get_epics_service() result = await epics.get_single(pv) - return success_response( - { - "pv": pv, - "connected": result.connected, - "value": result.value, - "error": result.error, - "environment": { - "EPICS_CA_ADDR_LIST": os.environ.get("EPICS_CA_ADDR_LIST", "NOT SET"), - "EPICS_CA_AUTO_ADDR_LIST": os.environ.get("EPICS_CA_AUTO_ADDR_LIST", "NOT SET"), - }, - } + return EpicsTestResponse( + pv=pv, + connected=result.connected, + value=result.value, + error=result.error, + environment={ + "EPICS_CA_ADDR_LIST": os.environ.get("EPICS_CA_ADDR_LIST", "NOT SET"), + "EPICS_CA_AUTO_ADDR_LIST": os.environ.get("EPICS_CA_AUTO_ADDR_LIST", "NOT SET"), + }, ) diff --git a/app/api/v1/snapshots.py b/app/api/v1/snapshots.py index ca4fd7b..ffbaa5d 100644 --- a/app/api/v1/snapshots.py +++ b/app/api/v1/snapshots.py @@ -2,7 +2,7 @@ from uuid import UUID from arq import create_pool -from fastapi import Query, Depends, Security, APIRouter, BackgroundTasks +from fastapi import Query, Depends, Security, APIRouter, HTTPException, BackgroundTasks from arq.connections import RedisSettings from sqlalchemy.ext.asyncio import AsyncSession @@ -16,8 +16,15 @@ get_snapshot_service, require_write_access, ) -from app.api.responses import APIException, success_response -from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO, UpdateSnapshotDTO +from app.schemas.snapshot import ( + SnapshotDTO, + NewSnapshotDTO, + RestoreResultDTO, + RestoreRequestDTO, + UpdateSnapshotDTO, + SnapshotSummaryDTO, + ComparisonResultDTO, +) from app.services.job_service import JobService from app.services.background_tasks import run_snapshot_restore, run_snapshot_creation from app.services.snapshot_service import SnapshotService @@ -44,25 +51,32 @@ async def get_arq_pool(): return _arq_pool -@router.get("", dependencies=[Security(require_read_access)]) +@router.get( + "", + dependencies=[Security(require_read_access)], + response_model=list[SnapshotSummaryDTO], +) async def list_snapshots( title: str | None = Query(None), tags: list[str] | None = Query(None, description="Filter by tag IDs (returns snapshots containing PVs with any of these tags)"), service: SnapshotService = Depends(get_snapshot_service), -) -> dict: +) -> list[SnapshotSummaryDTO]: """List all snapshots, optionally filtered by title and/or tags.""" - snapshots = await service.list_snapshots(title=title, tag_ids=tags) - return success_response(snapshots) + return await service.list_snapshots(title=title, tag_ids=tags) -@router.get("/{snapshot_id}", dependencies=[Security(require_read_access)]) +@router.get( + "/{snapshot_id}", + dependencies=[Security(require_read_access)], + response_model=SnapshotDTO, +) async def get_snapshot( snapshot_id: str, limit: int | None = Query(None, description="Limit number of PV values returned"), offset: int = Query(0, description="Offset for pagination"), service: SnapshotService = Depends(get_snapshot_service), -) -> dict: +) -> SnapshotDTO: """ Get snapshot by ID with values. @@ -72,14 +86,18 @@ async def get_snapshot( try: UUID(snapshot_id) except ValueError: - raise APIException(404, f"Snapshot {snapshot_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Snapshot {snapshot_id} not found") snapshot = await service.get_by_id(snapshot_id, limit=limit, offset=offset) if not snapshot: - raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - return success_response(snapshot) + raise HTTPException(status_code=404, detail=f"Snapshot {snapshot_id} not found") + return snapshot -@router.post("", dependencies=[Security(require_write_access)]) +@router.post( + "", + dependencies=[Security(require_write_access)], + response_model=JobCreatedDTO | SnapshotSummaryDTO, +) async def create_snapshot( data: NewSnapshotDTO, background_tasks: BackgroundTasks, @@ -89,7 +107,7 @@ async def create_snapshot( use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"), service: SnapshotService = Depends(get_snapshot_service), job_service: JobService = Depends(get_job_service), -) -> dict: +) -> JobCreatedDTO | SnapshotSummaryDTO: """ Create a new snapshot by reading all PVs. @@ -100,9 +118,9 @@ async def create_snapshot( - use_arq=true (default): Use Arq persistent queue (survives restarts) - use_arq=false: Use FastAPI BackgroundTasks (lost on restart) - By default (async=true), this returns immediately with a job ID that can be - polled for status. Set async=false for synchronous operation (may timeout - for large numbers of PVs). + By default (async=true), this returns immediately with a JobCreatedDTO that can be + polled for status. Set async=false for synchronous operation — the response is + the completed SnapshotSummaryDTO (may timeout for large numbers of PVs). """ if async_mode: # Create a job record @@ -128,12 +146,10 @@ async def create_snapshot( use_cache=use_cache, ) logger.info(f"Enqueued snapshot job to Arq: {job.id}") - return success_response( - JobCreatedDTO( - jobId=job.id, - message=f"Snapshot creation queued for '{data.title}'" - + (" (from cache)" if use_cache else " (direct EPICS)"), - ) + return JobCreatedDTO( + jobId=job.id, + message=f"Snapshot creation queued for '{data.title}'" + + (" (from cache)" if use_cache else " (direct EPICS)"), ) except Exception as e: logger.warning(f"Failed to enqueue to Arq, falling back to BackgroundTasks: {e}") @@ -142,29 +158,28 @@ async def create_snapshot( background_tasks.add_task(run_snapshot_creation, job.id, data.title, data.description, use_cache) logger.info(f"Scheduled snapshot job via BackgroundTasks: {job.id}") - return success_response( - JobCreatedDTO( - jobId=job.id, - message=f"Snapshot creation started for '{data.title}'" - + (" (from cache)" if use_cache else " (direct EPICS)"), - ) + return JobCreatedDTO( + jobId=job.id, + message=f"Snapshot creation started for '{data.title}'" + + (" (from cache)" if use_cache else " (direct EPICS)"), ) else: # Synchronous mode (legacy behavior) if use_cache: - snapshot = await service.create_snapshot_from_cache(data) - else: - snapshot = await service.create_snapshot(data) - - return success_response(snapshot) + return await service.create_snapshot_from_cache(data) + return await service.create_snapshot(data) -@router.put("/{snapshot_id}", dependencies=[Security(require_write_access)]) +@router.put( + "/{snapshot_id}", + dependencies=[Security(require_write_access)], + response_model=SnapshotSummaryDTO, +) async def update_snapshot( snapshot_id: str, data: UpdateSnapshotDTO, service: SnapshotService = Depends(get_snapshot_service), -) -> dict: +) -> SnapshotSummaryDTO: """Update snapshot title and/or description.""" snapshot = await service.update_snapshot_metadata( snapshot_id, @@ -173,12 +188,16 @@ async def update_snapshot( ) if not snapshot: - raise APIException(404, f"Snapshot {snapshot_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Snapshot {snapshot_id} not found") - return success_response(snapshot) + return snapshot -@router.post("/{snapshot_id}/restore", dependencies=[Security(require_write_access)]) +@router.post( + "/{snapshot_id}/restore", + dependencies=[Security(require_write_access)], + response_model=JobCreatedDTO | RestoreResultDTO, +) async def restore_snapshot( snapshot_id: str, background_tasks: BackgroundTasks, @@ -188,7 +207,7 @@ async def restore_snapshot( use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"), service: SnapshotService = Depends(get_snapshot_service), job_service: JobService = Depends(get_job_service), -) -> dict: +) -> JobCreatedDTO | RestoreResultDTO: """ Restore PV values from a snapshot to EPICS. @@ -197,12 +216,12 @@ async def restore_snapshot( try: UUID(snapshot_id) except ValueError: - raise APIException(404, f"Snapshot {snapshot_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Snapshot {snapshot_id} not found") # Verify snapshot exists snapshot = await service.get_by_id(snapshot_id) if not snapshot: - raise APIException(404, f"Snapshot {snapshot_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Snapshot {snapshot_id} not found") if async_mode: job = await job_service.create_job( @@ -224,12 +243,9 @@ async def restore_snapshot( pv_ids=pv_ids, ) logger.info(f"Enqueued restore job to Arq: {job.id}") - - return success_response( - JobCreatedDTO( - jobId=job.id, - message=f"Snapshot restore queued ({snapshot_id})", - ) + return JobCreatedDTO( + jobId=job.id, + message=f"Snapshot restore queued ({snapshot_id})", ) except Exception as e: logger.warning(f"Failed to enqueue to Arq, falling back to BackgroundTasks: {e}") @@ -237,48 +253,52 @@ async def restore_snapshot( # Fallback to FastAPI BackgroundTasks background_tasks.add_task(run_snapshot_restore, str(job.id), snapshot_id, pv_ids) logger.info(f"Scheduled restore job via BackgroundTasks: {job.id}") - return success_response( - JobCreatedDTO( - jobId=job.id, - message=f"Snapshot restore started ({snapshot_id})", - ) + return JobCreatedDTO( + jobId=job.id, + message=f"Snapshot restore started ({snapshot_id})", ) - result = await service.restore_snapshot(snapshot_id, request) - return success_response(result) + return await service.restore_snapshot(snapshot_id, request) -@router.delete("/{snapshot_id}", dependencies=[Security(require_write_access)]) +@router.delete( + "/{snapshot_id}", + dependencies=[Security(require_write_access)], + response_model=bool, +) async def delete_snapshot( snapshot_id: str, deleteData: bool = Query(True), service: SnapshotService = Depends(get_snapshot_service), -) -> dict: +) -> bool: """Delete a snapshot.""" try: UUID(snapshot_id) except ValueError: - raise APIException(404, f"Snapshot {snapshot_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Snapshot {snapshot_id} not found") success = await service.delete_snapshot(snapshot_id) if not success: - raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - return success_response(True) + raise HTTPException(status_code=404, detail=f"Snapshot {snapshot_id} not found") + return True -@router.get("/{snapshot1_id}/compare/{snapshot2_id}", dependencies=[Security(require_read_access)]) +@router.get( + "/{snapshot1_id}/compare/{snapshot2_id}", + dependencies=[Security(require_read_access)], + response_model=ComparisonResultDTO, +) async def compare_snapshots( snapshot1_id: str, snapshot2_id: str, service: SnapshotService = Depends(get_snapshot_service), -) -> dict: +) -> ComparisonResultDTO: """Compare two snapshots and return differences.""" try: UUID(snapshot1_id) UUID(snapshot2_id) except ValueError: - raise APIException(404, "Snapshot not found", 404) + raise HTTPException(status_code=404, detail="Snapshot not found") try: - result = await service.compare_snapshots(snapshot1_id, snapshot2_id) - return success_response(result) + return await service.compare_snapshots(snapshot1_id, snapshot2_id) except ValueError as e: - raise APIException(404, str(e), 404) + raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/v1/tags.py b/app/api/v1/tags.py index 947aae6..c582023 100644 --- a/app/api/v1/tags.py +++ b/app/api/v1/tags.py @@ -1,140 +1,177 @@ from uuid import UUID -from fastapi import Query, Depends, Security, APIRouter +from fastapi import Query, Depends, Security, APIRouter, HTTPException from pydantic import BaseModel -from app.schemas.tag import TagCreate, TagUpdate, TagGroupCreate, TagGroupUpdate +from app.schemas.tag import ( + TagCreate, + TagUpdate, + TagGroupDTO, + AddTagResponse, + TagGroupCreate, + TagGroupUpdate, + TagGroupSummaryDTO, +) from app.dependencies import get_tag_service, require_read_access, require_write_access -from app.api.responses import APIException, success_response from app.services.tag_service import TagService router = APIRouter(prefix="/tags", tags=["Tags"]) -@router.get("", dependencies=[Security(require_read_access)]) -async def get_all_tag_groups(service: TagService = Depends(get_tag_service)) -> dict: +@router.get( + "", + dependencies=[Security(require_read_access)], + response_model=list[TagGroupSummaryDTO], +) +async def get_all_tag_groups(service: TagService = Depends(get_tag_service)) -> list[TagGroupSummaryDTO]: """Get all tag groups with tag counts.""" - groups = await service.get_all_groups_summary() - return success_response(groups) + return await service.get_all_groups_summary() -@router.get("/{group_id}", dependencies=[Security(require_read_access)]) -async def get_tag_group(group_id: str, service: TagService = Depends(get_tag_service)) -> dict: +@router.get( + "/{group_id}", + dependencies=[Security(require_read_access)], + response_model=list[TagGroupDTO], +) +async def get_tag_group(group_id: str, service: TagService = Depends(get_tag_service)) -> list[TagGroupDTO]: """Get tag group by ID with all tags.""" try: UUID(group_id) except ValueError: - raise APIException(404, f"Tag group {group_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") group = await service.get_group_by_id(group_id) if not group: - raise APIException(404, f"Tag group {group_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") # Return as array to match frontend expectations - return success_response([group]) + return [group] -@router.post("", dependencies=[Security(require_write_access)]) -async def create_tag_group(data: TagGroupCreate, service: TagService = Depends(get_tag_service)) -> dict: +@router.post( + "", + dependencies=[Security(require_write_access)], + response_model=TagGroupDTO, +) +async def create_tag_group(data: TagGroupCreate, service: TagService = Depends(get_tag_service)) -> TagGroupDTO: """Create a new tag group.""" try: - group = await service.create_group(data) - return success_response(group) + return await service.create_group(data) except ValueError as e: - raise APIException(409, str(e), 409) + raise HTTPException(status_code=409, detail=str(e)) -@router.put("/{group_id}", dependencies=[Security(require_write_access)]) +@router.put( + "/{group_id}", + dependencies=[Security(require_write_access)], + response_model=TagGroupDTO, +) async def update_tag_group( group_id: str, data: TagGroupUpdate, service: TagService = Depends(get_tag_service), -) -> dict: +) -> TagGroupDTO: """Update a tag group.""" try: UUID(group_id) except ValueError: - raise APIException(404, f"Tag group {group_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") try: group = await service.update_group(group_id, data) if not group: - raise APIException(404, f"Tag group {group_id} not found", 404) - return success_response(group) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") + return group except ValueError as e: - raise APIException(409, str(e), 409) + raise HTTPException(status_code=409, detail=str(e)) -@router.delete("/{group_id}", dependencies=[Security(require_write_access)]) +@router.delete( + "/{group_id}", + dependencies=[Security(require_write_access)], + response_model=bool, +) async def delete_tag_group( group_id: str, force: bool = Query(False), service: TagService = Depends(get_tag_service), -) -> dict: +) -> bool: """Delete a tag group.""" try: UUID(group_id) except ValueError: - raise APIException(404, f"Tag group {group_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") success = await service.delete_group(group_id, force=force) if not success: - raise APIException(404, f"Tag group {group_id} not found", 404) - return success_response(True) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") + return True -@router.post("/{group_id}/tags", dependencies=[Security(require_write_access)]) +@router.post( + "/{group_id}/tags", + dependencies=[Security(require_write_access)], + response_model=AddTagResponse, +) async def add_tag_to_group( group_id: str, data: TagCreate, skip_duplicates: bool = Query(False, description="Skip duplicate tags instead of raising error"), service: TagService = Depends(get_tag_service), -) -> dict: +) -> AddTagResponse: """Add a tag to a group.""" try: UUID(group_id) except ValueError: - raise APIException(404, f"Tag group {group_id} not found", 404) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") try: group, was_created = await service.add_tag_to_group(group_id, data, skip_duplicates=skip_duplicates) if not group: - raise APIException(404, f"Tag group {group_id} not found", 404) - return success_response({"group": group, "wasCreated": was_created}) + raise HTTPException(status_code=404, detail=f"Tag group {group_id} not found") + return AddTagResponse(group=group, wasCreated=was_created) except ValueError as e: - raise APIException(409, str(e), 409) + raise HTTPException(status_code=409, detail=str(e)) -@router.put("/{group_id}/tags/{tag_id}", dependencies=[Security(require_write_access)]) +@router.put( + "/{group_id}/tags/{tag_id}", + dependencies=[Security(require_write_access)], + response_model=TagGroupDTO, +) async def update_tag( group_id: str, tag_id: str, data: TagUpdate, service: TagService = Depends(get_tag_service), -) -> dict: +) -> TagGroupDTO: """Update a tag in a group.""" try: UUID(group_id) UUID(tag_id) except ValueError: - raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) + raise HTTPException(status_code=404, detail=f"Tag {tag_id} not found in group {group_id}") group = await service.update_tag(group_id, tag_id, data) if not group: - raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) - return success_response(group) + raise HTTPException(status_code=404, detail=f"Tag {tag_id} not found in group {group_id}") + return group -@router.delete("/{group_id}/tags/{tag_id}", dependencies=[Security(require_write_access)]) +@router.delete( + "/{group_id}/tags/{tag_id}", + dependencies=[Security(require_write_access)], + response_model=TagGroupDTO, +) async def remove_tag( group_id: str, tag_id: str, service: TagService = Depends(get_tag_service), -) -> dict: +) -> TagGroupDTO: """Remove a tag from a group.""" try: UUID(group_id) UUID(tag_id) except ValueError: - raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) + raise HTTPException(status_code=404, detail=f"Tag {tag_id} not found in group {group_id}") group = await service.remove_tag(group_id, tag_id) if not group: - raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) - return success_response(group) + raise HTTPException(status_code=404, detail=f"Tag {tag_id} not found in group {group_id}") + return group class BulkTagImportRequest(BaseModel): @@ -152,8 +189,11 @@ class BulkTagImportResponse(BaseModel): warnings: list[str] -@router.post("/bulk", response_model=dict) -async def bulk_import_tags(data: BulkTagImportRequest, service: TagService = Depends(get_tag_service)): +@router.post("/bulk", response_model=BulkTagImportResponse) +async def bulk_import_tags( + data: BulkTagImportRequest, + service: TagService = Depends(get_tag_service), +) -> BulkTagImportResponse: """Bulk import tags with duplicate handling.""" groups_created = 0 tags_created = 0 @@ -192,11 +232,9 @@ async def bulk_import_tags(data: BulkTagImportRequest, service: TagService = Dep except Exception as e: warnings.append(f"Failed to add tag '{tag_name}' to group '{group_name}': {e}") - return success_response( - BulkTagImportResponse( - groupsCreated=groups_created, - tagsCreated=tags_created, - tagsSkipped=tags_skipped, - warnings=warnings, - ) + return BulkTagImportResponse( + groupsCreated=groups_created, + tagsCreated=tags_created, + tagsSkipped=tags_skipped, + warnings=warnings, ) diff --git a/app/dependencies.py b/app/dependencies.py index 36117f4..381e9f2 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -1,12 +1,11 @@ from typing import Annotated -from fastapi import Depends, Security, WebSocket, status +from fastapi import Depends, Security, WebSocket, HTTPException, status from fastapi.security import APIKeyHeader from fastapi.exceptions import WebSocketException from sqlalchemy.ext.asyncio import AsyncSession from app.db.session import get_db -from app.api.responses import APIException from app.schemas.api_key import ApiKeyDTO from app.services.pv_service import PVService from app.services.job_service import JobService @@ -66,30 +65,27 @@ async def get_api_key( if api_key_dto and api_key_dto.isActive: return api_key_dto - raise APIException( - error_code=status.HTTP_401_UNAUTHORIZED, - message="Missing or deactivated API key", + raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing or deactivated API key", ) def require_read_access(api_key_dto: Annotated[ApiKeyDTO, Security(get_api_key)]): """Dependency that requires a valid, active API Key with read access.""" if not api_key_dto.readAccess: - raise APIException( - error_code=status.HTTP_401_UNAUTHORIZED, - message="API key does not have read access", + raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key does not have read access", ) def require_write_access(api_key_dto: Annotated[ApiKeyDTO, Security(get_api_key)]): """Dependency that requires a valid, active API Key with write access.""" if not api_key_dto.writeAccess: - raise APIException( - error_code=status.HTTP_401_UNAUTHORIZED, - message="API key does not have write access", + raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key does not have write access", ) diff --git a/app/main.py b/app/main.py index e92ec15..e306e5f 100644 --- a/app/main.py +++ b/app/main.py @@ -12,7 +12,6 @@ - Independent scaling and deployment """ -import os import logging from contextlib import asynccontextmanager @@ -21,13 +20,10 @@ from fastapi.middleware.cors import CORSMiddleware from app.config import get_settings -from app.api.responses import APIException from app.api.v1.router import router as v1_router from app.api.v1.websocket import get_diff_manager -from app.services.pv_protocol import is_unprefixed, parse_pv_name from app.services.epics_service import get_epics_service from app.services.redis_service import get_redis_service -from app.services.pvaccess_monitor import get_pvaccess_monitor # Configure logging logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") @@ -35,9 +31,6 @@ settings = get_settings() -# Environment variable to optionally enable embedded monitor (for backward compatibility) -EMBEDDED_MONITOR = os.environ.get("SQUIRREL_EMBEDDED_MONITOR", "false").lower() == "true" - @asynccontextmanager async def lifespan(app: FastAPI): @@ -78,15 +71,7 @@ async def lifespan(app: FastAPI): if monitor_alive: logger.info("PV Monitor process detected (via heartbeat)") else: - logger.warning( - "PV Monitor process not detected - start squirrel-monitor separately, " - "or set SQUIRREL_EMBEDDED_MONITOR=true to run embedded" - ) - - # Optionally start embedded monitor (for backward compatibility/development) - if EMBEDDED_MONITOR: - logger.info("Starting EMBEDDED PV Monitor (SQUIRREL_EMBEDDED_MONITOR=true)") - await _start_embedded_monitor(redis_service, epics) + logger.warning("PV Monitor process not detected - start squirrel-monitor separately") # Start WebSocket diff stream manager (subscribes to Redis pub/sub) diff_manager = get_diff_manager() @@ -104,10 +89,6 @@ async def lifespan(app: FastAPI): # Cleanup logger.info("Shutting down Squirrel API...") - # Stop embedded monitor if running - if EMBEDDED_MONITOR: - await _stop_embedded_monitor() - # Stop WebSocket manager try: diff_manager = get_diff_manager() @@ -132,100 +113,6 @@ async def lifespan(app: FastAPI): logger.info("Squirrel API shutdown complete") -async def _start_embedded_monitor(redis_service, epics): - """ - Start PV Monitor and Watchdog in embedded mode (backward compatibility). - - This is enabled by setting SQUIRREL_EMBEDDED_MONITOR=true. - """ - from app.db.session import async_session_maker - from app.services.watchdog import get_watchdog - from app.services.pv_monitor import get_pv_monitor - from app.repositories.pv_repository import PVRepository - - pv_monitor = get_pv_monitor(redis_service) - pva_monitor = None - - # Get all PV addresses from database - async with async_session_maker() as session: - pv_repo = PVRepository(session) - pv_addresses_data = await pv_repo.get_all_addresses() - - # Extract unique addresses (setpoint and readback) - pv_addresses = set() - for _, setpoint, readback, config in pv_addresses_data: - if setpoint: - pv_addresses.add(setpoint) - if readback: - pv_addresses.add(readback) - - # Start PV monitoring (with batched startup) - ca_pvs: list[str] = [] - pva_pvs: list[str] = [] - for pv_name in pv_addresses: - protocol, _ = parse_pv_name(pv_name) - if protocol == "pva": - pva_pvs.append(pv_name) - else: - ca_pvs.append(pv_name) - if settings.epics_unprefixed_pva_fallback and is_unprefixed(pv_name): - pva_pvs.append(pv_name) - - if pv_addresses: - logger.info(f"[EMBEDDED] Starting PV Monitor for {len(ca_pvs)} CA and {len(pva_pvs)} PVA addresses") - await pv_monitor.start(ca_pvs) - logger.info(f"[EMBEDDED] PV Monitor started for {len(ca_pvs)} CA addresses") - - if pva_pvs: - pva_monitor = get_pvaccess_monitor(redis_service) - await pva_monitor.start(pva_pvs) - logger.info(f"[EMBEDDED] PVAccess Monitor started for {len(pva_pvs)} PVA addresses") - else: - logger.info("[EMBEDDED] No PVA addresses found; PVAccess Monitor not started") - else: - logger.warning("[EMBEDDED] No PV addresses found in database") - - # Start Watchdog if enabled - if settings.watchdog_enabled: - watchdog = get_watchdog(redis_service, epics, pv_monitor, pva_monitor if pva_pvs else None) - await watchdog.start() - logger.info("[EMBEDDED] Watchdog started") - - -async def _stop_embedded_monitor(): - """Stop embedded PV Monitor and Watchdog.""" - from app.services.watchdog import get_watchdog - from app.services.pv_monitor import get_pv_monitor - - # Stop Watchdog - if settings.watchdog_enabled: - try: - watchdog = get_watchdog() - if watchdog.is_running(): - await watchdog.stop() - logger.info("[EMBEDDED] Watchdog stopped") - except Exception as e: - logger.error(f"Error stopping Watchdog: {e}") - - # Stop PV Monitor - try: - pv_monitor = get_pv_monitor() - if pv_monitor.is_running(): - await pv_monitor.stop() - logger.info("[EMBEDDED] PV Monitor stopped") - except Exception as e: - logger.error(f"Error stopping PV Monitor: {e}") - - # Stop PVAccess Monitor - try: - pva_monitor = get_pvaccess_monitor() - if pva_monitor.is_running(): - await pva_monitor.stop() - logger.info("[EMBEDDED] PVAccess Monitor stopped") - except Exception as e: - logger.error(f"Error stopping PVAccess Monitor: {e}") - - app = FastAPI( title="Squirrel Backend", description="High-performance EPICS snapshot/restore backend with 40k PV support", @@ -243,26 +130,13 @@ async def _stop_embedded_monitor(): ) -# Exception handler for APIException -@app.exception_handler(APIException) -async def api_exception_handler(request: Request, exc: APIException): - return JSONResponse( - status_code=exc.status_code, - content={"errorCode": exc.error_code, "errorMessage": exc.error_message, "payload": None}, - ) - - -# Generic exception handler +# Generic exception handler — returns FastAPI-style {"detail": ...} body @app.exception_handler(Exception) async def generic_exception_handler(request: Request, exc: Exception): logger.exception(f"Unhandled exception: {exc}") return JSONResponse( status_code=500, - content={ - "errorCode": 500, - "errorMessage": str(exc) if settings.debug else "Internal server error", - "payload": None, - }, + content={"detail": str(exc) if settings.debug else "Internal server error"}, ) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index 416236d..f5a89d7 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -1,4 +1,4 @@ -from app.schemas.common import ApiResultResponse, PagedResult +from app.schemas.common import PagedResult from app.schemas.tag import ( TagDTO, TagCreate, @@ -21,7 +21,6 @@ ) __all__ = [ - "ApiResultResponse", "PagedResult", "TagDTO", "TagCreate", diff --git a/app/schemas/common.py b/app/schemas/common.py index 15be6b6..25e9901 100644 --- a/app/schemas/common.py +++ b/app/schemas/common.py @@ -5,14 +5,6 @@ T = TypeVar("T") -class ApiResultResponse(BaseModel, Generic[T]): - """Standard API response wrapper matching frontend expectations.""" - - errorCode: int = 0 - errorMessage: str | None = None - payload: T - - class PagedResult(BaseModel, Generic[T]): """Paginated result with continuation token.""" diff --git a/app/schemas/health.py b/app/schemas/health.py index 7f23280..f5d540f 100644 --- a/app/schemas/health.py +++ b/app/schemas/health.py @@ -8,6 +8,8 @@ class HeartbeatResponse(BaseModel): timestamp: float | None alive: bool + age_seconds: float | None = None + error: str | None = None class MonitorHealthResponse(BaseModel): @@ -52,3 +54,56 @@ class HealthSummaryResponse(BaseModel): watchdog_running: bool last_watchdog_check: datetime | None issues: list[str] + + +class DisconnectedPVsResponse(BaseModel): + """List of disconnected PVs.""" + + count: int + pvs: list[str] + + +class StalePVsResponse(BaseModel): + """List of stale (un-updated) PVs.""" + + count: int + threshold_seconds: float + pvs: list[str] + + +class CircuitStatsResponse(BaseModel): + """Per-circuit statistics.""" + + name: str + state: str + failure_count: int + success_count: int + call_count: int + last_failure: str | None = None + opened_at: str | None = None + + +class CircuitStatusResponse(BaseModel): + """Aggregated circuit-breaker status for all EPICS IOCs.""" + + open_circuit_count: int + total_circuits: int + open_circuits: list[str] = [] + circuits: list[CircuitStatsResponse] = [] + error: str | None = None + + +class CircuitActionResponse(BaseModel): + """Result of a force-open / force-close action on a circuit breaker.""" + + success: bool + message: str + + +class MonitorProcessStatusResponse(BaseModel): + """Liveness of the separate PV Monitor process.""" + + status: str # "healthy", "stale", "unknown", "error" + message: str | None = None + age_seconds: float | None = None + leader: str | None = None diff --git a/app/schemas/pv.py b/app/schemas/pv.py index e41d157..f382722 100644 --- a/app/schemas/pv.py +++ b/app/schemas/pv.py @@ -1,3 +1,4 @@ +from typing import Any from datetime import datetime from pydantic import Field, BaseModel, model_validator @@ -56,3 +57,54 @@ class LivePVRequest(BaseModel): """DTO for requesting live PV values via POST.""" pv_names: list[str] = Field(..., description="List of PV names to fetch") + + +class PVCacheEntryResponse(BaseModel): + """Single PV cache entry as returned by live-value endpoints.""" + + model_config = {"extra": "allow"} # accept any extra metadata fields + + value: Any | None = None + connected: bool = False + updated_at: float | None = None + status: str | None = None + severity: int | None = None + timestamp: float | None = None + units: str | None = None + error: str | None = None + + +class FilteredSearchResponse(BaseModel): + """Response from GET /v1/pvs/search.""" + + results: list[PVElementDTO] + totalCount: int + limit: int + offset: int + liveValues: dict[str, PVCacheEntryResponse] | None = None + liveValuesError: str | None = None + + +class AllLiveValuesResponse(BaseModel): + """Response from GET /v1/pvs/live/all.""" + + values: dict[str, PVCacheEntryResponse] + count: int + + +class CacheStatusResponse(BaseModel): + """Response from GET /v1/pvs/cache/status.""" + + cachedPvCount: int + status: str # "connected" | "disconnected" + error: str | None = None + + +class EpicsTestResponse(BaseModel): + """Response from GET /v1/pvs/test-epics.""" + + pv: str + connected: bool + value: Any | None = None + error: str | None = None + environment: dict[str, str] diff --git a/app/schemas/tag.py b/app/schemas/tag.py index 564a962..1fbceac 100644 --- a/app/schemas/tag.py +++ b/app/schemas/tag.py @@ -60,3 +60,10 @@ class TagGroupDTO(TagGroupBase): class Config: from_attributes = True + + +class AddTagResponse(BaseModel): + """Response from POST /v1/tags/{group_id}/tags.""" + + group: TagGroupDTO + wasCreated: bool diff --git a/tests/conftest.py b/tests/conftest.py index afc25d9..ea0407a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -151,7 +151,7 @@ async def sample_tag_group(client: AsyncClient) -> dict: """Create a sample tag group for testing.""" response = await client.post("/v1/tags", json={"name": "Test Location", "description": "Test location tags"}) assert response.status_code == 200 - return response.json()["payload"] + return response.json() @pytest_asyncio.fixture @@ -162,7 +162,7 @@ async def sample_tag(client: AsyncClient, sample_tag_group: dict) -> tuple[dict, f"/v1/tags/{group_id}/tags", json={"name": "Building-A", "description": "Building A location"} ) assert response.status_code == 200 - group = response.json()["payload"]["group"] + group = response.json()["group"] tag = group["tags"][0] return group, tag @@ -183,7 +183,7 @@ async def sample_pv(client: AsyncClient) -> dict: }, ) assert response.status_code == 200 - return response.json()["payload"] + return response.json() @pytest_asyncio.fixture @@ -203,7 +203,7 @@ async def sample_pvs(client: AsyncClient) -> list[dict]: ] response = await client.post("/v1/pvs/multi", json=pvs_data) assert response.status_code == 200 - return response.json()["payload"] + return response.json() @pytest_asyncio.fixture @@ -222,4 +222,4 @@ async def sample_snapshot(client: AsyncClient, sample_pvs: list[dict], mock_epic json={"title": "Test Snapshot", "description": "Snapshot for unit tests"}, ) assert response.status_code == 200 - return response.json()["payload"] + return response.json() diff --git a/tests/test_api/test_api_keys.py b/tests/test_api/test_api_keys.py index 75ebbc0..4a99da3 100644 --- a/tests/test_api/test_api_keys.py +++ b/tests/test_api/test_api_keys.py @@ -15,10 +15,10 @@ async def _create_key( client: AsyncClient, app_name: str = "TestApp", *, read: bool = True, write: bool = False ) -> dict: - """Helper to create an API key and return the response payload.""" + """Helper to create an API key and return the response body.""" response = await client.post("/v1/api-keys", json={"appName": app_name, "readAccess": read, "writeAccess": write}) assert response.status_code == 200 - return response.json()["payload"] + return response.json() # --------------------------------------------------------------------------- @@ -67,9 +67,7 @@ async def test_create_key_duplicate_app_name_returns_409(self, client: AsyncClie ) assert response.status_code == 409 - data = response.json() - assert data["errorCode"] == 409 - assert "DuplicateApp" in data["errorMessage"] + assert "DuplicateApp" in response.json()["detail"] @pytest.mark.asyncio async def test_create_key_missing_required_field_returns_422(self, client: AsyncClient): @@ -93,7 +91,7 @@ async def test_list_empty_returns_empty_list(self, client: AsyncClient): response = await client.get("/v1/api-keys") assert response.status_code == 200 - assert response.json()["payload"] == [] + assert response.json() == [] @pytest.mark.asyncio async def test_list_returns_created_keys(self, client: AsyncClient): @@ -104,7 +102,7 @@ async def test_list_returns_created_keys(self, client: AsyncClient): response = await client.get("/v1/api-keys") assert response.status_code == 200 - data = response.json()["payload"] + data = response.json() assert len(data) == 2 app_names = {k["appName"] for k in data} assert app_names == {"App1", "App2"} @@ -115,7 +113,7 @@ async def test_list_does_not_include_token(self, client: AsyncClient): await _create_key(client) response = await client.get("/v1/api-keys") - data = response.json()["payload"] + data = response.json() for key in data: assert "token" not in key @@ -134,7 +132,7 @@ async def test_list_active_only_excludes_inactive(self, client: AsyncClient): response = await client.get("/v1/api-keys", params={"active_only": "true"}) assert response.status_code == 200 - data = response.json()["payload"] + data = response.json() ids = [k["id"] for k in data] assert key1["id"] in ids assert key2["id"] not in ids @@ -148,7 +146,7 @@ async def test_list_without_active_only_includes_inactive(self, client: AsyncCli response = await client.get("/v1/api-keys") assert response.status_code == 200 - data = response.json()["payload"] + data = response.json() assert any(k["id"] == key["id"] for k in data) @@ -166,7 +164,7 @@ async def test_count_empty(self, client: AsyncClient): response = await client.get("/v1/api-keys/count") assert response.status_code == 200 - assert response.json()["payload"] == 0 + assert response.json() == 0 @pytest.mark.asyncio async def test_count_reflects_created_keys(self, client: AsyncClient): @@ -177,7 +175,7 @@ async def test_count_reflects_created_keys(self, client: AsyncClient): response = await client.get("/v1/api-keys/count") assert response.status_code == 200 - assert response.json()["payload"] == 2 + assert response.json() == 2 @pytest.mark.asyncio async def test_count_active_only_excludes_inactive(self, client: AsyncClient): @@ -189,7 +187,7 @@ async def test_count_active_only_excludes_inactive(self, client: AsyncClient): response = await client.get("/v1/api-keys/count", params={"active_only": "true"}) assert response.status_code == 200 - assert response.json()["payload"] == 1 + assert response.json() == 1 @pytest.mark.asyncio async def test_count_total_includes_inactive(self, client: AsyncClient): @@ -200,7 +198,7 @@ async def test_count_total_includes_inactive(self, client: AsyncClient): response = await client.get("/v1/api-keys/count") assert response.status_code == 200 - assert response.json()["payload"] == 1 + assert response.json() == 1 # --------------------------------------------------------------------------- @@ -220,7 +218,7 @@ async def test_deactivate_sets_is_active_false(self, client: AsyncClient): response = await client.delete(f"/v1/api-keys/{key_id}") assert response.status_code == 200 - data = response.json()["payload"] + data = response.json() assert data["id"] == key_id assert data["isActive"] is False @@ -231,7 +229,7 @@ async def test_deactivated_key_still_retrievable(self, client: AsyncClient): await client.delete(f"/v1/api-keys/{key['id']}") response = await client.get("/v1/api-keys") - data = response.json()["payload"] + data = response.json() deactivated = next((k for k in data if k["id"] == key["id"]), None) assert deactivated is not None @@ -243,8 +241,6 @@ async def test_deactivate_nonexistent_key_returns_404(self, client: AsyncClient) response = await client.delete("/v1/api-keys/00000000-0000-0000-0000-000000000000") assert response.status_code == 404 - data = response.json() - assert data["errorCode"] == 404 @pytest.mark.asyncio async def test_deactivate_already_inactive_key_returns_409(self, client: AsyncClient): @@ -256,8 +252,6 @@ async def test_deactivate_already_inactive_key_returns_409(self, client: AsyncCl response = await client.delete(f"/v1/api-keys/{key['id']}") assert response.status_code == 409 - data = response.json() - assert data["errorCode"] == 409 @pytest.mark.asyncio async def test_deactivate_allows_reuse_of_app_name(self, client: AsyncClient): @@ -270,6 +264,6 @@ async def test_deactivate_allows_reuse_of_app_name(self, client: AsyncClient): ) assert response.status_code == 200 - new_key = response.json()["payload"] + new_key = response.json() assert new_key["id"] != key["id"] assert new_key["isActive"] is True diff --git a/tests/test_api/test_pvs.py b/tests/test_api/test_pvs.py index 15fbfea..edbe415 100644 --- a/tests/test_api/test_pvs.py +++ b/tests/test_api/test_pvs.py @@ -17,10 +17,9 @@ async def test_create_pv_with_setpoint(self, client: AsyncClient): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["setpointAddress"] == "CREATE:TEST:SP" - assert data["payload"]["id"] is not None - assert data["payload"]["device"] == "TEST-DEVICE" + assert data["setpointAddress"] == "CREATE:TEST:SP" + assert data["id"] is not None + assert data["device"] == "TEST-DEVICE" @pytest.mark.asyncio async def test_create_pv_with_all_addresses(self, client: AsyncClient): @@ -42,13 +41,11 @@ async def test_create_pv_with_all_addresses(self, client: AsyncClient): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - payload = data["payload"] - assert payload["setpointAddress"] == "FULL:TEST:SP" - assert payload["readbackAddress"] == "FULL:TEST:RB" - assert payload["configAddress"] == "FULL:TEST:CFG" - assert payload["absTolerance"] == 0.5 - assert payload["relTolerance"] == 0.05 + assert data["setpointAddress"] == "FULL:TEST:SP" + assert data["readbackAddress"] == "FULL:TEST:RB" + assert data["configAddress"] == "FULL:TEST:CFG" + assert data["absTolerance"] == 0.5 + assert data["relTolerance"] == 0.05 @pytest.mark.asyncio async def test_create_pv_requires_at_least_one_address(self, client: AsyncClient): @@ -67,9 +64,7 @@ async def test_create_pv_duplicate_address_fails(self, client: AsyncClient, samp ) assert response.status_code == 409 - data = response.json() - assert data["errorCode"] == 409 - assert "already exists" in data["errorMessage"] + assert "already exists" in response.json()["detail"] @pytest.mark.asyncio async def test_create_pv_with_tags(self, client: AsyncClient, sample_tag: tuple): @@ -81,9 +76,8 @@ async def test_create_pv_with_tags(self, client: AsyncClient, sample_tag: tuple) assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]["tags"]) == 1 - assert data["payload"]["tags"][0]["id"] == tag["id"] + assert len(data["tags"]) == 1 + assert data["tags"][0]["id"] == tag["id"] class TestPVBulkCreate: @@ -97,9 +91,7 @@ async def test_bulk_create_pvs(self, client: AsyncClient): response = await client.post("/v1/pvs/multi", json=pvs_data) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]) == 10 + assert len(response.json()) == 10 @pytest.mark.asyncio async def test_bulk_create_empty_list(self, client: AsyncClient): @@ -107,8 +99,7 @@ async def test_bulk_create_empty_list(self, client: AsyncClient): response = await client.post("/v1/pvs/multi", json=[]) assert response.status_code == 200 - data = response.json() - assert data["payload"] == [] + assert response.json() == [] @pytest.mark.asyncio async def test_bulk_create_allows_blank_optional_addresses(self, client: AsyncClient): @@ -122,12 +113,11 @@ async def test_bulk_create_allows_blank_optional_addresses(self, client: AsyncCl assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]) == 2 - assert data["payload"][0]["readbackAddress"] is None - assert data["payload"][0]["configAddress"] is None - assert data["payload"][1]["readbackAddress"] is None - assert data["payload"][1]["configAddress"] is None + assert len(data) == 2 + assert data[0]["readbackAddress"] is None + assert data[0]["configAddress"] is None + assert data[1]["readbackAddress"] is None + assert data[1]["configAddress"] is None @pytest.mark.asyncio async def test_bulk_create_allows_duplicate_readback(self, client: AsyncClient): @@ -140,9 +130,7 @@ async def test_bulk_create_allows_duplicate_readback(self, client: AsyncClient): response = await client.post("/v1/pvs/multi", json=pvs_data) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]) == 2 + assert len(response.json()) == 2 class TestPVSearch: @@ -154,9 +142,7 @@ async def test_search_pvs_simple(self, client: AsyncClient, sample_pvs: list): response = await client.get("/v1/pvs", params={"pvName": "TEST:PV"}) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]) >= 1 + assert len(response.json()) >= 1 @pytest.mark.asyncio async def test_search_pvs_paged(self, client: AsyncClient, sample_pvs: list): @@ -165,10 +151,9 @@ async def test_search_pvs_paged(self, client: AsyncClient, sample_pvs: list): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]["results"]) == 2 - assert data["payload"]["totalCount"] >= 5 - assert data["payload"]["continuationToken"] is not None + assert len(data["results"]) == 2 + assert data["totalCount"] >= 5 + assert data["continuationToken"] is not None @pytest.mark.asyncio async def test_search_pvs_pagination_continuation(self, client: AsyncClient, sample_pvs: list): @@ -176,7 +161,7 @@ async def test_search_pvs_pagination_continuation(self, client: AsyncClient, sam # First page response1 = await client.get("/v1/pvs/paged", params={"pvName": "TEST", "pageSize": 2}) data1 = response1.json() - token = data1["payload"]["continuationToken"] + token = data1["continuationToken"] # Second page response2 = await client.get( @@ -185,8 +170,8 @@ async def test_search_pvs_pagination_continuation(self, client: AsyncClient, sam data2 = response2.json() # Ensure different results - ids1 = {p["id"] for p in data1["payload"]["results"]} - ids2 = {p["id"] for p in data2["payload"]["results"]} + ids1 = {p["id"] for p in data1["results"]} + ids2 = {p["id"] for p in data2["results"]} assert ids1.isdisjoint(ids2) # No overlap @pytest.mark.asyncio @@ -196,9 +181,8 @@ async def test_search_pvs_no_results(self, client: AsyncClient): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]["results"]) == 0 - assert data["payload"]["totalCount"] == 0 + assert len(data["results"]) == 0 + assert data["totalCount"] == 0 class TestPVUpdate: @@ -211,9 +195,7 @@ async def test_update_pv_description(self, client: AsyncClient, sample_pv: dict) response = await client.put(f"/v1/pvs/{pv_id}", json={"description": "Updated description"}) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["description"] == "Updated description" + assert response.json()["description"] == "Updated description" @pytest.mark.asyncio async def test_update_pv_tolerances(self, client: AsyncClient, sample_pv: dict): @@ -223,8 +205,8 @@ async def test_update_pv_tolerances(self, client: AsyncClient, sample_pv: dict): assert response.status_code == 200 data = response.json() - assert data["payload"]["absTolerance"] == 1.0 - assert data["payload"]["relTolerance"] == 0.1 + assert data["absTolerance"] == 1.0 + assert data["relTolerance"] == 0.1 @pytest.mark.asyncio async def test_update_pv_not_found(self, client: AsyncClient): @@ -232,8 +214,6 @@ async def test_update_pv_not_found(self, client: AsyncClient): response = await client.put("/v1/pvs/nonexistent-id", json={"description": "Should fail"}) assert response.status_code == 404 - data = response.json() - assert data["errorCode"] == 404 class TestPVDelete: @@ -246,13 +226,11 @@ async def test_delete_pv(self, client: AsyncClient, sample_pv: dict): response = await client.delete(f"/v1/pvs/{pv_id}") assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert data["payload"] is True + assert response.json() is True # Verify deletion search_response = await client.get("/v1/pvs/paged", params={"pvName": sample_pv["setpointAddress"]}) - assert len(search_response.json()["payload"]["results"]) == 0 + assert len(search_response.json()["results"]) == 0 @pytest.mark.asyncio async def test_delete_pv_not_found(self, client: AsyncClient): @@ -260,5 +238,3 @@ async def test_delete_pv_not_found(self, client: AsyncClient): response = await client.delete("/v1/pvs/nonexistent-id") assert response.status_code == 404 - data = response.json() - assert data["errorCode"] == 404 diff --git a/tests/test_api/test_snapshots.py b/tests/test_api/test_snapshots.py index 798c529..808640b 100644 --- a/tests/test_api/test_snapshots.py +++ b/tests/test_api/test_snapshots.py @@ -28,10 +28,9 @@ async def test_create_snapshot(self, client: AsyncClient, sample_pvs: list[dict] assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["title"] == "Test Snapshot Creation" - assert data["payload"]["pvCount"] >= len(sample_pvs) - assert data["payload"]["id"] is not None + assert data["title"] == "Test Snapshot Creation" + assert data["pvCount"] >= len(sample_pvs) + assert data["id"] is not None @pytest.mark.asyncio async def test_create_snapshot_with_disconnected_pvs( @@ -48,8 +47,6 @@ async def test_create_snapshot_with_disconnected_pvs( ) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 # Should still create snapshot even with some missing values @pytest.mark.asyncio @@ -71,11 +68,10 @@ async def test_get_snapshot_by_id(self, client: AsyncClient, sample_snapshot: di assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["id"] == snapshot_id - assert data["payload"]["title"] == sample_snapshot["title"] - assert "pvValues" in data["payload"] - assert len(data["payload"]["pvValues"]) > 0 + assert data["id"] == snapshot_id + assert data["title"] == sample_snapshot["title"] + assert "pvValues" in data + assert len(data["pvValues"]) > 0 @pytest.mark.asyncio async def test_get_snapshot_not_found(self, client: AsyncClient): @@ -83,8 +79,6 @@ async def test_get_snapshot_not_found(self, client: AsyncClient): response = await client.get("/v1/snapshots/nonexistent-id") assert response.status_code == 404 - data = response.json() - assert data["errorCode"] == 404 @pytest.mark.asyncio async def test_list_snapshots(self, client: AsyncClient, sample_snapshot: dict): @@ -92,9 +86,7 @@ async def test_list_snapshots(self, client: AsyncClient, sample_snapshot: dict): response = await client.get("/v1/snapshots") assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]) >= 1 + assert len(response.json()) >= 1 @pytest.mark.asyncio async def test_list_snapshots_with_filter(self, client: AsyncClient, sample_snapshot: dict): @@ -102,9 +94,7 @@ async def test_list_snapshots_with_filter(self, client: AsyncClient, sample_snap response = await client.get("/v1/snapshots", params={"title": "Test"}) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]) >= 1 + assert len(response.json()) >= 1 class TestSnapshotRestore: @@ -118,10 +108,9 @@ async def test_restore_snapshot(self, client: AsyncClient, sample_snapshot: dict assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert "successCount" in data["payload"] - assert "failureCount" in data["payload"] - assert data["payload"]["failureCount"] == 0 + assert "successCount" in data + assert "failureCount" in data + assert data["failureCount"] == 0 @pytest.mark.asyncio async def test_restore_snapshot_partial( @@ -134,10 +123,8 @@ async def test_restore_snapshot_partial( response = await client.post(f"/v1/snapshots/{snapshot_id}/restore?async=false", json={"pvIds": pv_ids}) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 # Only 2 PVs should be restored - assert data["payload"]["totalPVs"] <= 2 + assert response.json()["totalPVs"] <= 2 @pytest.mark.asyncio async def test_restore_snapshot_not_found(self, client: AsyncClient): @@ -168,7 +155,7 @@ async def test_compare_snapshots_identical( params={"async": "false", "use_cache": "false"}, json={"title": "Snapshot 1"}, ) - snap1_id = resp1.json()["payload"]["id"] + snap1_id = resp1.json()["id"] # Create second snapshot (same values) resp2 = await client.post( @@ -176,16 +163,15 @@ async def test_compare_snapshots_identical( params={"async": "false", "use_cache": "false"}, json={"title": "Snapshot 2"}, ) - snap2_id = resp2.json()["payload"]["id"] + snap2_id = resp2.json()["id"] # Compare response = await client.get(f"/v1/snapshots/{snap1_id}/compare/{snap2_id}") assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["differenceCount"] == 0 - assert data["payload"]["matchCount"] > 0 + assert data["differenceCount"] == 0 + assert data["matchCount"] > 0 @pytest.mark.asyncio async def test_compare_snapshots_different( @@ -205,7 +191,7 @@ async def test_compare_snapshots_different( params={"async": "false", "use_cache": "false"}, json={"title": "Before Change"}, ) - snap1_id = resp1.json()["payload"]["id"] + snap1_id = resp1.json()["id"] # Change values significantly (beyond default tolerance) for pv in sample_pvs: @@ -220,15 +206,13 @@ async def test_compare_snapshots_different( params={"async": "false", "use_cache": "false"}, json={"title": "After Change"}, ) - snap2_id = resp2.json()["payload"]["id"] + snap2_id = resp2.json()["id"] # Compare response = await client.get(f"/v1/snapshots/{snap1_id}/compare/{snap2_id}") assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["differenceCount"] > 0 + assert response.json()["differenceCount"] > 0 @pytest.mark.asyncio async def test_compare_snapshots_not_found(self, client: AsyncClient, sample_snapshot: dict): @@ -249,9 +233,7 @@ async def test_delete_snapshot(self, client: AsyncClient, sample_snapshot: dict) response = await client.delete(f"/v1/snapshots/{snapshot_id}") assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert data["payload"] is True + assert response.json() is True # Verify deletion get_response = await client.get(f"/v1/snapshots/{snapshot_id}") diff --git a/tests/test_api/test_tags.py b/tests/test_api/test_tags.py index 058be6f..b42c8d3 100644 --- a/tests/test_api/test_tags.py +++ b/tests/test_api/test_tags.py @@ -16,11 +16,10 @@ async def test_create_tag_group(self, client: AsyncClient): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["name"] == "Location" - assert data["payload"]["description"] == "Physical location tags" - assert data["payload"]["id"] is not None - assert data["payload"]["tags"] == [] + assert data["name"] == "Location" + assert data["description"] == "Physical location tags" + assert data["id"] is not None + assert data["tags"] == [] @pytest.mark.asyncio async def test_create_tag_group_without_description(self, client: AsyncClient): @@ -29,9 +28,8 @@ async def test_create_tag_group_without_description(self, client: AsyncClient): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["name"] == "System" - assert data["payload"]["description"] is None + assert data["name"] == "System" + assert data["description"] is None @pytest.mark.asyncio async def test_create_tag_group_duplicate_name_fails(self, client: AsyncClient, sample_tag_group: dict): @@ -39,9 +37,7 @@ async def test_create_tag_group_duplicate_name_fails(self, client: AsyncClient, response = await client.post("/v1/tags", json={"name": sample_tag_group["name"]}) # Duplicate name assert response.status_code == 409 - data = response.json() - assert data["errorCode"] == 409 - assert "already exists" in data["errorMessage"] + assert "already exists" in response.json()["detail"] @pytest.mark.asyncio async def test_create_tag_group_requires_name(self, client: AsyncClient): @@ -61,11 +57,10 @@ async def test_get_all_tag_groups(self, client: AsyncClient, sample_tag_group: d assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]) >= 1 + assert len(data) >= 1 # Check that summary includes tag count - group = next(g for g in data["payload"] if g["id"] == sample_tag_group["id"]) + group = next(g for g in data if g["id"] == sample_tag_group["id"]) assert "tagCount" in group @pytest.mark.asyncio @@ -76,11 +71,10 @@ async def test_get_tag_group_by_id(self, client: AsyncClient, sample_tag_group: assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 # Response is wrapped in array per frontend expectation - assert isinstance(data["payload"], list) - assert len(data["payload"]) == 1 - assert data["payload"][0]["id"] == group_id + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["id"] == group_id @pytest.mark.asyncio async def test_get_tag_group_not_found(self, client: AsyncClient): @@ -88,8 +82,6 @@ async def test_get_tag_group_not_found(self, client: AsyncClient): response = await client.get("/v1/tags/nonexistent-id") assert response.status_code == 404 - data = response.json() - assert data["errorCode"] == 404 class TestTagGroupUpdate: @@ -102,9 +94,7 @@ async def test_update_tag_group_name(self, client: AsyncClient, sample_tag_group response = await client.put(f"/v1/tags/{group_id}", json={"name": "Updated Location"}) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["name"] == "Updated Location" + assert response.json()["name"] == "Updated Location" @pytest.mark.asyncio async def test_update_tag_group_description(self, client: AsyncClient, sample_tag_group: dict): @@ -113,9 +103,7 @@ async def test_update_tag_group_description(self, client: AsyncClient, sample_ta response = await client.put(f"/v1/tags/{group_id}", json={"description": "Updated description"}) assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["description"] == "Updated description" + assert response.json()["description"] == "Updated description" @pytest.mark.asyncio async def test_update_tag_group_not_found(self, client: AsyncClient): @@ -135,9 +123,7 @@ async def test_delete_tag_group(self, client: AsyncClient, sample_tag_group: dic response = await client.delete(f"/v1/tags/{group_id}") assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 - assert data["payload"] is True + assert response.json() is True # Verify deletion get_response = await client.get(f"/v1/tags/{group_id}") @@ -164,9 +150,8 @@ async def test_add_tag_to_group(self, client: AsyncClient, sample_tag_group: dic assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert len(data["payload"]["group"]["tags"]) == 1 - assert data["payload"]["group"]["tags"][0]["name"] == "Building-A" + assert len(data["group"]["tags"]) == 1 + assert data["group"]["tags"][0]["name"] == "Building-A" @pytest.mark.asyncio async def test_add_multiple_tags_to_group(self, client: AsyncClient, sample_tag_group: dict): @@ -180,8 +165,7 @@ async def test_add_multiple_tags_to_group(self, client: AsyncClient, sample_tag_ response = await client.post(f"/v1/tags/{group_id}/tags", json={"name": "Tag-2"}) assert response.status_code == 200 - data = response.json() - assert len(data["payload"]["group"]["tags"]) == 2 + assert len(response.json()["group"]["tags"]) == 2 @pytest.mark.asyncio async def test_add_duplicate_tag_fails(self, client: AsyncClient, sample_tag: tuple): @@ -192,8 +176,7 @@ async def test_add_duplicate_tag_fails(self, client: AsyncClient, sample_tag: tu response = await client.post(f"/v1/tags/{group_id}/tags", json={"name": tag["name"]}) # Duplicate name assert response.status_code == 409 - data = response.json() - assert "already exists" in data["errorMessage"] + assert "already exists" in response.json()["detail"] @pytest.mark.asyncio async def test_add_duplicate_tag_with_skip_duplicates_succeeds(self, client: AsyncClient, sample_tag: tuple): @@ -208,9 +191,8 @@ async def test_add_duplicate_tag_with_skip_duplicates_succeeds(self, client: Asy assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["wasCreated"] is False # Tag already existed - assert data["payload"]["group"]["id"] == group_id + assert data["wasCreated"] is False # Tag already existed + assert data["group"]["id"] == group_id @pytest.mark.asyncio async def test_add_tag_to_nonexistent_group(self, client: AsyncClient): @@ -233,8 +215,7 @@ async def test_update_tag(self, client: AsyncClient, sample_tag: tuple): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - updated_tag = next(t for t in data["payload"]["tags"] if t["id"] == tag_id) + updated_tag = next(t for t in data["tags"] if t["id"] == tag_id) assert updated_tag["name"] == "Updated-Tag-Name" @pytest.mark.asyncio @@ -255,10 +236,8 @@ async def test_remove_tag(self, client: AsyncClient, sample_tag: tuple): response = await client.delete(f"/v1/tags/{group_id}/tags/{tag_id}") assert response.status_code == 200 - data = response.json() - assert data["errorCode"] == 0 # Tag should be removed from the group - assert len(data["payload"]["tags"]) == 0 + assert len(response.json()["tags"]) == 0 @pytest.mark.asyncio async def test_remove_tag_not_found(self, client: AsyncClient, sample_tag_group: dict): @@ -287,11 +266,10 @@ async def test_bulk_import_tags_creates_new_groups_and_tags(self, client: AsyncC assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["groupsCreated"] == 2 - assert data["payload"]["tagsCreated"] == 5 - assert data["payload"]["tagsSkipped"] == 0 - assert len(data["payload"]["warnings"]) == 0 + assert data["groupsCreated"] == 2 + assert data["tagsCreated"] == 5 + assert data["tagsSkipped"] == 0 + assert len(data["warnings"]) == 0 @pytest.mark.asyncio async def test_bulk_import_tags_skips_duplicates(self, client: AsyncClient, sample_tag: tuple): @@ -310,12 +288,11 @@ async def test_bulk_import_tags_skips_duplicates(self, client: AsyncClient, samp assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["groupsCreated"] == 0 # Group already exists - assert data["payload"]["tagsCreated"] == 2 # Only new tags created - assert data["payload"]["tagsSkipped"] == 1 # Duplicate tag skipped - assert len(data["payload"]["warnings"]) == 1 - assert "already exists" in data["payload"]["warnings"][0] + assert data["groupsCreated"] == 0 # Group already exists + assert data["tagsCreated"] == 2 # Only new tags created + assert data["tagsSkipped"] == 1 # Duplicate tag skipped + assert len(data["warnings"]) == 1 + assert "already exists" in data["warnings"][0] @pytest.mark.asyncio async def test_bulk_import_tags_with_existing_group(self, client: AsyncClient, sample_tag_group: dict): @@ -333,10 +310,9 @@ async def test_bulk_import_tags_with_existing_group(self, client: AsyncClient, s assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["groupsCreated"] == 0 # Group already exists - assert data["payload"]["tagsCreated"] == 2 - assert data["payload"]["tagsSkipped"] == 0 + assert data["groupsCreated"] == 0 # Group already exists + assert data["tagsCreated"] == 2 + assert data["tagsSkipped"] == 0 @pytest.mark.asyncio async def test_bulk_import_tags_empty_groups(self, client: AsyncClient): @@ -345,7 +321,6 @@ async def test_bulk_import_tags_empty_groups(self, client: AsyncClient): assert response.status_code == 200 data = response.json() - assert data["errorCode"] == 0 - assert data["payload"]["groupsCreated"] == 0 - assert data["payload"]["tagsCreated"] == 0 - assert data["payload"]["tagsSkipped"] == 0 + assert data["groupsCreated"] == 0 + assert data["tagsCreated"] == 0 + assert data["tagsSkipped"] == 0 diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index b545292..34dde46 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from fastapi import status +from fastapi import HTTPException, status from fastapi.exceptions import WebSocketException from app.dependencies import ( @@ -19,7 +19,6 @@ ws_require_read_access, ws_require_write_access, ) -from app.api.responses import APIException from app.schemas.api_key import ApiKeyDTO from app.services.pv_service import PVService from app.services.tag_service import TagService @@ -114,10 +113,9 @@ class TestGetApiKey: async def test_no_header_raises_401(self): """Missing header (None) should raise 401 immediately.""" db = MagicMock() - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: await get_api_key(db, None) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED - assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_unknown_token_raises_401(self): @@ -125,10 +123,9 @@ async def test_unknown_token_raises_401(self): db = MagicMock() with patch("app.dependencies.ApiKeyService") as MockService: MockService.return_value.get_by_token = AsyncMock(return_value=None) - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: await get_api_key(db, "sq_unknown") assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED - assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_inactive_key_raises_401(self): @@ -137,10 +134,9 @@ async def test_inactive_key_raises_401(self): inactive_key = _make_api_key_dto(isActive=False) with patch("app.dependencies.ApiKeyService") as MockService: MockService.return_value.get_by_token = AsyncMock(return_value=inactive_key) - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: await get_api_key(db, "sq_inactive") assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED - assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_active_key_returns_dto(self): @@ -156,9 +152,9 @@ async def test_active_key_returns_dto(self): async def test_error_message_mentions_api_key(self): """401 error message should reference the API key.""" db = MagicMock() - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: await get_api_key(db, None) - assert "api key" in exc_info.value.error_message.lower() + assert "api key" in exc_info.value.detail.lower() @pytest.mark.asyncio async def test_service_receives_provided_token(self): @@ -168,7 +164,7 @@ async def test_service_receives_provided_token(self): with patch("app.dependencies.ApiKeyService") as MockService: mock_get = AsyncMock(return_value=None) MockService.return_value.get_by_token = mock_get - with pytest.raises(APIException): + with pytest.raises(HTTPException): await get_api_key(db, token) mock_get.assert_awaited_once_with(token) @@ -184,10 +180,9 @@ class TestRequireReadAccess: def test_raises_401_when_read_access_false(self): """A key without read access should be rejected.""" dto = _make_api_key_dto(readAccess=False) - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: require_read_access(dto) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED - assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED def test_passes_when_read_access_true(self): """A key with read access should not raise.""" @@ -197,9 +192,9 @@ def test_passes_when_read_access_true(self): def test_error_message_mentions_read(self): """Error message should indicate lack of read access.""" dto = _make_api_key_dto(readAccess=False) - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: require_read_access(dto) - assert "read" in exc_info.value.error_message.lower() + assert "read" in exc_info.value.detail.lower() # --------------------------------------------------------------------------- @@ -213,10 +208,9 @@ class TestRequireWriteAccess: def test_raises_401_when_write_access_false(self): """A key without write access should be rejected.""" dto = _make_api_key_dto(writeAccess=False) - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: require_write_access(dto) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED - assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED def test_passes_when_write_access_true(self): """A key with write access should not raise.""" @@ -226,9 +220,9 @@ def test_passes_when_write_access_true(self): def test_error_message_mentions_write(self): """Error message should indicate lack of write access.""" dto = _make_api_key_dto(writeAccess=False) - with pytest.raises(APIException) as exc_info: + with pytest.raises(HTTPException) as exc_info: require_write_access(dto) - assert "write" in exc_info.value.error_message.lower() + assert "write" in exc_info.value.detail.lower() # ---------------------------------------------------------------------------