From cdf22fc334e87e6ba6e7225ef83f5adc71478f6a Mon Sep 17 00:00:00 2001 From: Polo Date: Wed, 25 Feb 2026 15:08:55 +0200 Subject: [PATCH 1/2] Add delete and checksum validation endpoints with client methods --- app/api/v1/endpoints/artifacts.py | 46 ++++++- app/grpc/clients/artifact_storage_client.py | 57 +++++++++ tests/unit/test_artifact_client.py | 105 ++++++++++++++++ tests/unit/test_artifact_endpoints.py | 126 ++++++++++++++++++++ 4 files changed, 333 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_artifact_client.py create mode 100644 tests/unit/test_artifact_endpoints.py diff --git a/app/api/v1/endpoints/artifacts.py b/app/api/v1/endpoints/artifacts.py index 171da84..0015651 100644 --- a/app/api/v1/endpoints/artifacts.py +++ b/app/api/v1/endpoints/artifacts.py @@ -1,6 +1,6 @@ """Artifact management endpoints.""" -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from typing import Optional @@ -190,3 +190,47 @@ async def list_artifacts( status_code=500, detail=f"Failed to list artifacts: {str(e)}" ) + + +@router.delete("/{artifact_id}", status_code=status.HTTP_200_OK) +async def delete_artifact( + artifact_id: str, + current_user: User = Depends(deps.get_current_user), + client: ArtifactStorageClient = Depends(get_artifact_storage_client) +): + """ + Delete an artifact. + """ + try: + success = await client.delete_artifact(artifact_id) + if success: + return {"status": "deleted", "artifact_id": artifact_id} + else: + raise HTTPException( + status_code=500, + detail="Failed to delete artifact" + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to delete artifact: {str(e)}" + ) + + +@router.post("/{artifact_id}/validate-checksum", response_model=dict) +async def validate_checksum( + artifact_id: str, + current_user: User = Depends(deps.get_current_user), + client: ArtifactStorageClient = Depends(get_artifact_storage_client) +): + """ + Validate artifact checksum integrity. + """ + try: + validation_result = await client.validate_checksum(artifact_id) + return validation_result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to validate checksum: {str(e)}" + ) diff --git a/app/grpc/clients/artifact_storage_client.py b/app/grpc/clients/artifact_storage_client.py index 679fb93..f16a431 100644 --- a/app/grpc/clients/artifact_storage_client.py +++ b/app/grpc/clients/artifact_storage_client.py @@ -264,6 +264,63 @@ async def list_artifacts( logger.error(f"gRPC error listing artifacts: {e.code()} - {e.details()}") raise + async def delete_artifact(self, artifact_id: str) -> bool: + """ + Delete an artifact. + + Args: + artifact_id: Artifact ID to delete + + Returns: + True if successful + """ + if not self.stub: + raise RuntimeError("Client not connected. Call connect() first.") + + try: + from app.grpc.generated import artifact_pb2 + + request = artifact_pb2.DeleteArtifactRequest(artifact_id=artifact_id) + + metadata = (("x-api-key", self.api_key),) + response = await self.stub.DeleteArtifact(request, metadata=metadata) + + return response.status == "deleted" + except grpc.RpcError as e: + logger.error(f"gRPC error deleting artifact: {e.code()} - {e.details()}") + raise + + async def validate_checksum(self, artifact_id: str) -> Dict[str, Any]: + """ + Validate artifact checksum integrity. + + Args: + artifact_id: Artifact ID to validate + + Returns: + Dict with validation results + """ + if not self.stub: + raise RuntimeError("Client not connected. Call connect() first.") + + try: + from app.grpc.generated import artifact_pb2 + + request = artifact_pb2.ValidateChecksumRequest(artifact_id=artifact_id) + + metadata = (("x-api-key", self.api_key),) + response = await self.stub.ValidateChecksum(request, metadata=metadata) + + return { + "valid": response.valid, + "expected_checksum": response.expected_checksum, + "actual_checksum": response.actual_checksum if response.HasField("actual_checksum") else None, + "validation_time": response.validation_time.ToDatetime().isoformat() if response.HasField("validation_time") else None + } + except grpc.RpcError as e: + logger.error(f"gRPC error validating checksum: {e.code()} - {e.details()}") + raise + async def close(self): """Close gRPC connection.""" if self.channel: diff --git a/tests/unit/test_artifact_client.py b/tests/unit/test_artifact_client.py new file mode 100644 index 0000000..4f0ac76 --- /dev/null +++ b/tests/unit/test_artifact_client.py @@ -0,0 +1,105 @@ +"""Unit tests for Artifact Storage gRPC client.""" + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.grpc.clients.artifact_storage_client import ArtifactStorageClient + + +@pytest.mark.asyncio +async def test_delete_artifact_success(): + """Test successful artifact deletion.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with patch("app.grpc.clients.artifact_storage_client.grpc.aio.insecure_channel"), \ + patch("app.grpc.generated.artifact_pb2_grpc.ArtifactServiceStub") as mock_stub: + + # Mock successful response + mock_response = MagicMock() + mock_response.status = "deleted" + mock_stub.return_value.DeleteArtifact.return_value = mock_response + + await client.connect() + result = await client.delete_artifact("test-artifact-id") + + assert result is True + + +@pytest.mark.asyncio +async def test_delete_artifact_failure(): + """Test artifact deletion failure.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with patch("app.grpc.clients.artifact_storage_client.grpc.aio.insecure_channel"), \ + patch("app.grpc.generated.artifact_pb2_grpc.ArtifactServiceStub") as mock_stub: + + # Mock failed response + mock_response = MagicMock() + mock_response.status = "error" + mock_stub.return_value.DeleteArtifact.return_value = mock_response + + await client.connect() + result = await client.delete_artifact("test-artifact-id") + + assert result is False + + +@pytest.mark.asyncio +async def test_validate_checksum_success(): + """Test successful checksum validation.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with patch("app.grpc.clients.artifact_storage_client.grpc.aio.insecure_channel"), \ + patch("app.grpc.generated.artifact_pb2_grpc.ArtifactServiceStub") as mock_stub, \ + patch("app.grpc.clients.artifact_storage_client.datetime") as mock_datetime: + + # Mock successful response + mock_response = MagicMock() + mock_response.valid = True + mock_response.expected_checksum = "abc123" + mock_response.actual_checksum = "abc123" + mock_response.HasField.return_value = True + mock_response.validation_time.ToDatetime.return_value.isoformat.return_value = "2026-02-25T15:00:00" + mock_stub.return_value.ValidateChecksum.return_value = mock_response + + await client.connect() + result = await client.validate_checksum("test-artifact-id") + + assert result["valid"] is True + assert result["expected_checksum"] == "abc123" + assert result["actual_checksum"] == "abc123" + assert result["validation_time"] == "2026-02-25T15:00:00" + + +@pytest.mark.asyncio +async def test_validate_checksum_failure(): + """Test checksum validation failure.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with patch("app.grpc.clients.artifact_storage_client.grpc.aio.insecure_channel"), \ + patch("app.grpc.generated.artifact_pb2_grpc.ArtifactServiceStub") as mock_stub: + + # Mock failed response + mock_response = MagicMock() + mock_response.valid = False + mock_response.expected_checksum = "abc123" + mock_response.HasField.return_value = False + mock_stub.return_value.ValidateChecksum.return_value = mock_response + + await client.connect() + result = await client.validate_checksum("test-artifact-id") + + assert result["valid"] is False + assert result["expected_checksum"] == "abc123" + assert result["actual_checksum"] is None + + +@pytest.mark.asyncio +async def test_client_not_connected_error(): + """Test error when client is not connected.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with pytest.raises(RuntimeError, match="Client not connected"): + await client.delete_artifact("test-id") + + with pytest.raises(RuntimeError, match="Client not connected"): + await client.validate_checksum("test-id") diff --git a/tests/unit/test_artifact_endpoints.py b/tests/unit/test_artifact_endpoints.py new file mode 100644 index 0000000..61b5d7f --- /dev/null +++ b/tests/unit/test_artifact_endpoints.py @@ -0,0 +1,126 @@ +"""Unit tests for artifact endpoints.""" + +import pytest +from unittest.mock import AsyncMock, patch +from fastapi.testclient import TestClient +from app.main import app +from app.api.v1.endpoints.artifacts import ( + delete_artifact, + validate_checksum, + UploadURLRequest, + UploadURLResponse, + ArtifactMetadata +) + +client = TestClient(app) + + +@pytest.mark.asyncio +async def test_delete_artifact_success(): + """Test successful artifact deletion.""" + mock_client = AsyncMock() + mock_client.delete_artifact.return_value = True + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + result = await delete_artifact("test-artifact-id", mock_client, mock_client) + + assert result["status"] == "deleted" + assert result["artifact_id"] == "test-artifact-id" + mock_client.delete_artifact.assert_called_once_with("test-artifact-id") + + +@pytest.mark.asyncio +async def test_delete_artifact_failure(): + """Test artifact deletion failure.""" + mock_client = AsyncMock() + mock_client.delete_artifact.return_value = False + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + with pytest.raises(Exception): # Should raise HTTPException + await delete_artifact("test-artifact-id", mock_client, mock_client) + + +@pytest.mark.asyncio +async def test_validate_checksum_success(): + """Test successful checksum validation.""" + mock_client = AsyncMock() + mock_client.validate_checksum.return_value = { + "valid": True, + "expected_checksum": "abc123", + "actual_checksum": "abc123", + "validation_time": "2026-02-25T15:00:00" + } + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + result = await validate_checksum("test-artifact-id", mock_client, mock_client) + + assert result["valid"] is True + assert result["expected_checksum"] == "abc123" + assert result["actual_checksum"] == "abc123" + mock_client.validate_checksum.assert_called_once_with("test-artifact-id") + + +@pytest.mark.asyncio +async def test_validate_checksum_failure(): + """Test checksum validation failure.""" + mock_client = AsyncMock() + mock_client.validate_checksum.side_effect = Exception("Storage error") + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + with pytest.raises(Exception): # Should raise HTTPException + await validate_checksum("test-artifact-id", mock_client, mock_client) + + +# Test the request/response models +def test_upload_url_request_model(): + """Test UploadURLRequest model validation.""" + # Valid request + request = UploadURLRequest( + name="test.txt", + bucket="test-bucket", + key="test-key", + content_type="text/plain" + ) + assert request.name == "test.txt" + assert request.bucket == "test-bucket" + assert request.key == "test-key" + assert request.content_type == "text/plain" + + +def test_upload_url_response_model(): + """Test UploadURLResponse model validation.""" + response = UploadURLResponse( + artifact_id="test-id", + upload_url="https://example.com/upload" + ) + assert response.artifact_id == "test-id" + assert response.upload_url == "https://example.com/upload" + + +def test_artifact_metadata_model(): + """Test ArtifactMetadata model validation.""" + metadata = ArtifactMetadata( + id="test-id", + name="test.txt", + bucket="test-bucket", + key="test-key", + size=1024, + content_type="text/plain", + checksum_sha256="abc123", + pipeline_id="pipeline-1", + project_id="project-1" + ) + assert metadata.id == "test-id" + assert metadata.name == "test.txt" + assert metadata.size == 1024 + assert metadata.checksum_sha256 == "abc123" + assert metadata.pipeline_id == "pipeline-1" + assert metadata.project_id == "project-1" From 657a47507aa25d05d3b0927068912bb67cf841c0 Mon Sep 17 00:00:00 2001 From: Polo Date: Wed, 25 Feb 2026 15:13:11 +0200 Subject: [PATCH 2/2] Add multipart upload endpoints and client methods --- app/api/v1/endpoints/artifacts.py | 135 ++++++++++++++++- app/grpc/clients/artifact_storage_client.py | 132 ++++++++++++++++ tests/unit/test_multipart_endpoints.py | 158 ++++++++++++++++++++ tests/unit/test_multipart_upload.py | 101 +++++++++++++ 4 files changed, 525 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_multipart_endpoints.py create mode 100644 tests/unit/test_multipart_upload.py diff --git a/app/api/v1/endpoints/artifacts.py b/app/api/v1/endpoints/artifacts.py index 0015651..f449670 100644 --- a/app/api/v1/endpoints/artifacts.py +++ b/app/api/v1/endpoints/artifacts.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field -from typing import Optional +from typing import Optional, List from app.core.grpc_clients import get_artifact_storage_client from app.grpc.clients.artifact_storage_client import ArtifactStorageClient @@ -234,3 +234,136 @@ async def validate_checksum( status_code=500, detail=f"Failed to validate checksum: {str(e)}" ) + + +# Multipart Upload Request/Response Models +class InitiateMultipartRequest(BaseModel): + """Request to initiate multipart upload.""" + name: str = Field(..., description="Artifact name") + bucket: str = Field(..., description="S3 bucket name") + key: str = Field(..., description="S3 object key") + content_type: str = Field(default="application/octet-stream", description="MIME type") + pipeline_id: Optional[str] = Field(None, description="Associated pipeline ID") + execution_id: Optional[str] = Field(None, description="Associated execution ID") + project_id: Optional[str] = Field(None, description="Associated project ID") + retention_days: Optional[int] = Field(None, description="Retention period in days") + storage_class: Optional[str] = Field(None, description="S3 storage class") + + +class InitiateMultipartResponse(BaseModel): + """Response with upload ID and artifact ID.""" + artifact_id: str = Field(..., description="Unique artifact ID") + upload_id: str = Field(..., description="Multipart upload ID") + + +class MultipartPartURLRequest(BaseModel): + """Request for multipart part upload URL.""" + artifact_id: str = Field(..., description="Artifact ID") + upload_id: str = Field(..., description="Multipart upload ID") + part_number: int = Field(..., description="Part number", ge=1) + expires_in: int = Field(default=3600, description="URL expiration (seconds)") + + +class MultipartPartURLResponse(BaseModel): + """Response with part upload URL.""" + part_url: str = Field(..., description="Pre-signed URL for part upload") + expires_at: Optional[str] = Field(None, description="URL expiration time") + + +class CompletePart(BaseModel): + """Completed multipart part.""" + part_number: int = Field(..., description="Part number", ge=1) + etag: str = Field(..., description="Part ETag from S3") + + +class CompleteMultipartRequest(BaseModel): + """Request to complete multipart upload.""" + artifact_id: str = Field(..., description="Artifact ID") + upload_id: str = Field(..., description="Multipart upload ID") + parts: List[CompletePart] = Field(..., description="List of completed parts") + + +class CompleteMultipartResponse(BaseModel): + """Response with completion status.""" + status: str = Field(..., description="Upload completion status") + version_id: Optional[str] = Field(None, description="S3 object version ID") + + +# Multipart Upload Endpoints +@router.post("/multipart/initiate", response_model=InitiateMultipartResponse, status_code=status.HTTP_200_OK) +async def initiate_multipart_upload( + request: InitiateMultipartRequest, + current_user: User = Depends(deps.get_current_user), + client: ArtifactStorageClient = Depends(get_artifact_storage_client) +): + """ + Initiate multipart upload for large files. + """ + try: + result = await client.initiate_multipart_upload( + name=request.name, + bucket=request.bucket, + key=request.key, + content_type=request.content_type, + pipeline_id=request.pipeline_id, + execution_id=request.execution_id, + project_id=request.project_id, + retention_days=request.retention_days, + storage_class=request.storage_class + ) + return result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to initiate multipart upload: {str(e)}" + ) + + +@router.post("/multipart/part-url", response_model=MultipartPartURLResponse, status_code=status.HTTP_200_OK) +async def get_multipart_part_url( + request: MultipartPartURLRequest, + current_user: User = Depends(deps.get_current_user), + client: ArtifactStorageClient = Depends(get_artifact_storage_client) +): + """ + Get pre-signed URL for uploading a multipart part. + """ + try: + result = await client.get_multipart_part_url( + artifact_id=request.artifact_id, + upload_id=request.upload_id, + part_number=request.part_number, + expires_in=request.expires_in + ) + return result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to get multipart part URL: {str(e)}" + ) + + +@router.post("/multipart/complete", response_model=CompleteMultipartResponse, status_code=status.HTTP_200_OK) +async def complete_multipart_upload( + request: CompleteMultipartRequest, + current_user: User = Depends(deps.get_current_user), + client: ArtifactStorageClient = Depends(get_artifact_storage_client) +): + """ + Complete multipart upload and assemble final object. + """ + try: + # Convert Pydantic parts to dict for gRPC client + parts_dict = [{"part_number": part.part_number, "etag": part.etag} for part in request.parts] + + result = await client.complete_multipart_upload( + artifact_id=request.artifact_id, + upload_id=request.upload_id, + parts=parts_dict + ) + return result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to complete multipart upload: {str(e)}" + ) diff --git a/app/grpc/clients/artifact_storage_client.py b/app/grpc/clients/artifact_storage_client.py index f16a431..c0d0b37 100644 --- a/app/grpc/clients/artifact_storage_client.py +++ b/app/grpc/clients/artifact_storage_client.py @@ -321,6 +321,138 @@ async def validate_checksum(self, artifact_id: str) -> Dict[str, Any]: logger.error(f"gRPC error validating checksum: {e.code()} - {e.details()}") raise + async def initiate_multipart_upload(self, name: str, bucket: str, key: str, content_type: str = None, + pipeline_id: str = None, execution_id: str = None, + project_id: str = None, retention_days: int = None, + storage_class: str = None) -> Dict[str, Any]: + """ + Initiate multipart upload for large files. + + Args: + name: Artifact name + bucket: S3 bucket name + key: S3 object key + content_type: MIME type + pipeline_id: Associated pipeline ID + execution_id: Associated execution ID + project_id: Associated project ID + retention_days: Retention period in days + storage_class: S3 storage class + + Returns: + Dict with upload ID and artifact ID + """ + if not self.stub: + raise RuntimeError("Client not connected. Call connect() first.") + + try: + from app.grpc.generated import artifact_pb2 + + request = artifact_pb2.InitiateMultipartRequest( + name=name, + bucket=bucket, + key=key, + content_type=content_type or "application/octet-stream", + pipeline_id=pipeline_id, + execution_id=execution_id, + project_id=project_id, + retention_days=retention_days, + storage_class=storage_class + ) + + metadata = (("x-api-key", self.api_key),) + response = await self.stub.InitiateMultipart(request, metadata=metadata) + + return { + "artifact_id": response.artifact_id, + "upload_id": response.upload_id + } + except grpc.RpcError as e: + logger.error(f"gRPC error initiating multipart upload: {e.code()} - {e.details()}") + raise + + async def get_multipart_part_url(self, artifact_id: str, upload_id: str, part_number: int, + expires_in: int = 3600) -> Dict[str, Any]: + """ + Get pre-signed URL for multipart upload part. + + Args: + artifact_id: Artifact ID + upload_id: Multipart upload ID + part_number: Part number + expires_in: URL expiration in seconds + + Returns: + Dict with part upload URL + """ + if not self.stub: + raise RuntimeError("Client not connected. Call connect() first.") + + try: + from app.grpc.generated import artifact_pb2 + + request = artifact_pb2.MultipartPartURLRequest( + artifact_id=artifact_id, + upload_id=upload_id, + part_number=part_number, + expires_in=expires_in + ) + + metadata = (("x-api-key", self.api_key),) + response = await self.stub.GetMultipartPartURL(request, metadata=metadata) + + return { + "part_url": response.part_url, + "expires_at": response.expires_at.ToDatetime().isoformat() if response.HasField("expires_at") else None + } + except grpc.RpcError as e: + logger.error(f"gRPC error getting multipart part URL: {e.code()} - {e.details()}") + raise + + async def complete_multipart_upload(self, artifact_id: str, upload_id: str, parts: list) -> Dict[str, Any]: + """ + Complete multipart upload. + + Args: + artifact_id: Artifact ID + upload_id: Multipart upload ID + parts: List of completed parts with ETags + + Returns: + Dict with completion status and version info + """ + if not self.stub: + raise RuntimeError("Client not connected. Call connect() first.") + + try: + from app.grpc.generated import artifact_pb2 + + # Convert parts to protobuf format + grpc_parts = [] + for part in parts: + grpc_part = artifact_pb2.CompletePart( + part_number=part.get("part_number"), + etag=part.get("etag") + ) + grpc_parts.append(grpc_part) + + request = artifact_pb2.CompleteMultipartRequest( + artifact_id=artifact_id, + upload_id=upload_id, + parts=grpc_parts + ) + + metadata = (("x-api-key", self.api_key),) + response = await self.stub.CompleteMultipart(request, metadata=metadata) + + return { + "status": "completed", + "version_id": response.version_id if response.HasField("version_id") else None + } + except grpc.RpcError as e: + logger.error(f"gRPC error completing multipart upload: {e.code()} - {e.details()}") + raise + async def close(self): """Close gRPC connection.""" if self.channel: diff --git a/tests/unit/test_multipart_endpoints.py b/tests/unit/test_multipart_endpoints.py new file mode 100644 index 0000000..b389214 --- /dev/null +++ b/tests/unit/test_multipart_endpoints.py @@ -0,0 +1,158 @@ +"""Unit tests for multipart upload endpoints.""" + +import pytest +from unittest.mock import AsyncMock, patch +from app.api.v1.endpoints.artifacts import ( + initiate_multipart_upload, + get_multipart_part_url, + complete_multipart_upload, + InitiateMultipartRequest, + InitiateMultipartResponse, + MultipartPartURLRequest, + MultipartPartURLResponse, + CompletePart, + CompleteMultipartRequest, + CompleteMultipartResponse +) + + +@pytest.mark.asyncio +async def test_initiate_multipart_endpoint_success(): + """Test successful multipart upload initiation endpoint.""" + mock_client = AsyncMock() + mock_client.initiate_multipart_upload.return_value = { + "artifact_id": "test-artifact-id", + "upload_id": "test-upload-id" + } + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + request = InitiateMultipartRequest( + name="large-file.zip", + bucket="test-bucket", + key="test-key", + content_type="application/zip" + ) + + result = await initiate_multipart_upload(request, mock_client, mock_client) + + assert result["artifact_id"] == "test-artifact-id" + assert result["upload_id"] == "test-upload-id" + mock_client.initiate_multipart_upload.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_multipart_part_url_endpoint_success(): + """Test successful multipart part URL endpoint.""" + mock_client = AsyncMock() + mock_client.get_multipart_part_url.return_value = { + "part_url": "https://example.com/upload-part", + "expires_at": "2026-02-25T16:00:00" + } + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + request = MultipartPartURLRequest( + artifact_id="test-artifact-id", + upload_id="test-upload-id", + part_number=1 + ) + + result = await get_multipart_part_url(request, mock_client, mock_client) + + assert result["part_url"] == "https://example.com/upload-part" + assert result["expires_at"] == "2026-02-25T16:00:00" + mock_client.get_multipart_part_url.assert_called_once() + + +@pytest.mark.asyncio +async def test_complete_multipart_endpoint_success(): + """Test successful multipart upload completion endpoint.""" + mock_client = AsyncMock() + mock_client.complete_multipart_upload.return_value = { + "status": "completed", + "version_id": "test-version-id" + } + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + parts = [ + CompletePart(part_number=1, etag="etag1"), + CompletePart(part_number=2, etag="etag2") + ] + request = CompleteMultipartRequest( + artifact_id="test-artifact-id", + upload_id="test-upload-id", + parts=parts + ) + + result = await complete_multipart_upload(request, mock_client, mock_client) + + assert result["status"] == "completed" + assert result["version_id"] == "test-version-id" + mock_client.complete_multipart_upload.assert_called_once() + + +@pytest.mark.asyncio +async def test_multipart_endpoint_error_handling(): + """Test multipart upload endpoint error handling.""" + mock_client = AsyncMock() + mock_client.initiate_multipart_upload.side_effect = Exception("Storage error") + + with patch("app.api.v1.endpoints.artifacts.get_artifact_storage_client", return_value=mock_client), \ + patch("app.api.v1.endpoints.artifacts.deps.get_current_user", return_value={"id": 1}): + + request = InitiateMultipartRequest( + name="test.txt", + bucket="test-bucket", + key="test-key" + ) + + with pytest.raises(Exception): # Should raise HTTPException + await initiate_multipart_upload(request, mock_client, mock_client) + + +# Test request/response models +def test_initiate_multipart_request_model(): + """Test InitiateMultipartRequest model validation.""" + request = InitiateMultipartRequest( + name="large-file.zip", + bucket="test-bucket", + key="test-key", + content_type="application/zip", + pipeline_id="pipeline-1", + retention_days=30 + ) + assert request.name == "large-file.zip" + assert request.bucket == "test-bucket" + assert request.content_type == "application/zip" + assert request.pipeline_id == "pipeline-1" + assert request.retention_days == 30 + + +def test_complete_part_model(): + """Test CompletePart model validation.""" + part = CompletePart(part_number=1, etag="test-etag") + assert part.part_number == 1 + assert part.etag == "test-etag" + + +def test_complete_multipart_request_model(): + """Test CompleteMultipartRequest model validation.""" + parts = [ + CompletePart(part_number=1, etag="etag1"), + CompletePart(part_number=2, etag="etag2") + ] + request = CompleteMultipartRequest( + artifact_id="test-id", + upload_id="upload-id", + parts=parts + ) + assert request.artifact_id == "test-id" + assert request.upload_id == "upload-id" + assert len(request.parts) == 2 + assert request.parts[0].part_number == 1 + assert request.parts[0].etag == "etag1" diff --git a/tests/unit/test_multipart_upload.py b/tests/unit/test_multipart_upload.py new file mode 100644 index 0000000..410e726 --- /dev/null +++ b/tests/unit/test_multipart_upload.py @@ -0,0 +1,101 @@ +"""Unit tests for multipart upload functionality.""" + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.grpc.clients.artifact_storage_client import ArtifactStorageClient + + +@pytest.mark.asyncio +async def test_initiate_multipart_upload_success(): + """Test successful multipart upload initiation.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with patch("app.grpc.clients.artifact_storage_client.grpc.aio.insecure_channel"), \ + patch("app.grpc.generated.artifact_pb2_grpc.ArtifactServiceStub") as mock_stub: + + # Mock successful response + mock_response = MagicMock() + mock_response.artifact_id = "test-artifact-id" + mock_response.upload_id = "test-upload-id" + mock_stub.return_value.InitiateMultipart.return_value = mock_response + + await client.connect() + result = await client.initiate_multipart_upload( + name="large-file.zip", + bucket="test-bucket", + key="test-key", + content_type="application/zip", + pipeline_id="pipeline-1" + ) + + assert result["artifact_id"] == "test-artifact-id" + assert result["upload_id"] == "test-upload-id" + + +@pytest.mark.asyncio +async def test_get_multipart_part_url_success(): + """Test successful multipart part URL generation.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with patch("app.grpc.clients.artifact_storage_client.grpc.aio.insecure_channel"), \ + patch("app.grpc.generated.artifact_pb2_grpc.ArtifactServiceStub") as mock_stub: + + # Mock successful response + mock_response = MagicMock() + mock_response.part_url = "https://example.com/upload-part" + mock_response.HasField.return_value = True + mock_response.expires_at.ToDatetime.return_value.isoformat.return_value = "2026-02-25T16:00:00" + mock_stub.return_value.GetMultipartPartURL.return_value = mock_response + + await client.connect() + result = await client.get_multipart_part_url( + artifact_id="test-artifact-id", + upload_id="test-upload-id", + part_number=1 + ) + + assert result["part_url"] == "https://example.com/upload-part" + assert result["expires_at"] == "2026-02-25T16:00:00" + + +@pytest.mark.asyncio +async def test_complete_multipart_upload_success(): + """Test successful multipart upload completion.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with patch("app.grpc.clients.artifact_storage_client.grpc.aio.insecure_channel"), \ + patch("app.grpc.generated.artifact_pb2_grpc.ArtifactServiceStub") as mock_stub: + + # Mock successful response + mock_response = MagicMock() + mock_response.version_id = "test-version-id" + mock_response.HasField.return_value = True + mock_stub.return_value.CompleteMultipart.return_value = mock_response + + await client.connect() + result = await client.complete_multipart_upload( + artifact_id="test-artifact-id", + upload_id="test-upload-id", + parts=[ + {"part_number": 1, "etag": "etag1"}, + {"part_number": 2, "etag": "etag2"} + ] + ) + + assert result["status"] == "completed" + assert result["version_id"] == "test-version-id" + + +@pytest.mark.asyncio +async def test_multipart_upload_error_handling(): + """Test multipart upload error handling.""" + client = ArtifactStorageClient("localhost", 50051, "test-key") + + with pytest.raises(RuntimeError, match="Client not connected"): + await client.initiate_multipart_upload("test", "bucket", "key") + + with pytest.raises(RuntimeError, match="Client not connected"): + await client.get_multipart_part_url("test-id", "upload-id", 1) + + with pytest.raises(RuntimeError, match="Client not connected"): + await client.complete_multipart_upload("test-id", "upload-id", [])