From 75ae9b7e3861ccd2773a18d8f7318de5e5ea3eec Mon Sep 17 00:00:00 2001 From: LucasTargett Date: Tue, 5 May 2026 21:37:08 +1000 Subject: [PATCH] Combine backend API gateway routes and test fixes --- .../backend/app/auth/dependencies.py | 7 +- .../backend/app/auth/jwt.py | 27 +-- .../backend/app/config.py | 6 +- .../backend/app/database.py | 23 +- .../backend/app/main.py | 2 +- .../backend/app/models.py | 13 +- .../backend/app/routes/auth.py | 69 +++--- .../backend/app/routes/crowd.py | 26 ++- .../backend/app/routes/jobs.py | 29 +++ .../backend/app/routes/players.py | 24 ++- .../backend/app/routes/upload.py | 61 +++--- .../backend/app/schemas/auth.py | 9 - .../backend/app/schemas/jobs.py | 21 +- .../backend/tests/test_auth.py | 8 +- .../backend/tests/test_jobs.py | 202 ++++++++++++++++++ .../backend/tests/test_upload.py | 41 ++++ 16 files changed, 419 insertions(+), 149 deletions(-) diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/dependencies.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/dependencies.py index bf9862c37..e402e7b00 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/dependencies.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/dependencies.py @@ -1,11 +1,10 @@ from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.security import OAuth2PasswordBearer from app.auth.jwt import decode_access_token -http_bearer = HTTPBearer() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") -def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(http_bearer)) -> dict: - token = credentials.credentials +def get_current_user(token: str = Depends(oauth2_scheme)) -> dict: payload = decode_access_token(token) if payload is None: raise HTTPException( diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/jwt.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/jwt.py index 24dc5b5d6..e86dc0d87 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/jwt.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/auth/jwt.py @@ -1,39 +1,16 @@ from datetime import datetime, timedelta, timezone from jose import JWTError, jwt -from app.config import JWT_SECRET_KEY, JWT_ALGORITHM, JWT_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS - +from app.config import JWT_SECRET_KEY, JWT_ALGORITHM, JWT_EXPIRE_MINUTES def create_access_token(data: dict) -> str: to_encode = data.copy() expire = datetime.now(timezone.utc) + timedelta(minutes=JWT_EXPIRE_MINUTES) - to_encode.update({"exp": expire, "type": "access"}) + to_encode.update({"exp": expire}) return jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) - -def create_refresh_token(data: dict) -> tuple[str, datetime]: - """Returns (token, expires_at)""" - to_encode = data.copy() - expires_at = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) - to_encode.update({"exp": expires_at, "type": "refresh"}) - token = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) - return token, expires_at - - def decode_access_token(token: str) -> dict: try: payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) - if payload.get("type") != "access": - return None - return payload - except JWTError: - return None - - -def decode_refresh_token(token: str) -> dict: - try: - payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) - if payload.get("type") != "refresh": - return None return payload except JWTError: return None diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/config.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/config.py index 101f90591..c3df82ff5 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/config.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/config.py @@ -10,14 +10,12 @@ USE_MOCK_SERVICES = os.getenv("USE_MOCK_SERVICES", "true").lower() == "true" -DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost:5432/orion_db") +# Have just added async driver ('+asyncpg') to URL to match app - Lucas +DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://user:password@localhost:5432/orion_db") JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-secret-key-here") JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256") JWT_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", 60)) -REFRESH_TOKEN_EXPIRE_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", 7)) DEBUG = os.getenv("DEBUG", "True").lower() == "true" LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") - - diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/database.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/database.py index 13cc2c051..885b8d80f 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/database.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/database.py @@ -1,17 +1,20 @@ -from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base from app.config import DATABASE_URL -engine = create_engine(DATABASE_URL) +engine = create_async_engine( + DATABASE_URL, + echo=True +) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +SessionLocal = sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False +) Base = declarative_base() - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() +async def get_db(): + async with SessionLocal() as session: + yield session \ No newline at end of file diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/main.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/main.py index b90ffeb01..e6933ecda 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/main.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/main.py @@ -13,7 +13,7 @@ ) logger = logging.getLogger(__name__) -Base.metadata.create_all(bind=engine) +# Base.metadata.create_all(bind=engine.sync_engine) app = FastAPI( title="Project Orion Backend API", diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/models.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/models.py index dd4213220..743d460ae 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/models.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/models.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean +from sqlalchemy import Column, String, DateTime, ForeignKey from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.orm import relationship # ✅ ADDED from app.database import Base @@ -19,17 +19,6 @@ class User(Base): jobs = relationship("Job", back_populates="user") -class RefreshToken(Base): - __tablename__ = "refresh_tokens" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.user_id"), nullable=False) - token = Column(String, unique=True, nullable=False) - is_active = Column(Boolean, default=True, nullable=False) - created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) - expires_at = Column(DateTime, nullable=False) - - class Job(Base): __tablename__ = "jobs" diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/auth.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/auth.py index cdcdabb41..00b9705c8 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/auth.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/auth.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from app.database import get_db from app.models import User, RefreshToken @@ -13,7 +13,6 @@ def _issue_tokens(user: User, db: Session) -> dict: - """Create access + refresh tokens, store refresh token in DB.""" token_data = { "sub": str(user.user_id), "email": user.email, @@ -41,30 +40,44 @@ def _issue_tokens(user: User, db: Session) -> dict: @router.post("/register", response_model=AuthResponse) def register(user: RegisterRequest, db: Session = Depends(get_db)): - if db.query(User).filter(User.email == user.email).first(): - raise HTTPException(status_code=400, detail="Email already registered") - if db.query(User).filter(User.username == user.username).first(): - raise HTTPException(status_code=400, detail="Username already taken") - - new_user = User( - email=user.email, - username=user.username, - password=hash_password(user.password) - ) - db.add(new_user) - db.commit() - db.refresh(new_user) + try: + if db.query(User).filter(User.email == user.email).first(): + raise HTTPException(status_code=409, detail="Email already registered") + if db.query(User).filter(User.username == user.username).first(): + raise HTTPException(status_code=409, detail="Username already taken") + + new_user = User( + email=user.email, + username=user.username, + password=hash_password(user.password) + ) + db.add(new_user) + db.commit() + db.refresh(new_user) - return _issue_tokens(new_user, db) + return _issue_tokens(new_user, db) + + except HTTPException: + raise + + except Exception: + raise HTTPException(status_code=500, detail="Internal server error during registration") @router.post("/login", response_model=AuthResponse) def login(user: LoginRequest, db: Session = Depends(get_db)): - db_user = db.query(User).filter(User.email == user.email).first() - if not db_user or not verify_password(user.password, db_user.password): - raise HTTPException(status_code=401, detail="Invalid email or password") + try: + db_user = db.query(User).filter(User.email == user.email).first() + if not db_user or not verify_password(user.password, db_user.password): + raise HTTPException(status_code=401, detail="Invalid email or password") + + return _issue_tokens(db_user, db) - return _issue_tokens(db_user, db) + except HTTPException: + raise + + except Exception: + raise HTTPException(status_code=500, detail="Internal server error during login") @router.post("/refresh", response_model=AuthResponse) @@ -90,7 +103,6 @@ def refresh(body: RefreshRequest, db: Session = Depends(get_db)): if not user: raise HTTPException(status_code=404, detail="User not found") - # Revoke old refresh token and issue a new pair db_token.is_active = False db.commit() @@ -112,7 +124,14 @@ def logout(body: LogoutRequest, db: Session = Depends(get_db)): @router.get("/me", response_model=UserResponse) def get_me(current_user: dict = Depends(get_current_user), db: Session = Depends(get_db)): - user = db.query(User).filter(User.user_id == current_user["sub"]).first() - if not user: - raise HTTPException(status_code=404, detail="User not found") - return user + try: + user = db.query(User).filter(User.user_id == current_user["sub"]).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + + except HTTPException: + raise + + except Exception: + raise HTTPException(status_code=500, detail="Internal server error retrieving user") \ No newline at end of file diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/crowd.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/crowd.py index 4ac1293bf..b635b9782 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/crowd.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/crowd.py @@ -1,14 +1,26 @@ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.services.crowd_client import get_crowd_data + router = APIRouter(prefix="/api", tags=["Crowd"]) @router.get("/crowd") async def get_crowd(): - data = await get_crowd_data() - return { - "status": "success", - "message": "Crowd data retrieved successfully", - "data": data - } + try: + data = await get_crowd_data() + + if not data: + raise HTTPException(status_code=404, detail="Crowd data not found") + + return { + "status": "success", + "message": "Crowd data retrieved successfully", + "data": data + } + + except HTTPException: + raise + + except Exception: + raise HTTPException(status_code=500, detail="Internal server error while retrieving crowd data") diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/jobs.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/jobs.py index e76b83c8d..8a428e86c 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/jobs.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/jobs.py @@ -1,6 +1,8 @@ import asyncio +import httpx from datetime import datetime, timezone from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from app.database import get_db from app.models import Job @@ -8,6 +10,7 @@ from app.auth.dependencies import get_current_user from app.services.player_client import get_player_data from app.services.crowd_client import get_crowd_data +from app.config import CROWD_SERVICE_URL router = APIRouter() @@ -129,6 +132,32 @@ async def retry_job( return {"job_id": str(job.job_id), "status": job.status} +@router.get("/jobs/{job_id}/heatmap") +async def get_heatmap( + job_id: str, + current_user: dict = Depends(get_current_user), + db: Session = Depends(get_db) +): + job = db.query(Job).filter(Job.job_id == job_id).first() + if not job: + raise HTTPException(status_code=404, detail="Job not found") + check_job_access(job, current_user) + + crowd = job.crowd_result + if not crowd or not crowd.get("heatmap") or not crowd["heatmap"].get("image_path"): + raise HTTPException(status_code=404, detail="Heatmap not available for this job") + + image_path = crowd["heatmap"]["image_path"].replace("\\", "/") + url = f"{CROWD_SERVICE_URL}/artifacts/{image_path}" + + async with httpx.AsyncClient(timeout=10.0) as client: + r = await client.get(url) + if r.status_code != 200: + raise HTTPException(status_code=502, detail="Could not fetch heatmap from crowd service") + + return StreamingResponse(iter([r.content]), media_type="image/png") + + @router.delete("/jobs/{job_id}") def delete_job( job_id: str, diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/players.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/players.py index dafa008cd..13495f00e 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/players.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/players.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from app.services.player_client import get_player_data router = APIRouter(prefix="/api", tags=["Players"]) @@ -6,9 +6,19 @@ @router.get("/players") async def get_players(): - data = await get_player_data() - return { - "status": "success", - "message": "Players data retrieved successfully", - "data": data - } + try: + data = await get_player_data() + + if not data: + raise HTTPException(status_code=404, detail="Player data not found") + + return { + "status": "success", + "message": "Players data retrieved successfully", + "data": data + } + except HTTPException: + raise + + except Exception: + raise HTTPException(status_code=500, detail="Internal server error while retrieving player data") diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/upload.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/upload.py index 9e059af9a..a53256e6c 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/upload.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/routes/upload.py @@ -12,6 +12,8 @@ from app.services.player_client import get_player_data from app.services.crowd_client import get_crowd_data +from sqlalchemy import select + router = APIRouter() ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov"} @@ -44,16 +46,17 @@ async def process_video(job_id: str, file_path: str): status = "done" error = None - job = db.query(Job).filter(Job.job_id == job_id).first() + result = await db.execute(select(Job).where(Job.job_id == job_id)) + job = result.scalar_one_or_none() if job: job.status = status job.player_result = player_data job.crowd_result = crowd_data job.error = error job.updated_at = datetime.now(timezone.utc) - db.commit() + await db.commit() finally: - db.close() + await db.close() if os.path.exists(file_path): os.remove(file_path) @@ -71,27 +74,33 @@ async def upload_video( status_code=400, detail="Invalid video format. Accepted formats: .mp4, .avi, .mov" ) + try: + os.makedirs(UPLOAD_DIR, exist_ok=True) + filename = f"{uuid.uuid4()}{ext}" + file_path = os.path.join(UPLOAD_DIR, filename) + + with open(file_path, "wb") as f: + f.write(await file.read()) + + job = Job( + user_id=current_user["sub"], + status="processing", + video_path=file_path + ) + db.add(job) + db.commit() + db.refresh(job) + + background_tasks.add_task(process_video, str(job.job_id), file_path) + + return { + "job_id": str(job.job_id), + "status": job.status, + "created_at": job.created_at + } + + except HTTPException: + raise - os.makedirs(UPLOAD_DIR, exist_ok=True) - filename = f"{uuid.uuid4()}{ext}" - file_path = os.path.join(UPLOAD_DIR, filename) - - with open(file_path, "wb") as f: - f.write(await file.read()) - - job = Job( - user_id=current_user["sub"], - status="processing", - video_path=file_path - ) - db.add(job) - db.commit() - db.refresh(job) - - background_tasks.add_task(process_video, str(job.job_id), file_path) - - return { - "job_id": str(job.job_id), - "status": job.status, - "created_at": job.created_at - } + except Exception: + raise HTTPException(status_code=500, detail="Internal server error while uploading video") diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/auth.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/auth.py index aae717cfa..a104f317d 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/auth.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/auth.py @@ -27,20 +27,11 @@ class UserResponse(BaseModel): class AuthResponse(BaseModel): access_token: str - refresh_token: str token_type: str user: Optional[UserResponse] = None expires_in: int -class RefreshRequest(BaseModel): - refresh_token: str - - -class LogoutRequest(BaseModel): - refresh_token: str - - diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/jobs.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/jobs.py index 8810c1499..c39934f47 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/jobs.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/app/schemas/jobs.py @@ -1,31 +1,22 @@ -from pydantic import BaseModel, ConfigDict, field_serializer +from pydantic import BaseModel, ConfigDict from typing import Optional, List, Any from datetime import datetime -from uuid import UUID class UploadResponse(BaseModel): - job_id: UUID + job_id: str status: str created_at: datetime - @field_serializer("job_id") - def serialize_job_id(self, v: UUID) -> str: - return str(v) - class JobSummary(BaseModel): model_config = ConfigDict(from_attributes=True) - job_id: UUID + job_id: str status: str created_at: datetime updated_at: datetime - @field_serializer("job_id") - def serialize_job_id(self, v: UUID) -> str: - return str(v) - class JobResults(BaseModel): player: Optional[Any] = None @@ -40,17 +31,13 @@ class JobErrors(BaseModel): class JobDetail(BaseModel): model_config = ConfigDict(from_attributes=True) - job_id: UUID + job_id: str status: str created_at: datetime updated_at: datetime results: Optional[JobResults] = None errors: Optional[JobErrors] = None - @field_serializer("job_id") - def serialize_job_id(self, v: UUID) -> str: - return str(v) - class JobListResponse(BaseModel): total: int diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_auth.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_auth.py index 6806a5106..6d3f29816 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_auth.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_auth.py @@ -44,15 +44,19 @@ def test_register_success(client, mock_db): def test_register_duplicate_email(client, mock_db): mock_db.query.return_value.filter.return_value.first.return_value = make_mock_user() response = client.post("/auth/register", json=REGISTER_PAYLOAD) - assert response.status_code == 400 + assert response.status_code == 409 assert "Email already registered" in response.json()["detail"] +def test_register_duplicate_username(client, mock_db): + mock_db.query.return_value.filter.return_value.first.side_effect = [None, make_mock_user()] + response = client.post("/auth/register", json=REGISTER_PAYLOAD) + assert response.status_code == 409 + assert "Username already taken" in response.json()["detail"] def test_register_missing_fields(client, mock_db): response = client.post("/auth/register", json={"email": "test@example.com"}) assert response.status_code == 422 - # --- Login --- def test_login_success(client, mock_db): diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_jobs.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_jobs.py index e69de29bb..766b74de9 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_jobs.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_jobs.py @@ -0,0 +1,202 @@ +from datetime import datetime, timezone +from types import SimpleNamespace + +from app.main import app +from app.auth.dependencies import get_current_user + +# Bypass authentication +def override_get_current_user(): + return {"sub": "test_user", "role": "admin"} + +# Mock DB record +def make_job( + job_id="11111111-1111-1111-1111-111111111111", + status="done", + user_id="test_user", + player_result=None, + crowd_result=None, + error=None, + video_path="uploads/test.mp4", + ): + now = datetime.now(timezone.utc) + + return SimpleNamespace( + job_id=job_id, + status=status, + user_id=user_id, + player_result=player_result, + crowd_result=crowd_result, + error=error, + video_path=video_path, + created_at=now, + updated_at=now, + ) + +# Test: Get job status (success) +def test_get_status_success(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + fake_job = make_job( + status="done", + player_result={"players": 10}, + crowd_result={"crowd": 50}, + ) + mock_db.query.return_value.filter.return_value.first.return_value = fake_job # Mock DB query to return fake job + response = client.get("/status/11111111-1111-1111-1111-111111111111") # Send GET request to endpoint + assert response.status_code == 200 + data = response.json() + assert data["job_id"] == "11111111-1111-1111-1111-111111111111" + assert data["status"] == "done" + assert "results" in data + +# Test: Job status not found +def test_get_status_not_found(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + mock_db.query.return_value.filter.return_value.first.return_value = None + + response = client.get("/status/missing-job") + + assert response.status_code == 404 + assert response.json()["detail"] == "Job not found" + +# Test: List jobs (pagination) +def test_list_jobs_with_pagination(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + # Create fake jobs list + jobs = [ + make_job(job_id="11111111-1111-1111-1111-111111111111"), + make_job(job_id="22222222-2222-2222-2222-222222222222"), + make_job(job_id="33333333-3333-3333-3333-333333333333"), + ] + + mock_db.query.return_value.count.return_value = 3 + mock_db.query.return_value.order_by.return_value.offset.return_value.limit.return_value.all.return_value = jobs[:2] + + response = client.get("/jobs?page=1&limit=2") + assert response.status_code == 200 + data = response.json() + + # Check pagination fields + assert data["total"] == 3 + assert data["page"] == 1 + assert data["limit"] == 2 + assert len(data["jobs"]) == 2 # only 2 returned + +# Test: Get job details +def test_get_job_success(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + # Fake completed job + fake_job = make_job( + status="done", + player_result={"players": 10}, + crowd_result={"crowd": 50}, + ) + + mock_db.query.return_value.filter.return_value.first.return_value = fake_job + + response = client.get("/jobs/11111111-1111-1111-1111-111111111111") + assert response.status_code == 200 + data = response.json() + # Validate job data + assert data["job_id"] == "11111111-1111-1111-1111-111111111111" + assert data["status"] == "done" + assert "results" in data + +# Test Get job not found +def test_get_job_not_found(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + # No job returned + mock_db.query.return_value.filter.return_value.first.return_value = None + + response = client.get("/jobs/missing-job") + assert response.status_code == 404 + assert response.json()["detail"] == "Job not found" + +# Test: Retry job (success case) +def test_retry_job_success(client, mock_db, monkeypatch): + app.dependency_overrides[get_current_user] = override_get_current_user + + # Partial job (only one result missing) + fake_job = make_job( + status="partial", + player_result=None, + crowd_result={"crowd": 50}, + ) + + mock_db.query.return_value.filter.return_value.first.return_value = fake_job + + # Mock async player service response + async def fake_player_data(video_path): + return {"players": 12} + + # Replace real service call with fake one + monkeypatch.setattr("app.routes.jobs.get_player_data", fake_player_data) + + response = client.post("/jobs/11111111-1111-1111-1111-111111111111/retry") + assert response.status_code == 200 + data = response.json() + + # Job should now be completed + assert data["job_id"] == "11111111-1111-1111-1111-111111111111" + assert data["status"] == "done" + +# Test: Retry invalid job +def test_retry_job_not_partial(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + # Job already complete + fake_job = make_job( + status="done", + player_result={"players": 10}, + crowd_result={"crowd": 50}, + ) + + mock_db.query.return_value.filter.return_value.first.return_value = fake_job + + response = client.post("/jobs/11111111-1111-1111-1111-111111111111/retry") + + # Should fail + assert response.status_code == 400 + assert response.json()["detail"] == "Only partial jobs can be retried" + +# Test: Delete job +def test_delete_job_success(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + fake_job = make_job() + + # Mock DB returning a job + mock_db.query.return_value.filter.return_value.first.return_value = fake_job + + response = client.delete("/jobs/11111111-1111-1111-1111-111111111111") + assert response.status_code == 200 + assert response.json()["message"] == "job deleted" + + # Ensure DB methods were called + mock_db.delete.assert_called_once_with(fake_job) + mock_db.commit.assert_called() + +# When attempting to delete a job that does not exist +def test_delete_job_not_found(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + mock_db.query.return_value.filter.return_value.first.return_value = None + + response = client.delete("/jobs/missing-job") + + assert response.status_code == 404 + assert response.json()["detail"] == "Job not found" + +# When attempting to retry a job that does not exist +def test_retry_job_not_found(client, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + + mock_db.query.return_value.filter.return_value.first.return_value = None + + response = client.post("/jobs/missing-job/retry") + + assert response.status_code == 404 + assert response.json()["detail"] == "Job not found" \ No newline at end of file diff --git a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_upload.py b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_upload.py index e69de29bb..db1976580 100644 --- a/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_upload.py +++ b/26_T1/afl_player_tracking_and_crowd_monitoring/backend/tests/test_upload.py @@ -0,0 +1,41 @@ +from app.main import app +from datetime import datetime, timezone +from app.auth.dependencies import get_current_user + +def override_get_current_user(): + return {"sub": "test_user", "role": "admin"} + +async def fake_process_video(job_id, file_path): + return "11111111-1111-1111-1111-111111111111" + +def test_upload_valid_file(client, monkeypatch, mock_db): + app.dependency_overrides[get_current_user] = override_get_current_user + monkeypatch.setattr("app.routes.upload.process_video", fake_process_video) + + def fake_refresh(job): + job.job_id = "11111111-1111-1111-1111-111111111111" + job.created_at = datetime.now(timezone.utc) + + mock_db.refresh.side_effect = fake_refresh + response = client.post("/upload", files={"file": ("test.mp4", b"fake video content", "video/mp4")}) + assert response.status_code == 200 + assert response.json()["job_id"] == "11111111-1111-1111-1111-111111111111" + assert response.json()["status"] == "processing" + +def test_upload_invalid_file_type(client): + app.dependency_overrides[get_current_user] = override_get_current_user + response = client.post("/upload", files={"file": ("text.txt", b"dummy,data", "text/plain")}) + assert response.status_code == 400 + assert "invalid" in str(response.json()).lower() + +def test_missing_file(client): + app.dependency_overrides[get_current_user] = override_get_current_user + response = client.post("/upload", files={}) + assert response.status_code == 422 + assert "file" in str(response.json()) + +def test_upload_invalid_mime_type(client): + app.dependency_overrides[get_current_user] = override_get_current_user + response = client.post("/upload", files={"file": ("test.mp4", b"fake video content", "text/plain")}) + assert response.status_code == 400 + assert "Invalid video format" in response.json()["detail"] \ No newline at end of file