Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions app/auth/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@
from app.database import get_db
from app.models import User, UserRole
from passlib.context import CryptContext
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator

router = APIRouter(prefix="/auth", tags=["authentication"])


class UserCreate(BaseModel):
username: str = Field(..., min_length=3, max_length=50)
email: str
password: str
role: UserRole

@validator('role')
@field_validator('role')
def validate_role(cls, v):
if v not in [UserRole.admin, UserRole.case_worker]:
raise ValueError('Role must be either admin or case_worker')
return v


class UserResponse(BaseModel):
username: str
email: str
Expand All @@ -31,6 +33,7 @@ class UserResponse(BaseModel):
class Config:
from_attributes = True


# Configuration
SECRET_KEY = "your-secret-key-here"
ALGORITHM = "HS256"
Expand All @@ -39,18 +42,22 @@ class Config:
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")


def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


def authenticate_user(db: Session, username: str, password: str) -> Optional[User]:
user = db.query(User).filter(User.username == username).first()
if not user or not verify_password(password, user.hashed_password):
return None
return user


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
Expand All @@ -61,9 +68,10 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


async def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand All @@ -77,12 +85,13 @@ async def get_current_user(
raise credentials_exception
except JWTError:
raise credentials_exception

user = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user


def get_admin_user(current_user: User = Depends(get_current_user)):
if current_user.role != UserRole.admin:
raise HTTPException(
Expand All @@ -91,10 +100,11 @@ def get_admin_user(current_user: User = Depends(get_current_user)):
)
return current_user


@router.post("/token")
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
):
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
Expand All @@ -109,11 +119,12 @@ async def login_for_access_token(
)
return {"access_token": access_token, "token_type": "bearer"}


@router.post("/users", response_model=UserResponse)
async def create_user(
user_data: UserCreate,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
user_data: UserCreate,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Create a new user (admin only)"""
# Check if username exists
Expand All @@ -122,7 +133,7 @@ async def create_user(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
)

# Check if email exists
if db.query(User).filter(User.email == user_data.email).first():
raise HTTPException(
Expand All @@ -137,7 +148,7 @@ async def create_user(
hashed_password=get_password_hash(user_data.password),
role=user_data.role
)

try:
db.add(db_user)
db.commit()
Expand All @@ -148,4 +159,4 @@ async def create_user(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
)
23 changes: 7 additions & 16 deletions app/clients/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
Router module for client-related endpoints.
Handles all HTTP requests for client operations including create, read, update, and delete.
"""

from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import List, Optional
from app.auth.router import get_current_user, get_admin_user
from app.models import User, UserRole
from fastapi import APIRouter, Depends, status, Query
from sqlalchemy.orm import Session

from app.database import get_db
from app.clients.service.client_service import ClientService
Expand All @@ -23,17 +20,18 @@

@router.get("/", response_model=ClientListResponse)
async def get_clients(
current_user: User = Depends(get_admin_user),
skip: int = Query(default=0, ge=0, description="Number of records to skip"),
limit: int = Query(default=50, ge=1, le=150, description="Maximum number of records to return"),
db: Session = Depends(get_db)
):
"""
Get a list of clients with pagination.
"""
return ClientService.get_clients(db, skip, limit)

@router.get("/{client_id}", response_model=ClientResponse)
async def get_client(
client_id: int,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Get a specific client by ID"""
Expand Down Expand Up @@ -65,7 +63,6 @@ async def get_clients_by_criteria(
substance_use: Optional[bool] = None,
time_unemployed: Optional[int] = Query(None, ge=0),
need_mental_health_support_bool: Optional[bool] = None,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Search clients by any combination of criteria"""
Expand Down Expand Up @@ -106,7 +103,6 @@ async def get_clients_by_services(
employment_related_financial_supports: Optional[bool] = None,
employer_financial_supports: Optional[bool] = None,
enhanced_referrals: Optional[bool] = None,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Get clients filtered by multiple service statuses"""
Expand All @@ -124,7 +120,6 @@ async def get_clients_by_services(
@router.get("/{client_id}/services", response_model=List[ServiceResponse])
async def get_client_services(
client_id: int,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Get all services and their status for a specific client, including case worker info"""
Expand All @@ -133,7 +128,6 @@ async def get_client_services(
@router.get("/search/success-rate", response_model=List[ClientResponse])
async def get_clients_by_success_rate(
min_rate: int = Query(70, ge=0, le=100, description="Minimum success rate percentage"),
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Get clients with success rate above specified threshold"""
Expand All @@ -142,16 +136,15 @@ async def get_clients_by_success_rate(
@router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse])
async def get_clients_by_case_worker(
case_worker_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get clients by caseworker id."""
return ClientService.get_clients_by_case_worker(db, case_worker_id)

@router.put("/{client_id}", response_model=ClientResponse)
async def update_client(
client_id: int,
client_data: ClientUpdate,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Update a client's information"""
Expand All @@ -162,16 +155,15 @@ async def update_client_services(
client_id: int,
user_id: int,
service_update: ServiceUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Update a client services"""
return ClientService.update_client_services(db, client_id, user_id, service_update)

@router.post("/{client_id}/case-assignment", response_model=ServiceResponse)
async def create_case_assignment(
client_id: int,
case_worker_id: int = Query(..., description="Case worker ID to assign"),
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Create a new case assignment for a client with a case worker"""
Expand All @@ -180,7 +172,6 @@ async def create_case_assignment(
@router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_client(
client_id: int,
current_user: User = Depends(get_admin_user),
db: Session = Depends(get_db)
):
"""Delete a client"""
Expand Down
2 changes: 1 addition & 1 deletion app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy.orm import sessionmaker

#Here is where the database is located
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"

#Open up a connection so that we are able to use the database
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
Expand Down
8 changes: 5 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
This module initializes the FastAPI application and includes all routers.
Handles database initialization and CORS middleware configuration.
"""

# Third-party imports
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Local application imports
from app import models
from app.database import engine
from app.clients.router import router as clients_router
from app.auth.router import router as auth_router
from fastapi.middleware.cors import CORSMiddleware

# Initialize database tables
models.Base.metadata.create_all(bind=engine)

# Create FastAPI application
app = FastAPI(title="Case Management API", description="API for managing client cases", version="1.0.0")
app = FastAPI(title="Case Management API",
description="API for managing client cases", version="1.0.0")

# Include routers
app.include_router(auth_router)
Expand Down
Loading