diff --git a/app/api/deps.py b/app/api/deps.py index ec9261d..57194a0 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator from typing import Annotated -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, Request, status from fastapi.security import OAuth2PasswordBearer from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession @@ -17,7 +17,9 @@ from app.schemas.token import TokenPayload from app.schemas.user import SystemRole -reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") +reusable_oauth2 = OAuth2PasswordBearer( + tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False +) async def get_db() -> AsyncGenerator[AsyncSession, None]: @@ -27,12 +29,22 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]: async def get_current_user( + request: Request, db: Annotated[AsyncSession, Depends(get_db)], - token: Annotated[str, Depends(reusable_oauth2)], + bearer_token: Annotated[str | None, Depends(reusable_oauth2)] = None, ) -> User: """ Get current authenticated user from JWT token. + Cookie takes priority; falls back to Authorization Bearer header. """ + token = request.cookies.get("access_token") or bearer_token + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ErrorMessages.INVALID_TOKEN, + headers={"WWW-Authenticate": "Bearer"}, + ) + try: # Check if token is blacklisted if await is_token_blacklisted(db, token): diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index e0bf618..b9876cc 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -9,7 +9,10 @@ from app.core.messages.error_message import ErrorMessages from app.core.messages.success_message import SuccessMessages from app.schemas.msg import Message -from app.schemas.token import LoginResponse, Token +from app.schemas.token import ( + CookieLoginResponse, + CookieRefreshResponse, +) from app.schemas.user import ( ForgotPassword, NewPassword, @@ -34,13 +37,15 @@ router = APIRouter() -@router.post("/login", response_model=LoginResponse, status_code=status.HTTP_200_OK) +@router.post( + "/login", response_model=CookieLoginResponse, status_code=status.HTTP_200_OK +) async def login_access_token( response: Response, request: Request, session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], -) -> LoginResponse: +) -> CookieLoginResponse: """ OAuth2 compatible token login, get an access token for future requests. """ @@ -52,6 +57,17 @@ async def login_access_token( password=form_data.password, ) + # Set access token in HttpOnly cookie + response.set_cookie( + key="access_token", + value=result.access_token, + httponly=True, + secure=settings.ENVIRONMENT != "local", + samesite="lax", + max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + path="/", + ) + # Set refresh token in HttpOnly cookie response.set_cookie( key="refresh_token", @@ -63,8 +79,7 @@ async def login_access_token( path=f"{settings.API_V1_STR}/auth/refresh", ) - return LoginResponse( - access_token=result.access_token, + return CookieLoginResponse( user=result.user, message=result.message, ) @@ -83,11 +98,14 @@ async def login_access_token( raise HTTPException(status_code=500, detail=ErrorMessages.INTERNAL_SERVER_ERROR) -@router.post("/refresh", response_model=Token, status_code=status.HTTP_200_OK) +@router.post( + "/refresh", response_model=CookieRefreshResponse, status_code=status.HTTP_200_OK +) async def refresh_token( request: Request, + response: Response, session: SessionDep, -) -> Token: +) -> CookieRefreshResponse: """ Refresh access token using the refresh token from cookie. """ @@ -99,9 +117,22 @@ async def refresh_token( ) try: - return await refresh_token_service( + result = await refresh_token_service( request=request, session=session, refresh_token=refresh_token_cookie ) + + # Set new access token in HttpOnly cookie + response.set_cookie( + key="access_token", + value=result.access_token, + httponly=True, + secure=settings.ENVIRONMENT != "local", + samesite="lax", + max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + path="/", + ) + + return CookieRefreshResponse(message=result.message) except HTTPException: raise except Exception as e: @@ -129,6 +160,7 @@ async def logout(request: Request, response: Response, session: SessionDep) -> M request=request, session=session, refresh_token=refresh_token ) + response.delete_cookie(key="access_token", path="/") response.delete_cookie( key="refresh_token", path=f"{settings.API_V1_STR}/auth/refresh", diff --git a/app/schemas/token.py b/app/schemas/token.py index b669f0e..1e30a39 100644 --- a/app/schemas/token.py +++ b/app/schemas/token.py @@ -10,16 +10,27 @@ class Token(BaseModel): message: str | None = None -class LoginResponse(Token): +class AuthTokens(Token): + refresh_token: str user: UserPublic + message: str | None = None -class AuthTokens(Token): - refresh_token: str +class CookieLoginResponse(BaseModel): + """Login response without access_token in body (token is in HttpOnly cookie).""" + + token_type: str = "bearer" user: UserPublic message: str | None = None +class CookieRefreshResponse(BaseModel): + """Refresh response without access_token in body (token is in HttpOnly cookie).""" + + token_type: str = "bearer" + message: str | None = None + + # Contents of JWT token class TokenPayload(BaseModel): sub: str | None = None diff --git a/app/tests/test_auth.py b/app/tests/test_auth.py index e478479..e603427 100644 --- a/app/tests/test_auth.py +++ b/app/tests/test_auth.py @@ -86,7 +86,7 @@ async def test_refresh_token_and_logout(client: AsyncClient): refresh_response = await client.post("/auth/refresh") assert refresh_response.status_code == 200 data = refresh_response.json() - assert "access_token" in data + assert refresh_response.cookies.get("access_token") is not None assert data["message"] == SuccessMessages.LOGIN_SUCCESS # Test Logout Endpoint @@ -149,7 +149,7 @@ async def test_verify_email_flow(client: AsyncClient): ) assert login_response.status_code == 200 data = login_response.json() - assert "access_token" in data + assert login_response.cookies.get("access_token") is not None assert data["message"] == SuccessMessages.LOGIN_SUCCESS @@ -282,19 +282,17 @@ async def test_change_password(client: AsyncClient): ) await session.commit() - # 3. Login to get access token + # 3. Login to get access token cookie login_response = await client.post( "/auth/login", data={"username": email, "password": old_password} ) assert login_response.status_code == 200 - access_token = login_response.json()["access_token"] - headers = {"Authorization": f"Bearer {access_token}"} + assert login_response.cookies.get("access_token") is not None # 4. Test Change Password - Failure (Wrong current password) fail_response = await client.patch( "/auth/change-password", json={"current_password": "wrongpassword", "new_password": "newPassword456"}, - headers=headers, ) assert fail_response.status_code == 400 assert fail_response.json()["error"] == ErrorMessages.INVALID_CURRENT_PASSWORD @@ -303,7 +301,6 @@ async def test_change_password(client: AsyncClient): success_response = await client.patch( "/auth/change-password", json={"current_password": old_password, "new_password": new_password}, - headers=headers, ) assert success_response.status_code == 200 assert success_response.json()["success"] is True @@ -314,7 +311,7 @@ async def test_change_password(client: AsyncClient): "/auth/login", data={"username": email, "password": new_password} ) assert new_login_response.status_code == 200 - assert "access_token" in new_login_response.json() + assert new_login_response.cookies.get("access_token") is not None # 7. Verify Login fails with OLD password old_login_response = await client.post( diff --git a/app/tests/test_users.py b/app/tests/test_users.py index 1d82f3a..3c42a90 100644 --- a/app/tests/test_users.py +++ b/app/tests/test_users.py @@ -35,7 +35,7 @@ async def auth_client(client: AsyncClient) -> AsyncClient: ) await session.commit() - # Login and get token + # Login — access_token is set as HttpOnly cookie; httpx stores and forwards it automatically response = await client.post( "/auth/login", data={ @@ -43,11 +43,8 @@ async def auth_client(client: AsyncClient) -> AsyncClient: "password": "password123", }, ) - data = response.json() - token = data["access_token"] - - # Set auth header - client.headers["Authorization"] = f"Bearer {token}" + assert response.status_code == 200 + assert response.cookies.get("access_token") is not None return client