Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ jobs:
run: black --check --diff app/ tests/

- name: Check import ordering with isort
run: isort --check-only --diff app/ tests/
run: isort --check-only --diff --profile black app/ tests/

- name: Lint with flake8
run: flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402
run: flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402,E203

test:
name: Test
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ test-cov:

lint:
black --check --diff app/ tests/
isort --check-only --diff app/ tests/
flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402
isort --check-only --diff --profile black app/ tests/
flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402,E203

format:
black app/ tests/
isort app/ tests/
isort --profile black app/ tests/

clean:
find . -type d -name "__pycache__" -exec rm -r {} +
Expand Down
8 changes: 5 additions & 3 deletions app/api/api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fastapi import APIRouter

from app.api.endpoints import auth, users, knowledge_bases, documents, messages
from app.api.endpoints import auth, documents, knowledge_bases, messages, users

api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(knowledge_bases.router, prefix="/knowledge-bases", tags=["knowledge-bases"])
api_router.include_router(
knowledge_bases.router, prefix="/knowledge-bases", tags=["knowledge-bases"]
)
api_router.include_router(documents.router, prefix="/documents", tags=["documents"])
api_router.include_router(messages.router, prefix="/messages", tags=["messages"])
api_router.include_router(messages.router, prefix="/messages", tags=["messages"])
34 changes: 21 additions & 13 deletions app/api/deps.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,40 @@
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from datetime import datetime, timedelta
from typing import Optional

from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.orm import Session

from app.core.config import settings
from app.db.database import get_db
from app.db.models.user import User
from app.core.config import settings
from app.schemas.user import UserResponse

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")

async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> UserResponse:

async def get_current_user(
token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)
) -> UserResponse:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="Invalid token")


# get db
# get db
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=401, detail="User not found")

return UserResponse.model_validate(user)
except JWTError:
raise HTTPException(status_code=401, detail="Invalid token")


def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None) -> str:
"""
Create a new JWT access token for a user
Expand All @@ -39,11 +45,13 @@ def create_access_token(user_id: str, expires_delta: Optional[timedelta] = None)
str: JWT access token
"""
to_encode = {"sub": user_id}

if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)

expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)

to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
18 changes: 8 additions & 10 deletions app/api/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from datetime import timedelta

from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from datetime import timedelta
from sqlalchemy.orm import Session

from app.services.user_service import UserService
from app.api.deps import create_access_token
from app.core.config import settings
from app.db.database import get_db
from app.services.user_service import UserService

router = APIRouter()


@router.post("/token")
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
):
"""
OAuth2 compatible token login, get an access token for future requests
Expand All @@ -30,10 +31,7 @@ async def login_for_access_token(
# Create access token
access_token = create_access_token(
user_id=str(user.id),
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
)

return {
"access_token": access_token,
"token_type": "bearer"
}

return {"access_token": access_token, "token_type": "bearer"}
52 changes: 35 additions & 17 deletions app/api/endpoints/conversations.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,100 @@
import logging
from functools import lru_cache
from typing import List

from fastapi import APIRouter, Body, Depends
from fastapi.responses import JSONResponse
from functools import lru_cache
from sqlalchemy.orm import Session

from app.schemas.conversation import ConversationCreate, ConversationUpdate, ConversationResponse
from app.api.deps import get_current_user
from app.api.endpoints.knowledge_bases import get_knowledge_base_service
from app.db.database import get_db
from app.repositories.conversation_repository import ConversationRepository
from app.schemas.conversation import (
ConversationCreate,
ConversationResponse,
ConversationUpdate,
)
from app.schemas.user import UserResponse
from app.services.conversation_service import ConversationService
from app.repositories.conversation_repository import ConversationRepository
from app.services.knowledge_base_service import KnowledgeBaseService
from app.api.endpoints.knowledge_bases import get_knowledge_base_service
from app.db.database import get_db
import logging

router = APIRouter()
logger = logging.getLogger(__name__)


@lru_cache()
def get_conversation_repository() -> ConversationRepository:
"""Get conversation repository instance"""
return ConversationRepository()


def get_conversation_service(
conversation_repository: ConversationRepository = Depends(get_conversation_repository),
conversation_repository: ConversationRepository = Depends(
get_conversation_repository
),
knowledge_base_service: KnowledgeBaseService = Depends(get_knowledge_base_service),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
) -> ConversationService:
"""Get conversation service instance"""
return ConversationService(
conversation_repository=conversation_repository,
knowledge_base_service=knowledge_base_service,
db=db
db=db,
)


@router.post("", response_model=ConversationResponse)
async def create_conversation(
payload: ConversationCreate = Body(..., description="Conversation details"),
current_user: UserResponse = Depends(get_current_user),
conversation_service: ConversationService = Depends(get_conversation_service)
conversation_service: ConversationService = Depends(get_conversation_service),
):
"""Create a new conversation"""
return await conversation_service.create_conversation(payload, current_user)


@router.get("", response_model=List[ConversationResponse])
async def list_conversations(
current_user: UserResponse = Depends(get_current_user),
conversation_service: ConversationService = Depends(get_conversation_service)
conversation_service: ConversationService = Depends(get_conversation_service),
):
"""List all conversations for the current user"""
logger.info(f"Listing conversations for user {current_user.id}")
return await conversation_service.list_conversations(current_user)


@router.get("/{conversation_id}", response_model=ConversationResponse)
async def get_conversation(
conversation_id: str,
current_user: UserResponse = Depends(get_current_user),
conversation_service: ConversationService = Depends(get_conversation_service)
conversation_service: ConversationService = Depends(get_conversation_service),
):
"""Get conversation details including messages"""
return await conversation_service.get_conversation(conversation_id, current_user)


@router.put("/{conversation_id}", response_model=ConversationResponse)
async def update_conversation(
conversation_id: str,
conversation_update: ConversationUpdate = Body(..., description="Conversation details"),
conversation_update: ConversationUpdate = Body(
..., description="Conversation details"
),
current_user: UserResponse = Depends(get_current_user),
conversation_service: ConversationService = Depends(get_conversation_service)
conversation_service: ConversationService = Depends(get_conversation_service),
):
"""Update conversation details"""
return await conversation_service.update_conversation(conversation_id, conversation_update, current_user)
return await conversation_service.update_conversation(
conversation_id, conversation_update, current_user
)


@router.delete("/{conversation_id}")
async def delete_conversation(
conversation_id: str,
current_user: UserResponse = Depends(get_current_user),
conversation_service: ConversationService = Depends(get_conversation_service)
conversation_service: ConversationService = Depends(get_conversation_service),
):
"""Delete a conversation and all its messages"""
await conversation_service.delete_conversation(conversation_id, current_user)
return JSONResponse(content={"message": "Conversation deleted successfully"})
return JSONResponse(content={"message": "Conversation deleted successfully"})
Loading
Loading