Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 179 additions & 2 deletions app/api/v1/endpoints/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Artifact management endpoints."""

from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from typing import Optional
from typing import Optional, List

from app.core.grpc_clients import get_artifact_storage_client
from app.grpc.clients.artifact_storage_client import ArtifactStorageClient
Expand Down Expand Up @@ -190,3 +190,180 @@ async def list_artifacts(
status_code=500,
detail=f"Failed to list artifacts: {str(e)}"
)


@router.delete("/{artifact_id}", status_code=status.HTTP_200_OK)
async def delete_artifact(
artifact_id: str,
current_user: User = Depends(deps.get_current_user),
client: ArtifactStorageClient = Depends(get_artifact_storage_client)
):
"""
Delete an artifact.
"""
try:
success = await client.delete_artifact(artifact_id)
if success:
return {"status": "deleted", "artifact_id": artifact_id}
else:
raise HTTPException(
status_code=500,
detail="Failed to delete artifact"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to delete artifact: {str(e)}"
)


@router.post("/{artifact_id}/validate-checksum", response_model=dict)
async def validate_checksum(
artifact_id: str,
current_user: User = Depends(deps.get_current_user),
client: ArtifactStorageClient = Depends(get_artifact_storage_client)
):
"""
Validate artifact checksum integrity.
"""
try:
validation_result = await client.validate_checksum(artifact_id)
return validation_result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to validate checksum: {str(e)}"
)


# Multipart Upload Request/Response Models
class InitiateMultipartRequest(BaseModel):
"""Request to initiate multipart upload."""
name: str = Field(..., description="Artifact name")
bucket: str = Field(..., description="S3 bucket name")
key: str = Field(..., description="S3 object key")
content_type: str = Field(default="application/octet-stream", description="MIME type")
pipeline_id: Optional[str] = Field(None, description="Associated pipeline ID")
execution_id: Optional[str] = Field(None, description="Associated execution ID")
project_id: Optional[str] = Field(None, description="Associated project ID")
retention_days: Optional[int] = Field(None, description="Retention period in days")
storage_class: Optional[str] = Field(None, description="S3 storage class")


class InitiateMultipartResponse(BaseModel):
"""Response with upload ID and artifact ID."""
artifact_id: str = Field(..., description="Unique artifact ID")
upload_id: str = Field(..., description="Multipart upload ID")


class MultipartPartURLRequest(BaseModel):
"""Request for multipart part upload URL."""
artifact_id: str = Field(..., description="Artifact ID")
upload_id: str = Field(..., description="Multipart upload ID")
part_number: int = Field(..., description="Part number", ge=1)
expires_in: int = Field(default=3600, description="URL expiration (seconds)")


class MultipartPartURLResponse(BaseModel):
"""Response with part upload URL."""
part_url: str = Field(..., description="Pre-signed URL for part upload")
expires_at: Optional[str] = Field(None, description="URL expiration time")


class CompletePart(BaseModel):
"""Completed multipart part."""
part_number: int = Field(..., description="Part number", ge=1)
etag: str = Field(..., description="Part ETag from S3")


class CompleteMultipartRequest(BaseModel):
"""Request to complete multipart upload."""
artifact_id: str = Field(..., description="Artifact ID")
upload_id: str = Field(..., description="Multipart upload ID")
parts: List[CompletePart] = Field(..., description="List of completed parts")


class CompleteMultipartResponse(BaseModel):
"""Response with completion status."""
status: str = Field(..., description="Upload completion status")
version_id: Optional[str] = Field(None, description="S3 object version ID")


# Multipart Upload Endpoints
@router.post("/multipart/initiate", response_model=InitiateMultipartResponse, status_code=status.HTTP_200_OK)
async def initiate_multipart_upload(
request: InitiateMultipartRequest,
current_user: User = Depends(deps.get_current_user),
client: ArtifactStorageClient = Depends(get_artifact_storage_client)
):
"""
Initiate multipart upload for large files.
"""
try:
result = await client.initiate_multipart_upload(
name=request.name,
bucket=request.bucket,
key=request.key,
content_type=request.content_type,
pipeline_id=request.pipeline_id,
execution_id=request.execution_id,
project_id=request.project_id,
retention_days=request.retention_days,
storage_class=request.storage_class
)
return result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to initiate multipart upload: {str(e)}"
)


@router.post("/multipart/part-url", response_model=MultipartPartURLResponse, status_code=status.HTTP_200_OK)
async def get_multipart_part_url(
request: MultipartPartURLRequest,
current_user: User = Depends(deps.get_current_user),
client: ArtifactStorageClient = Depends(get_artifact_storage_client)
):
"""
Get pre-signed URL for uploading a multipart part.
"""
try:
result = await client.get_multipart_part_url(
artifact_id=request.artifact_id,
upload_id=request.upload_id,
part_number=request.part_number,
expires_in=request.expires_in
)
return result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to get multipart part URL: {str(e)}"
)


@router.post("/multipart/complete", response_model=CompleteMultipartResponse, status_code=status.HTTP_200_OK)
async def complete_multipart_upload(
request: CompleteMultipartRequest,
current_user: User = Depends(deps.get_current_user),
client: ArtifactStorageClient = Depends(get_artifact_storage_client)
):
"""
Complete multipart upload and assemble final object.
"""
try:
# Convert Pydantic parts to dict for gRPC client
parts_dict = [{"part_number": part.part_number, "etag": part.etag} for part in request.parts]

result = await client.complete_multipart_upload(
artifact_id=request.artifact_id,
upload_id=request.upload_id,
parts=parts_dict
)
return result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to complete multipart upload: {str(e)}"
)
189 changes: 189 additions & 0 deletions app/grpc/clients/artifact_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,195 @@ async def list_artifacts(
logger.error(f"gRPC error listing artifacts: {e.code()} - {e.details()}")
raise

async def delete_artifact(self, artifact_id: str) -> bool:
"""
Delete an artifact.

Args:
artifact_id: Artifact ID to delete

Returns:
True if successful
"""
if not self.stub:
raise RuntimeError("Client not connected. Call connect() first.")

try:
from app.grpc.generated import artifact_pb2

request = artifact_pb2.DeleteArtifactRequest(artifact_id=artifact_id)

metadata = (("x-api-key", self.api_key),)
response = await self.stub.DeleteArtifact(request, metadata=metadata)

return response.status == "deleted"
except grpc.RpcError as e:
logger.error(f"gRPC error deleting artifact: {e.code()} - {e.details()}")
raise

async def validate_checksum(self, artifact_id: str) -> Dict[str, Any]:
"""
Validate artifact checksum integrity.

Args:
artifact_id: Artifact ID to validate

Returns:
Dict with validation results
"""
if not self.stub:
raise RuntimeError("Client not connected. Call connect() first.")

try:
from app.grpc.generated import artifact_pb2

request = artifact_pb2.ValidateChecksumRequest(artifact_id=artifact_id)

metadata = (("x-api-key", self.api_key),)
response = await self.stub.ValidateChecksum(request, metadata=metadata)

return {
"valid": response.valid,
"expected_checksum": response.expected_checksum,
"actual_checksum": response.actual_checksum if response.HasField("actual_checksum") else None,
"validation_time": response.validation_time.ToDatetime().isoformat() if response.HasField("validation_time") else None
}
except grpc.RpcError as e:
logger.error(f"gRPC error validating checksum: {e.code()} - {e.details()}")
raise

async def initiate_multipart_upload(self, name: str, bucket: str, key: str, content_type: str = None,
pipeline_id: str = None, execution_id: str = None,
project_id: str = None, retention_days: int = None,
storage_class: str = None) -> Dict[str, Any]:
"""
Initiate multipart upload for large files.

Args:
name: Artifact name
bucket: S3 bucket name
key: S3 object key
content_type: MIME type
pipeline_id: Associated pipeline ID
execution_id: Associated execution ID
project_id: Associated project ID
retention_days: Retention period in days
storage_class: S3 storage class

Returns:
Dict with upload ID and artifact ID
"""
if not self.stub:
raise RuntimeError("Client not connected. Call connect() first.")

try:
from app.grpc.generated import artifact_pb2

request = artifact_pb2.InitiateMultipartRequest(
name=name,
bucket=bucket,
key=key,
content_type=content_type or "application/octet-stream",
pipeline_id=pipeline_id,
execution_id=execution_id,
project_id=project_id,
retention_days=retention_days,
storage_class=storage_class
)

metadata = (("x-api-key", self.api_key),)
response = await self.stub.InitiateMultipart(request, metadata=metadata)

return {
"artifact_id": response.artifact_id,
"upload_id": response.upload_id
}
except grpc.RpcError as e:
logger.error(f"gRPC error initiating multipart upload: {e.code()} - {e.details()}")
raise

async def get_multipart_part_url(self, artifact_id: str, upload_id: str, part_number: int,
expires_in: int = 3600) -> Dict[str, Any]:
"""
Get pre-signed URL for multipart upload part.

Args:
artifact_id: Artifact ID
upload_id: Multipart upload ID
part_number: Part number
expires_in: URL expiration in seconds

Returns:
Dict with part upload URL
"""
if not self.stub:
raise RuntimeError("Client not connected. Call connect() first.")

try:
from app.grpc.generated import artifact_pb2

request = artifact_pb2.MultipartPartURLRequest(
artifact_id=artifact_id,
upload_id=upload_id,
part_number=part_number,
expires_in=expires_in
)

metadata = (("x-api-key", self.api_key),)
response = await self.stub.GetMultipartPartURL(request, metadata=metadata)

return {
"part_url": response.part_url,
"expires_at": response.expires_at.ToDatetime().isoformat() if response.HasField("expires_at") else None
}
except grpc.RpcError as e:
logger.error(f"gRPC error getting multipart part URL: {e.code()} - {e.details()}")
raise

async def complete_multipart_upload(self, artifact_id: str, upload_id: str, parts: list) -> Dict[str, Any]:
"""
Complete multipart upload.

Args:
artifact_id: Artifact ID
upload_id: Multipart upload ID
parts: List of completed parts with ETags

Returns:
Dict with completion status and version info
"""
if not self.stub:
raise RuntimeError("Client not connected. Call connect() first.")

try:
from app.grpc.generated import artifact_pb2

# Convert parts to protobuf format
grpc_parts = []
for part in parts:
grpc_part = artifact_pb2.CompletePart(
part_number=part.get("part_number"),
etag=part.get("etag")
)
grpc_parts.append(grpc_part)

request = artifact_pb2.CompleteMultipartRequest(
artifact_id=artifact_id,
upload_id=upload_id,
parts=grpc_parts
)

metadata = (("x-api-key", self.api_key),)
response = await self.stub.CompleteMultipart(request, metadata=metadata)

return {
"status": "completed",
"version_id": response.version_id if response.HasField("version_id") else None
}
except grpc.RpcError as e:
logger.error(f"gRPC error completing multipart upload: {e.code()} - {e.details()}")
raise

async def close(self):
"""Close gRPC connection."""
if self.channel:
Expand Down
Loading
Loading