diff --git a/backend/openedx_ai_extensions/api/v1/workflows/views.py b/backend/openedx_ai_extensions/api/v1/workflows/views.py index 3bfe99d2..57baa5ce 100644 --- a/backend/openedx_ai_extensions/api/v1/workflows/views.py +++ b/backend/openedx_ai_extensions/api/v1/workflows/views.py @@ -7,11 +7,9 @@ import logging from datetime import datetime, timezone -from django.contrib.auth.decorators import login_required from django.core.exceptions import ValidationError from django.http import JsonResponse, StreamingHttpResponse from django.utils.decorators import method_decorator -from django.views import View from rest_framework import status from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response @@ -33,27 +31,22 @@ logger = logging.getLogger(__name__) -@method_decorator(login_required, name="dispatch") -@method_decorator(handle_ai_errors, name="dispatch") -class AIGenericWorkflowView(View): +class AIGenericWorkflowView(APIView): """ AI Workflow API endpoint """ + permission_classes = [IsAuthenticated] + + @method_decorator(handle_ai_errors) def post(self, request): - """Common handler for GET and POST requests""" + """Handle POST requests to execute an AI workflow.""" context = get_context_from_request(request) workflow_profile = AIWorkflowScope.get_profile(**context) - request_body = {} - if request.body: - try: - request_body = json.loads(request.body.decode("utf-8")) - except json.JSONDecodeError as e: - raise ValidationError("Invalid JSON format in request body.") from e - action = request_body.get("action", "") - user_input = request_body.get("user_input", {}) + action = request.data.get("action", "") + user_input = request.data.get("user_input", {}) result = workflow_profile.execute( user_input=user_input, diff --git a/backend/openedx_ai_extensions/decorators.py b/backend/openedx_ai_extensions/decorators.py index af69e4b5..ccbbbe4d 100644 --- a/backend/openedx_ai_extensions/decorators.py +++ b/backend/openedx_ai_extensions/decorators.py @@ -12,11 +12,13 @@ APIConnectionError, AuthenticationError, ContextWindowExceededError, + NotFoundError, RateLimitError, ServiceUnavailableError, Timeout, ) from rest_framework import status +from rest_framework.exceptions import ParseError logger = logging.getLogger(__name__) @@ -28,6 +30,11 @@ "message": "The AI service is currently unavailable due to an authentication error.", "status": status.HTTP_500_INTERNAL_SERVER_ERROR, }, + NotFoundError: { + "code": "llm_config_error", + "message": "The AI service is misconfigured. Please check the LLM settings.", + "status": status.HTTP_500_INTERNAL_SERVER_ERROR, + }, RateLimitError: { "code": "rate_limit_exceeded", "message": "The AI service is currently busy. Please try again later.", @@ -48,6 +55,11 @@ "message": "The provided input or configuration is invalid.", "status": status.HTTP_400_BAD_REQUEST, }, + ParseError: { + "code": "parse_error", + "message": "The request body could not be parsed. Please check the JSON format.", + "status": status.HTTP_400_BAD_REQUEST, + }, } diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 01a80d44..c1d23049 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -12,6 +12,7 @@ from django.urls import reverse from opaque_keys.edx.keys import CourseKey from opaque_keys.edx.locator import BlockUsageLocator +from rest_framework.exceptions import ParseError from rest_framework.test import APIClient, APIRequestFactory # Mock the submissions module before any imports that depend on it @@ -33,6 +34,7 @@ AIWorkflowProfilesListView, AIWorkflowProfileView, ) +from openedx_ai_extensions.decorators import handle_ai_errors # noqa: E402 pylint: disable=wrong-import-position from openedx_ai_extensions.models import PromptTemplate # noqa: E402 pylint: disable=wrong-import-position from openedx_ai_extensions.workflows.models import ( # noqa: E402 pylint: disable=wrong-import-position AIWorkflowProfile, @@ -137,13 +139,9 @@ def test_workflows_endpoint_requires_authentication(api_client): # pylint: disa """ url = reverse("openedx_ai_extensions:api:v1:aiext_workflows") - # Test POST without authentication + # DRF IsAuthenticated with SessionAuthentication returns 403 (no WWW-Authenticate challenge) response = api_client.post(url, {}, format="json") - assert response.status_code == 302 # Redirect to login - - # Test GET without authentication - response = api_client.get(url) - assert response.status_code == 302 # Redirect to login + assert response.status_code == 403 @pytest.mark.django_db @@ -434,8 +432,8 @@ def test_workflows_post_with_invalid_json(api_client): # pylint: disable=redefi url, data="invalid json", content_type="application/json" ) - # Should return 400 or 500 for invalid JSON - assert response.status_code in [400, 500] + # ParseError is now mapped to 400 via handle_ai_errors + assert response.status_code == 400 data = response.json() assert "error" in data @@ -1407,3 +1405,21 @@ def test_profiles_list_view_unexpected_error(mock_list, api_client, course_key): response = api_client.get(url, {"context": context}) assert response.status_code == 500 assert response.json()["status"] == "error" + + +# ============================================================================ +# Unit Tests - handle_ai_errors decorator +# ============================================================================ + + +def test_handle_ai_errors_maps_parse_error_to_400(): + """DRF ParseError (malformed JSON body) is mapped to HTTP 400, not 500.""" + @handle_ai_errors + def fake_view(request): + raise ParseError("bad json") + + response = fake_view(Mock()) + assert response.status_code == 400 + data = json.loads(response.content) + assert data["error"]["code"] == "parse_error" + assert data["status"] == "error"