diff --git a/app/api/v1/api_keys.py b/app/api/v1/api_keys.py index f3db670..3aca68c 100644 --- a/app/api/v1/api_keys.py +++ b/app/api/v1/api_keys.py @@ -1,30 +1,30 @@ -"""API endpoints for job status monitoring.""" +"""API endpoints for API key management.""" from fastapi import Depends, APIRouter, status -from sqlalchemy.ext.asyncio import AsyncSession -from app.db.session import get_db -from app.dependencies import require_read_access, require_write_access +from app.dependencies import ApiKeyServiceDep, require_read_access, require_write_access from app.api.responses import APIException from app.schemas.common import ApiResultResponse from app.schemas.api_key import ApiKeyCreateDTO -from app.services.api_key_service import ApiKeyService router = APIRouter(prefix="/api-keys", tags=["ApiKeys"]) @router.get("", dependencies=[Depends(require_read_access)]) -async def list_all_keys(active_only: bool = False, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def list_all_keys( + service: ApiKeyServiceDep, + active_only: bool = False, +) -> ApiResultResponse: """List all API Keys, optionally filtered by active status.""" - service = ApiKeyService(db) keys = await service.list_keys(active_only) - key_response = ApiResultResponse(errorCode=0, errorMessage=None, payload=keys) - return key_response + return ApiResultResponse(errorCode=0, errorMessage=None, payload=keys) @router.post("", dependencies=[Depends(require_write_access)]) -async def create_api_key(data: ApiKeyCreateDTO, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def create_api_key( + data: ApiKeyCreateDTO, + service: ApiKeyServiceDep, +) -> ApiResultResponse: """Create a new API Key.""" - service = ApiKeyService(db) try: new_key = await service.create_key(data) return ApiResultResponse(errorCode=0, errorMessage=None, payload=new_key) @@ -33,9 +33,11 @@ async def create_api_key(data: ApiKeyCreateDTO, db: AsyncSession = Depends(get_d @router.delete("/{key_id}", dependencies=[Depends(require_write_access)]) -async def deactivate_api_key(key_id: str, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def deactivate_api_key( + key_id: str, + service: ApiKeyServiceDep, +) -> ApiResultResponse: """Deactivate an API Key by ID.""" - service = ApiKeyService(db) try: deactivated_key = await service.deactivate_key(key_id) except LookupError as e: @@ -46,8 +48,10 @@ async def deactivate_api_key(key_id: str, db: AsyncSession = Depends(get_db)) -> @router.get("/count", dependencies=[Depends(require_read_access)]) -async def get_api_key_count(active_only: bool = False, db: AsyncSession = Depends(get_db)) -> ApiResultResponse: +async def get_api_key_count( + service: ApiKeyServiceDep, + active_only: bool = False, +) -> ApiResultResponse: """Get the current number of API Keys.""" - service = ApiKeyService(db) count = await service.get_count(active_only) return ApiResultResponse(errorCode=0, errorMessage=None, payload=count) diff --git a/app/api/v1/health.py b/app/api/v1/health.py index 42778d3..9d31341 100644 --- a/app/api/v1/health.py +++ b/app/api/v1/health.py @@ -11,7 +11,7 @@ from fastapi import Security, APIRouter, HTTPException -from app.dependencies import require_read_access, require_write_access +from app.dependencies import RedisServiceDep, require_read_access, require_write_access from app.api.responses import success_response from app.schemas.health import ( HealthSummaryResponse, @@ -20,13 +20,12 @@ ) from app.services.watchdog import get_watchdog from app.services.pv_monitor import get_pv_monitor -from app.services.redis_service import get_redis_service router = APIRouter(prefix="/health", tags=["Health"]) @router.get("/heartbeat") -async def get_heartbeat() -> dict: +async def get_heartbeat(redis: RedisServiceDep) -> dict: """ Simple heartbeat check for frontend polling. @@ -35,8 +34,6 @@ async def get_heartbeat() -> dict: banner if alive=false. """ try: - redis = get_redis_service() - # Check if Redis is connected if not redis.is_connected(): return success_response( @@ -72,14 +69,13 @@ async def get_heartbeat() -> dict: @router.get("/monitor", dependencies=[Security(require_read_access)]) -async def get_monitor_health() -> MonitorHealthResponse: +async def get_monitor_health(redis: RedisServiceDep) -> MonitorHealthResponse: """ Get detailed monitor health information. Includes connection counts, monitor status, and watchdog status. """ try: - redis = get_redis_service() pv_monitor = get_pv_monitor() watchdog = get_watchdog() @@ -166,7 +162,7 @@ async def force_watchdog_check() -> WatchdogStatsResponse: @router.get("/summary", dependencies=[Security(require_read_access)]) -async def get_health_summary() -> HealthSummaryResponse: +async def get_health_summary(redis: RedisServiceDep) -> HealthSummaryResponse: """ Get a complete health summary for monitoring dashboards. @@ -175,7 +171,6 @@ async def get_health_summary() -> HealthSummaryResponse: issues = [] try: - redis = get_redis_service() pv_monitor = get_pv_monitor() watchdog = get_watchdog() @@ -253,14 +248,13 @@ async def get_health_summary() -> HealthSummaryResponse: @router.get("/disconnected", dependencies=[Security(require_read_access)]) -async def get_disconnected_pvs() -> dict: +async def get_disconnected_pvs(redis: RedisServiceDep) -> dict: """ Get list of all disconnected PVs. Useful for diagnostics and debugging connection issues. """ try: - redis = get_redis_service() disconnected = await redis.get_disconnected_pvs() return { @@ -273,7 +267,10 @@ async def get_disconnected_pvs() -> dict: @router.get("/stale", dependencies=[Security(require_read_access)]) -async def get_stale_pvs(max_age_seconds: float = 300) -> dict: +async def get_stale_pvs( + redis: RedisServiceDep, + max_age_seconds: float = 300, +) -> dict: """ Get list of stale PVs (connected but not updated recently). @@ -281,7 +278,6 @@ async def get_stale_pvs(max_age_seconds: float = 300) -> dict: max_age_seconds: Consider stale if not updated in this many seconds """ try: - redis = get_redis_service() stale = await redis.get_stale_pvs(max_age_seconds=max_age_seconds) return { @@ -386,7 +382,7 @@ async def force_open_circuit(circuit_name: str) -> dict: @router.get("/monitor/status", dependencies=[Security(require_read_access)]) -async def monitor_process_status() -> dict: +async def monitor_process_status(redis: RedisServiceDep) -> dict: """ Check if the separate PV Monitor process is alive via Redis heartbeat. @@ -399,8 +395,6 @@ async def monitor_process_status() -> dict: leader: Instance ID of current monitor leader (if available) """ try: - redis = get_redis_service() - if not redis.is_connected(): return { "status": "unknown", diff --git a/app/api/v1/jobs.py b/app/api/v1/jobs.py index ea67311..82c62ad 100644 --- a/app/api/v1/jobs.py +++ b/app/api/v1/jobs.py @@ -1,24 +1,20 @@ """API endpoints for job status monitoring.""" -from fastapi import Depends, Security, APIRouter -from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import Security, APIRouter -from app.db.session import get_db -from app.dependencies import require_read_access +from app.dependencies import JobServiceDep, require_read_access from app.api.responses import APIException, success_response -from app.services.job_service import JobService router = APIRouter(prefix="/jobs", tags=["Jobs"]) @router.get("/{job_id}", dependencies=[Security(require_read_access)]) -async def get_job_status(job_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def get_job_status(job_id: str, job_service: JobServiceDep) -> dict: """ Get the status of a background job. Returns the current status, progress percentage, and result when complete. Poll this endpoint to track the progress of async operations. """ - job_service = JobService(db) job = await job_service.get_job(job_id) if not job: raise APIException(404, f"Job {job_id} not found", 404) diff --git a/app/api/v1/pvs.py b/app/api/v1/pvs.py index 731178a..c127781 100644 --- a/app/api/v1/pvs.py +++ b/app/api/v1/pvs.py @@ -2,36 +2,40 @@ import json from uuid import UUID -from fastapi import Query, Depends, Security, APIRouter -from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import Query, Security, APIRouter -from app.db.session import get_db from app.schemas.pv import LivePVRequest, NewPVElementDTO, UpdatePVElementDTO -from app.dependencies import require_read_access, require_write_access +from app.dependencies import ( + DataBaseDep, + PVServiceDep, + EpicsServiceDep, + RedisServiceDep, + require_read_access, + require_write_access, +) from app.api.responses import APIException, success_response -from app.services.pv_service import PVService -from app.services.epics_service import get_epics_service -from app.services.redis_service import get_redis_service from app.repositories.pv_repository import PVRepository router = APIRouter(prefix="/pvs", tags=["PVs"]) @router.get("", dependencies=[Security(require_read_access)]) -async def search_pvs(pvName: str | None = Query(None), db: AsyncSession = Depends(get_db)) -> dict: +async def search_pvs( + service: PVServiceDep, + pvName: str | None = Query(None), +) -> dict: """Search PVs by name (non-paginated, for backward compatibility).""" - service = PVService(db) result = await service.search_paged(search=pvName, page_size=1000) return success_response(result.results) @router.get("/paged", dependencies=[Security(require_read_access)]) async def search_pvs_paged( + service: PVServiceDep, pvName: str | None = Query(None), pageSize: int = Query(100, ge=1, le=1000), continuationToken: str | None = Query(None), tagFilters: str | None = Query(None, description="JSON object: {groupId: [tagId1, tagId2], ...}"), - db: AsyncSession = Depends(get_db), ) -> dict: """ Search PVs with pagination and optional tag filtering. @@ -39,8 +43,6 @@ async def search_pvs_paged( Example tagFilters: {"group-1": ["tag-a", "tag-b"], "group-2": ["tag-c"]} This returns PVs that have (tag-a OR tag-b) AND (tag-c) """ - service = PVService(db) - # Parse tag filters from JSON string tag_filters = None if tagFilters: @@ -63,9 +65,8 @@ async def search_pvs_paged( @router.post("", dependencies=[Security(require_write_access)]) -async def create_pv(data: NewPVElementDTO, db: AsyncSession = Depends(get_db)) -> dict: +async def create_pv(data: NewPVElementDTO, service: PVServiceDep) -> dict: """Create a new PV.""" - service = PVService(db) try: pv = await service.create(data) return success_response(pv) @@ -74,9 +75,11 @@ async def create_pv(data: NewPVElementDTO, db: AsyncSession = Depends(get_db)) - @router.post("/multi", dependencies=[Security(require_write_access)]) -async def create_multiple_pvs(data: list[NewPVElementDTO], db: AsyncSession = Depends(get_db)) -> dict: +async def create_multiple_pvs( + data: list[NewPVElementDTO], + service: PVServiceDep, +) -> dict: """Bulk create PVs (for CSV import).""" - service = PVService(db) try: pvs = await service.create_many(data) return success_response(pvs) @@ -85,13 +88,16 @@ async def create_multiple_pvs(data: list[NewPVElementDTO], db: AsyncSession = De @router.put("/{pv_id}", dependencies=[Security(require_write_access)]) -async def update_pv(pv_id: str, data: UpdatePVElementDTO, db: AsyncSession = Depends(get_db)) -> dict: +async def update_pv( + pv_id: str, + data: UpdatePVElementDTO, + service: PVServiceDep, +) -> dict: """Update a PV.""" try: UUID(pv_id) except ValueError: raise APIException(404, f"PV {pv_id} not found", 404) - service = PVService(db) pv = await service.update(pv_id, data) if not pv: raise APIException(404, f"PV {pv_id} not found", 404) @@ -99,13 +105,12 @@ async def update_pv(pv_id: str, data: UpdatePVElementDTO, db: AsyncSession = Dep @router.delete("/{pv_id}", dependencies=[Security(require_write_access)]) -async def delete_pv(pv_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def delete_pv(pv_id: str, service: PVServiceDep) -> dict: """Delete a PV.""" try: UUID(pv_id) except ValueError: raise APIException(404, f"PV {pv_id} not found", 404) - service = PVService(db) success = await service.delete(pv_id) if not success: raise APIException(404, f"PV {pv_id} not found", 404) @@ -114,13 +119,15 @@ async def delete_pv(pv_id: str, db: AsyncSession = Depends(get_db)) -> dict: @router.get("/search", dependencies=[Security(require_read_access)]) async def search_pvs_filtered( + db: DataBaseDep, + service: PVServiceDep, + redis: RedisServiceDep, q: str | None = Query(None, description="Text search"), devices: list[str] = Query(default=[], description="Filter by device"), tags: list[str] = Query(default=[], description="Filter by tag IDs"), limit: int = Query(100, le=1000, description="Max results"), offset: int = Query(0, description="Offset for pagination"), include_live_values: bool = Query(False, description="Include Redis cache values"), - db: AsyncSession = Depends(get_db), ) -> dict: """ Server-side filtered search with optional live values. @@ -128,7 +135,6 @@ async def search_pvs_filtered( This is more efficient than client-side filtering for large datasets. """ pv_repo = PVRepository(db) - service = PVService(db) pvs, total = await pv_repo.search_filtered( search_term=q, devices=devices if devices else None, tag_ids=tags if tags else None, limit=limit, offset=offset @@ -142,7 +148,6 @@ async def search_pvs_filtered( # Optionally include live values from Redis if include_live_values: try: - redis = get_redis_service() pv_addresses = [] for pv in pvs: if pv.setpoint_address: @@ -159,7 +164,7 @@ async def search_pvs_filtered( @router.get("/devices", dependencies=[Security(require_read_access)]) -async def get_all_devices(db: AsyncSession = Depends(get_db)) -> dict: +async def get_all_devices(db: DataBaseDep) -> dict: """Get all unique device names for filtering.""" pv_repo = PVRepository(db) devices = await pv_repo.get_all_devices() @@ -167,10 +172,12 @@ async def get_all_devices(db: AsyncSession = Depends(get_db)) -> dict: @router.get("/live", dependencies=[Security(require_read_access)]) -async def get_live_values(pv_names: list[str] = Query(..., description="List of PV names to fetch")) -> dict: +async def get_live_values( + redis: RedisServiceDep, + pv_names: list[str] = Query(..., description="List of PV names to fetch"), +) -> dict: """Get current values from Redis cache (instant).""" try: - redis = get_redis_service() entries = await redis.get_pv_values_bulk(pv_names) # Convert PVCacheEntry objects to dicts for JSON serialization values = {pv_name: entry.to_dict() for pv_name, entry in entries.items()} @@ -180,10 +187,9 @@ async def get_live_values(pv_names: list[str] = Query(..., description="List of @router.post("/live", dependencies=[Security(require_read_access)]) -async def get_live_values_post(request: LivePVRequest) -> dict: +async def get_live_values_post(request: LivePVRequest, redis: RedisServiceDep) -> dict: """Get current values from Redis cache (instant) - POST version for large PV lists.""" try: - redis = get_redis_service() entries = await redis.get_pv_values_bulk(request.pv_names) # Convert PVCacheEntry objects to dicts for JSON serialization values = {pv_name: entry.to_dict() for pv_name, entry in entries.items()} @@ -193,10 +199,9 @@ async def get_live_values_post(request: LivePVRequest) -> dict: @router.get("/live/all", dependencies=[Security(require_read_access)]) -async def get_all_live_values() -> dict: +async def get_all_live_values(redis: RedisServiceDep) -> dict: """Get all cached PV values (for initial table load).""" try: - redis = get_redis_service() values = await redis.get_all_pv_values_as_dict() return success_response({"values": values, "count": len(values)}) except Exception as e: @@ -204,10 +209,9 @@ async def get_all_live_values() -> dict: @router.get("/cache/status", dependencies=[Security(require_read_access)]) -async def get_cache_status() -> dict: +async def get_cache_status(redis: RedisServiceDep) -> dict: """Get Redis cache status.""" try: - redis = get_redis_service() count = await redis.get_cached_pv_count() return success_response({"cachedPvCount": count, "status": "connected"}) except Exception as e: @@ -215,9 +219,13 @@ async def get_cache_status() -> dict: @router.get("/test-epics", dependencies=[Security(require_read_access)]) -async def test_epics_connection(pv: str = Query("KLYS:LI22:31:KVAC", description="PV name to test")) -> dict: +async def test_epics_connection( + epics: EpicsServiceDep, + pv: str = Query( + "KLYS:LI22:31:KVAC", description="PV name to test; Defaults to KLYS:LI22:31:KVAC, a SLAC-specific PV" + ), +) -> dict: """Test EPICS connectivity using aioca.""" - epics = get_epics_service() result = await epics.get_single(pv) return success_response( { diff --git a/app/api/v1/snapshots.py b/app/api/v1/snapshots.py index efe1e46..3d09fc8 100644 --- a/app/api/v1/snapshots.py +++ b/app/api/v1/snapshots.py @@ -2,22 +2,22 @@ from uuid import UUID from arq import create_pool -from fastapi import Query, Depends, Security, APIRouter, BackgroundTasks +from fastapi import Query, Security, APIRouter, BackgroundTasks from arq.connections import RedisSettings -from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings -from app.db.session import get_db from app.models.job import JobType from app.schemas.job import JobCreatedDTO -from app.dependencies import require_read_access, require_write_access +from app.dependencies import ( + DataBaseDep, + JobServiceDep, + SnapshotServiceDep, + require_read_access, + require_write_access, +) from app.api.responses import APIException, success_response from app.schemas.snapshot import NewSnapshotDTO, RestoreRequestDTO, UpdateSnapshotDTO -from app.services.job_service import JobService -from app.services.epics_service import EpicsService, get_epics_service -from app.services.redis_service import get_redis_service from app.services.background_tasks import run_snapshot_restore, run_snapshot_creation -from app.services.snapshot_service import SnapshotService logger = logging.getLogger(__name__) settings = get_settings() @@ -43,14 +43,12 @@ async def get_arq_pool(): @router.get("", dependencies=[Security(require_read_access)]) async def list_snapshots( + service: SnapshotServiceDep, title: str | None = Query(None), tags: list[str] | None = Query(None, description="Filter by tag IDs (returns snapshots containing PVs with any of these tags)"), - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), ) -> dict: """List all snapshots, optionally filtered by title and/or tags.""" - service = SnapshotService(db, epics) snapshots = await service.list_snapshots(title=title, tag_ids=tags) return success_response(snapshots) @@ -58,10 +56,9 @@ async def list_snapshots( @router.get("/{snapshot_id}", dependencies=[Security(require_read_access)]) async def get_snapshot( snapshot_id: str, + service: SnapshotServiceDep, limit: int | None = Query(None, description="Limit number of PV values returned"), offset: int = Query(0, description="Offset for pagination"), - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), ) -> dict: """ Get snapshot by ID with values. @@ -73,7 +70,6 @@ async def get_snapshot( UUID(snapshot_id) except ValueError: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - service = SnapshotService(db, epics) snapshot = await service.get_by_id(snapshot_id, limit=limit, offset=offset) if not snapshot: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) @@ -84,11 +80,12 @@ async def get_snapshot( async def create_snapshot( data: NewSnapshotDTO, background_tasks: BackgroundTasks, - db: AsyncSession = Depends(get_db), + db: DataBaseDep, + service: SnapshotServiceDep, + job_service: JobServiceDep, async_mode: bool = Query(True, alias="async", description="Run snapshot creation in background"), use_cache: bool = Query(True, description="Read from Redis cache (instant) vs direct EPICS read"), use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"), - epics: EpicsService = Depends(get_epics_service), ) -> dict: """ Create a new snapshot by reading all PVs. @@ -106,7 +103,6 @@ async def create_snapshot( """ if async_mode: # Create a job record - job_service = JobService(db) job = await job_service.create_job( JobType.SNAPSHOT_CREATE, job_data={"title": data.title, "description": data.description, "use_cache": use_cache}, @@ -152,9 +148,6 @@ async def create_snapshot( ) else: # Synchronous mode (legacy behavior) - redis = get_redis_service() - service = SnapshotService(db, epics, redis) - if use_cache: snapshot = await service.create_snapshot_from_cache(data) else: @@ -167,12 +160,9 @@ async def create_snapshot( async def update_snapshot( snapshot_id: str, data: UpdateSnapshotDTO, - db: AsyncSession = Depends(get_db), + service: SnapshotServiceDep, ) -> dict: """Update snapshot title and/or description.""" - epics = get_epics_service() - service = SnapshotService(db, epics) - snapshot = await service.update_snapshot_metadata( snapshot_id, title=data.title, @@ -189,11 +179,12 @@ async def update_snapshot( async def restore_snapshot( snapshot_id: str, background_tasks: BackgroundTasks, + db: DataBaseDep, + service: SnapshotServiceDep, + job_service: JobServiceDep, request: RestoreRequestDTO | None = None, - db: AsyncSession = Depends(get_db), async_mode: bool = Query(True, alias="async"), use_arq: bool = Query(True, description="Use Arq persistent queue (recommended) vs FastAPI BackgroundTasks"), - epics: EpicsService = Depends(get_epics_service), ) -> dict: """ Restore PV values from a snapshot to EPICS. @@ -204,7 +195,6 @@ async def restore_snapshot( UUID(snapshot_id) except ValueError: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - service = SnapshotService(db, epics) # Verify snapshot exists snapshot = await service.get_by_id(snapshot_id) @@ -212,7 +202,6 @@ async def restore_snapshot( raise APIException(404, f"Snapshot {snapshot_id} not found", 404) if async_mode: - job_service = JobService(db) job = await job_service.create_job( JobType.SNAPSHOT_RESTORE, job_data={"snapshotId": snapshot_id}, @@ -259,16 +248,14 @@ async def restore_snapshot( @router.delete("/{snapshot_id}", dependencies=[Security(require_write_access)]) async def delete_snapshot( snapshot_id: str, + service: SnapshotServiceDep, deleteData: bool = Query(True), - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), ) -> dict: """Delete a snapshot.""" try: UUID(snapshot_id) except ValueError: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) - service = SnapshotService(db, epics) success = await service.delete_snapshot(snapshot_id) if not success: raise APIException(404, f"Snapshot {snapshot_id} not found", 404) @@ -279,8 +266,7 @@ async def delete_snapshot( async def compare_snapshots( snapshot1_id: str, snapshot2_id: str, - db: AsyncSession = Depends(get_db), - epics: EpicsService = Depends(get_epics_service), + service: SnapshotServiceDep, ) -> dict: """Compare two snapshots and return differences.""" try: @@ -288,7 +274,6 @@ async def compare_snapshots( UUID(snapshot2_id) except ValueError: raise APIException(404, "Snapshot not found", 404) - service = SnapshotService(db, epics) try: result = await service.compare_snapshots(snapshot1_id, snapshot2_id) return success_response(result) diff --git a/app/api/v1/tags.py b/app/api/v1/tags.py index 3ff2dd7..87c026b 100644 --- a/app/api/v1/tags.py +++ b/app/api/v1/tags.py @@ -1,34 +1,29 @@ from uuid import UUID -from fastapi import Query, Depends, Security, APIRouter +from fastapi import Query, Security, APIRouter from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncSession -from app.db.session import get_db from app.schemas.tag import TagCreate, TagUpdate, TagGroupCreate, TagGroupUpdate -from app.dependencies import require_read_access, require_write_access +from app.dependencies import TagServiceDep, require_read_access, require_write_access from app.api.responses import APIException, success_response -from app.services.tag_service import TagService router = APIRouter(prefix="/tags", tags=["Tags"]) @router.get("", dependencies=[Security(require_read_access)]) -async def get_all_tag_groups(db: AsyncSession = Depends(get_db)) -> dict: +async def get_all_tag_groups(service: TagServiceDep) -> dict: """Get all tag groups with tag counts.""" - service = TagService(db) groups = await service.get_all_groups_summary() return success_response(groups) @router.get("/{group_id}", dependencies=[Security(require_read_access)]) -async def get_tag_group(group_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def get_tag_group(group_id: str, service: TagServiceDep) -> dict: """Get tag group by ID with all tags.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) group = await service.get_group_by_id(group_id) if not group: raise APIException(404, f"Tag group {group_id} not found", 404) @@ -37,9 +32,8 @@ async def get_tag_group(group_id: str, db: AsyncSession = Depends(get_db)) -> di @router.post("", dependencies=[Security(require_write_access)]) -async def create_tag_group(data: TagGroupCreate, db: AsyncSession = Depends(get_db)) -> dict: +async def create_tag_group(data: TagGroupCreate, service: TagServiceDep) -> dict: """Create a new tag group.""" - service = TagService(db) try: group = await service.create_group(data) return success_response(group) @@ -48,13 +42,16 @@ async def create_tag_group(data: TagGroupCreate, db: AsyncSession = Depends(get_ @router.put("/{group_id}", dependencies=[Security(require_write_access)]) -async def update_tag_group(group_id: str, data: TagGroupUpdate, db: AsyncSession = Depends(get_db)) -> dict: +async def update_tag_group( + group_id: str, + data: TagGroupUpdate, + service: TagServiceDep, +) -> dict: """Update a tag group.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) try: group = await service.update_group(group_id, data) if not group: @@ -65,13 +62,16 @@ async def update_tag_group(group_id: str, data: TagGroupUpdate, db: AsyncSession @router.delete("/{group_id}", dependencies=[Security(require_write_access)]) -async def delete_tag_group(group_id: str, force: bool = Query(False), db: AsyncSession = Depends(get_db)) -> dict: +async def delete_tag_group( + group_id: str, + service: TagServiceDep, + force: bool = Query(False), +) -> dict: """Delete a tag group.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) success = await service.delete_group(group_id, force=force) if not success: raise APIException(404, f"Tag group {group_id} not found", 404) @@ -82,15 +82,14 @@ async def delete_tag_group(group_id: str, force: bool = Query(False), db: AsyncS async def add_tag_to_group( group_id: str, data: TagCreate, + service: TagServiceDep, skip_duplicates: bool = Query(False, description="Skip duplicate tags instead of raising error"), - db: AsyncSession = Depends(get_db), ) -> dict: """Add a tag to a group.""" try: UUID(group_id) except ValueError: raise APIException(404, f"Tag group {group_id} not found", 404) - service = TagService(db) try: group, was_created = await service.add_tag_to_group(group_id, data, skip_duplicates=skip_duplicates) if not group: @@ -101,14 +100,18 @@ async def add_tag_to_group( @router.put("/{group_id}/tags/{tag_id}", dependencies=[Security(require_write_access)]) -async def update_tag(group_id: str, tag_id: str, data: TagUpdate, db: AsyncSession = Depends(get_db)) -> dict: +async def update_tag( + group_id: str, + tag_id: str, + data: TagUpdate, + service: TagServiceDep, +) -> dict: """Update a tag in a group.""" try: UUID(group_id) UUID(tag_id) except ValueError: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) - service = TagService(db) group = await service.update_tag(group_id, tag_id, data) if not group: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) @@ -116,14 +119,17 @@ async def update_tag(group_id: str, tag_id: str, data: TagUpdate, db: AsyncSessi @router.delete("/{group_id}/tags/{tag_id}", dependencies=[Security(require_write_access)]) -async def remove_tag(group_id: str, tag_id: str, db: AsyncSession = Depends(get_db)) -> dict: +async def remove_tag( + group_id: str, + tag_id: str, + service: TagServiceDep, +) -> dict: """Remove a tag from a group.""" try: UUID(group_id) UUID(tag_id) except ValueError: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) - service = TagService(db) group = await service.remove_tag(group_id, tag_id) if not group: raise APIException(404, f"Tag {tag_id} not found in group {group_id}", 404) @@ -146,10 +152,8 @@ class BulkTagImportResponse(BaseModel): @router.post("/bulk", response_model=dict) -async def bulk_import_tags(data: BulkTagImportRequest, db: AsyncSession = Depends(get_db)): +async def bulk_import_tags(data: BulkTagImportRequest, service: TagServiceDep): """Bulk import tags with duplicate handling.""" - service = TagService(db) - groups_created = 0 tags_created = 0 tags_skipped = 0 diff --git a/app/api/v1/websocket.py b/app/api/v1/websocket.py index c5f6608..d0adb92 100644 --- a/app/api/v1/websocket.py +++ b/app/api/v1/websocket.py @@ -26,7 +26,11 @@ from fastapi import Security, APIRouter, WebSocket, WebSocketDisconnect from app.config import get_settings -from app.dependencies import require_read_access, ws_require_read_access +from app.dependencies import ( + RedisServiceDep, + require_read_access, + ws_require_read_access, +) from app.services.redis_service import get_redis_service from app.services.subscription_registry import ( SubscriptionRegistry, @@ -428,7 +432,7 @@ def get_connection_manager() -> DiffStreamManager: @router.websocket("/pvs", dependencies=[Security(ws_require_read_access)]) -async def websocket_pvs(websocket: WebSocket): +async def websocket_pvs(websocket: WebSocket, redis: RedisServiceDep): """ WebSocket endpoint for real-time PV updates with diff streaming. @@ -467,7 +471,6 @@ async def websocket_pvs(websocket: WebSocket): elif message_type == "get_all": # Send all cached values (legacy support) try: - redis = get_redis_service() all_values = await redis.get_all_pv_values_as_dict() await websocket.send_json( { diff --git a/app/dependencies.py b/app/dependencies.py index fd3c5a5..fa883f9 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -9,8 +9,10 @@ from app.api.responses import APIException from app.schemas.api_key import ApiKeyDTO from app.services.pv_service import PVService +from app.services.job_service import JobService from app.services.tag_service import TagService from app.services.epics_service import EpicsService, get_epics_service +from app.services.redis_service import RedisService, get_redis_service from app.services.api_key_service import ApiKeyService from app.services.snapshot_service import SnapshotService @@ -18,37 +20,61 @@ # --------------------------------------------------------------------------- -# Service factory dependencies +# Service dependencies # --------------------------------------------------------------------------- +DataBaseDep = Annotated[AsyncSession, Depends(get_db)] -def get_pv_service(db: AsyncSession = Depends(get_db)) -> PVService: - """Get PV service instance.""" +EpicsServiceDep = Annotated[EpicsService, Depends(get_epics_service)] + +RedisServiceDep = Annotated[RedisService, Depends(get_redis_service)] + + +def get_pv_service(db: DataBaseDep) -> PVService: return PVService(db) -def get_snapshot_service( - db: AsyncSession = Depends(get_db), epics: EpicsService = Depends(get_epics_service) -) -> SnapshotService: - """Get Snapshot service instance.""" - return SnapshotService(db, epics) +PVServiceDep = Annotated[PVService, Depends(get_pv_service)] -def get_tag_service(db: AsyncSession = Depends(get_db)) -> TagService: - """Get Tag service instance.""" +def get_tag_service(db: DataBaseDep) -> TagService: return TagService(db) +TagServiceDep = Annotated[TagService, Depends(get_tag_service)] + + +def get_api_key_service(db: DataBaseDep) -> ApiKeyService: + return ApiKeyService(db) + + +ApiKeyServiceDep = Annotated[ApiKeyService, Depends(get_api_key_service)] + + +def get_job_service(db: DataBaseDep) -> JobService: + return JobService(db) + + +JobServiceDep = Annotated[JobService, Depends(get_job_service)] + + +def get_snapshot_service(db: DataBaseDep, epics: EpicsServiceDep, redis: RedisServiceDep) -> SnapshotService: + return SnapshotService(db, epics, redis) + + +SnapshotServiceDep = Annotated[SnapshotService, Depends(get_snapshot_service)] + + # --------------------------------------------------------------------------- # API Key auth dependencies # --------------------------------------------------------------------------- async def get_api_key( - db: Annotated[AsyncSession, Depends(get_db)], api_key_header: Annotated[str, Security(api_key_header)] + service: ApiKeyServiceDep, + api_key_header: Annotated[str, Security(api_key_header)], ) -> ApiKeyDTO | None: if api_key_header: - service = ApiKeyService(db) api_key_dto = await service.get_by_token(api_key_header) if api_key_dto and api_key_dto.isActive: @@ -86,11 +112,10 @@ def require_write_access(api_key_dto: Annotated[ApiKeyDTO, Security(get_api_key) # --------------------------------------------------------------------------- -async def ws_get_api_key(websocket: WebSocket, db: AsyncSession = Security(get_db)) -> ApiKeyDTO: +async def ws_get_api_key(websocket: WebSocket, service: ApiKeyServiceDep) -> ApiKeyDTO: """WebSocket variant of get_api_key — raises WebSocketException on failure.""" key_value = websocket.headers.get("X-API-Key") if key_value: - service = ApiKeyService(db) api_key_dto = await service.get_by_token(key_value) if api_key_dto and api_key_dto.isActive: diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index b545292..71ef79b 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -2,7 +2,7 @@ Tests for app/dependencies.py — service factories and auth guards. """ from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest from fastapi import status @@ -91,13 +91,15 @@ class TestGetSnapshotService: def test_returns_snapshot_service_instance(self): db = MagicMock() epics = MagicMock() - result = get_snapshot_service(db, epics) + redis = MagicMock() + result = get_snapshot_service(db, epics, redis) assert isinstance(result, SnapshotService) def test_passes_db_and_epics_to_service(self): db = MagicMock() epics = MagicMock() - result = get_snapshot_service(db, epics) + redis = MagicMock() + result = get_snapshot_service(db, epics, redis) assert result.session is db assert result.epics is epics @@ -107,70 +109,68 @@ def test_passes_db_and_epics_to_service(self): # --------------------------------------------------------------------------- +def _make_mock_api_key_service(return_value=None) -> MagicMock: + """Build a mock ApiKeyService with an async get_by_token.""" + service = MagicMock() + service.get_by_token = AsyncMock(return_value=return_value) + return service + + class TestGetApiKey: """Tests for get_api_key dependency.""" @pytest.mark.asyncio async def test_no_header_raises_401(self): """Missing header (None) should raise 401 immediately.""" - db = MagicMock() + service = _make_mock_api_key_service() with pytest.raises(APIException) as exc_info: - await get_api_key(db, None) + await get_api_key(service, None) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_unknown_token_raises_401(self): """A token not found in the DB should raise 401.""" - db = MagicMock() - with patch("app.dependencies.ApiKeyService") as MockService: - MockService.return_value.get_by_token = AsyncMock(return_value=None) - with pytest.raises(APIException) as exc_info: - await get_api_key(db, "sq_unknown") + service = _make_mock_api_key_service(return_value=None) + with pytest.raises(APIException) as exc_info: + await get_api_key(service, "sq_unknown") assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_inactive_key_raises_401(self): """A deactivated key should raise 401.""" - db = MagicMock() inactive_key = _make_api_key_dto(isActive=False) - with patch("app.dependencies.ApiKeyService") as MockService: - MockService.return_value.get_by_token = AsyncMock(return_value=inactive_key) - with pytest.raises(APIException) as exc_info: - await get_api_key(db, "sq_inactive") + service = _make_mock_api_key_service(return_value=inactive_key) + with pytest.raises(APIException) as exc_info: + await get_api_key(service, "sq_inactive") assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert exc_info.value.error_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_active_key_returns_dto(self): """A valid, active key should be returned as-is.""" - db = MagicMock() active_key = _make_api_key_dto(isActive=True) - with patch("app.dependencies.ApiKeyService") as MockService: - MockService.return_value.get_by_token = AsyncMock(return_value=active_key) - result = await get_api_key(db, "sq_valid_token") + service = _make_mock_api_key_service(return_value=active_key) + result = await get_api_key(service, "sq_valid_token") assert result is active_key @pytest.mark.asyncio async def test_error_message_mentions_api_key(self): """401 error message should reference the API key.""" - db = MagicMock() + service = _make_mock_api_key_service() with pytest.raises(APIException) as exc_info: - await get_api_key(db, None) + await get_api_key(service, None) assert "api key" in exc_info.value.error_message.lower() @pytest.mark.asyncio async def test_service_receives_provided_token(self): """The exact header value should be forwarded to ApiKeyService.get_by_token.""" - db = MagicMock() token = "sq_specific_token_value" - with patch("app.dependencies.ApiKeyService") as MockService: - mock_get = AsyncMock(return_value=None) - MockService.return_value.get_by_token = mock_get - with pytest.raises(APIException): - await get_api_key(db, token) - mock_get.assert_awaited_once_with(token) + service = _make_mock_api_key_service(return_value=None) + with pytest.raises(APIException): + await get_api_key(service, token) + service.get_by_token.assert_awaited_once_with(token) # --------------------------------------------------------------------------- @@ -243,43 +243,37 @@ class TestWsGetApiKey: async def test_missing_header_raises_ws_exception(self): """No X-API-Key header should raise WebSocketException 1008.""" ws = _make_websocket() - db = MagicMock() + service = _make_mock_api_key_service() with pytest.raises(WebSocketException) as exc_info: - await ws_get_api_key(ws, db) + await ws_get_api_key(ws, service) assert exc_info.value.code == status.WS_1008_POLICY_VIOLATION @pytest.mark.asyncio async def test_unknown_token_raises_ws_exception(self): """A token not found in the DB should raise WebSocketException 1008.""" ws = _make_websocket("sq_unknown") - db = MagicMock() - with patch("app.dependencies.ApiKeyService") as MockService: - MockService.return_value.get_by_token = AsyncMock(return_value=None) - with pytest.raises(WebSocketException) as exc_info: - await ws_get_api_key(ws, db) + service = _make_mock_api_key_service(return_value=None) + with pytest.raises(WebSocketException) as exc_info: + await ws_get_api_key(ws, service) assert exc_info.value.code == status.WS_1008_POLICY_VIOLATION @pytest.mark.asyncio async def test_inactive_key_raises_ws_exception(self): """A deactivated key should raise WebSocketException 1008.""" ws = _make_websocket("sq_inactive") - db = MagicMock() inactive_key = _make_api_key_dto(isActive=False) - with patch("app.dependencies.ApiKeyService") as MockService: - MockService.return_value.get_by_token = AsyncMock(return_value=inactive_key) - with pytest.raises(WebSocketException) as exc_info: - await ws_get_api_key(ws, db) + service = _make_mock_api_key_service(return_value=inactive_key) + with pytest.raises(WebSocketException) as exc_info: + await ws_get_api_key(ws, service) assert exc_info.value.code == status.WS_1008_POLICY_VIOLATION @pytest.mark.asyncio async def test_active_key_returns_dto(self): """A valid, active key should be returned as-is.""" ws = _make_websocket("sq_valid_token") - db = MagicMock() active_key = _make_api_key_dto(isActive=True) - with patch("app.dependencies.ApiKeyService") as MockService: - MockService.return_value.get_by_token = AsyncMock(return_value=active_key) - result = await ws_get_api_key(ws, db) + service = _make_mock_api_key_service(return_value=active_key) + result = await ws_get_api_key(ws, service) assert result is active_key @pytest.mark.asyncio @@ -287,21 +281,18 @@ async def test_service_receives_provided_token(self): """The exact header value should be forwarded to ApiKeyService.get_by_token.""" token = "sq_specific_token_value" ws = _make_websocket(token) - db = MagicMock() - with patch("app.dependencies.ApiKeyService") as MockService: - mock_get = AsyncMock(return_value=None) - MockService.return_value.get_by_token = mock_get - with pytest.raises(WebSocketException): - await ws_get_api_key(ws, db) - mock_get.assert_awaited_once_with(token) + service = _make_mock_api_key_service(return_value=None) + with pytest.raises(WebSocketException): + await ws_get_api_key(ws, service) + service.get_by_token.assert_awaited_once_with(token) @pytest.mark.asyncio async def test_error_reason_mentions_api_key(self): """WebSocketException reason should reference the API key.""" ws = _make_websocket() - db = MagicMock() + service = _make_mock_api_key_service() with pytest.raises(WebSocketException) as exc_info: - await ws_get_api_key(ws, db) + await ws_get_api_key(ws, service) assert "api key" in exc_info.value.reason.lower()