Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion app/api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,26 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):

async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
path = request.url.path

response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = "default-src 'self'"

# Define CSP
if path in ["/docs", "/redoc", "/api/v1/openapi.json"]:
# Documentation needs CDN access and inline scripts/styles
csp = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval' cdn.jsdelivr.net; "
"style-src 'self' 'unsafe-inline' cdn.jsdelivr.net; "
"img-src 'self' data: fastly.jsdelivr.net cdn.jsdelivr.net fastapi.tiangolo.com; "
"connect-src 'self'"
)
else:
# Strict CSP for API endpoints
csp = "default-src 'self'"

response.headers["Content-Security-Policy"] = csp
return response
20 changes: 16 additions & 4 deletions app/db/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,30 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
if cached_val:
try:
return json.loads(cached_val)
except Exception:
except (json.JSONDecodeError, TypeError):
# If corrupted or not JSON, ignore and fetch fresh
pass

result = await func(*args, **kwargs)

if result is not None:
# Handle Pydantic models or SQLAlchemy models if needed
# For now, assumes return is JSON serializable or has a __dict__
# Only cache if it's potentially JSON serializable
# We avoid using str(result) as a fallback because it leads to
# caching strings like "<Team 1>" which break logic expecting objects.
try:
dump = json.dumps(result, default=str)
# Generic check for Pydantic/SQLAlchemy models or dicts
if hasattr(result, "model_dump"): # Pydantic v2
dump_data = result.model_dump(mode="json")
elif hasattr(result, "__dict__"):
# Basic dict fallback, being careful with SQLAlchemy state
dump_data = {k: v for k, v in result.__dict__.items() if not k.startswith("_")}
else:
dump_data = result

dump = json.dumps(dump_data, default=str)
await redis.setex(cache_key, expire, dump)
except Exception:
# If we can't safely serialize, don't cache
pass
return result
return wrapper
Expand Down
8 changes: 6 additions & 2 deletions app/services/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ async def publish(
nats_bus = await get_event_bus()

# Convert event_type to NATS subject format
# "pipeline.executed" -> "backend.pipeline.executed"
subject = f"backend.{event_type}"
# If it already starts with a known prefix, use it as is
if event_type.startswith(("backend.", "pipeline.")):
subject = event_type
else:
# "pipeline.executed" -> "backend.pipeline.executed"
subject = f"backend.{event_type}"

# Prepare event data
event_data = {
Expand Down
19 changes: 18 additions & 1 deletion app/services/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ async def delete_pipeline(db: AsyncSession, pipeline_id: int) -> bool:

async def trigger_pipeline_execution(db: AsyncSession, pipeline_id: int, meta_data: Optional[dict] = None) -> PipelineExecution:
"""Trigger a new pipeline execution."""
pipeline = await get_pipeline(db, pipeline_id)
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")

db_execution = PipelineExecution(
pipeline_id=pipeline_id,
status="pending",
Expand All @@ -74,7 +78,20 @@ async def trigger_pipeline_execution(db: AsyncSession, pipeline_id: int, meta_da
await db.commit()
await db.refresh(db_execution)

await events.publish("pipeline.executed", {"pipeline_id": pipeline_id, "execution_id": db_execution.id}, resource_id=pipeline_id)
# Prepare task payload for worker
task_payload = {
"execution_id": str(db_execution.id),
"pipeline_id": pipeline_id,
"pipeline": {
"name": pipeline.name,
"config": pipeline.config,
# Add other fields if needed by PipelineExecutor
},
"meta_data": meta_data
}

await events.publish("pipeline.tasks", task_payload)
logger.info(f"Triggered pipeline execution {db_execution.id} on subject pipeline.tasks")

return db_execution

Expand Down
1 change: 0 additions & 1 deletion app/services/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ async def create_team(db: AsyncSession, team_in: TeamCreate, owner_id: int) -> T
return db_team


@cache("team", expire=1800, include_args=["team_id"])
async def get_team(db: AsyncSession, team_id: int) -> Optional[Team]:
"""Get a team by ID."""
result = await db.execute(select(Team).where(Team.id == team_id))
Expand Down
1 change: 0 additions & 1 deletion app/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
return result.scalars().first()


@cache("user", expire=600, include_args=["user_id"])
async def get_user_by_id(db: AsyncSession, user_id: int) -> Optional[User]:
"""Get a user by ID."""
result = await db.execute(select(User).where(User.id == user_id))
Expand Down
2 changes: 1 addition & 1 deletion scripts/test_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def test_artifact_storage():
print("\nTesting upload URL generation...")
result = await client.get_upload_url(
name="test-artifact.txt",
bucket="test-bucket",
bucket="artifacts",
key="test/artifact.txt",
content_type="text/plain",
expires_in_seconds=3600,
Expand Down
113 changes: 113 additions & 0 deletions scripts/test_e2e_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3
"""End-to-end test script for Pipeline execution flow."""

import asyncio
import sys
import json
from pathlib import Path
import httpx

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from app.core.config import get_settings

settings = get_settings()
BASE_URL = "http://localhost:8000"
API_V1 = f"{BASE_URL}{settings.api_v1_prefix}"

async def test_e2e_pipeline():
"""Run E2E pipeline test."""
print("=" * 60)
print("End-to-End Pipeline Integration Test")
print("=" * 60)

async with httpx.AsyncClient() as client:
# 1. Login
print("\n1. Logging in...")
try:
login_res = await client.post(
f"{API_V1}/auth/login",
data={"username": "test@xether.ai", "password": "testpassword123"}
)
login_res.raise_for_status()
token = login_res.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}
print("✅ Login successful")
except Exception as e:
print(f"❌ Login failed: {e}")
return False

# 2. Create Team
print("\n2. Creating Team...")
team_res = await client.post(f"{API_V1}/teams/", json={"name": "E2E Test Team"}, headers=headers)
team_res.raise_for_status()
team_id = team_res.json()["id"]
print(f"✅ Team created: ID={team_id}")

# 3. Create Project
print("\n3. Creating Project...")
project_res = await client.post(
f"{API_V1}/projects/",
json={"name": "E2E Pipeline Project", "team_id": team_id},
headers=headers
)
project_res.raise_for_status()
project_id = project_res.json()["id"]
print(f"✅ Project created: ID={project_id}")

# 4. Create Pipeline
print("\n4. Creating Pipeline...")
pipeline_data = {
"name": "E2E Test Pipeline",
"project_id": project_id,
"config": {"input": "test.csv", "operations": ["clean", "normalize"]}
}
pipeline_res = await client.post(f"{API_V1}/pipelines/", json=pipeline_data, headers=headers)
pipeline_res.raise_for_status()
pipeline_id = pipeline_res.json()["id"]
print(f"✅ Pipeline created: ID={pipeline_id}")

# 5. Trigger Execution
print("\n5. Triggering Execution...")
exec_res = await client.post(f"{API_V1}/pipelines/{pipeline_id}/execute", headers=headers)
exec_res.raise_for_status()
execution_id = exec_res.json()["id"]
print(f"✅ Execution triggered: ID={execution_id}")

# 6. Poll for Completion
print("\n6. Polling for completion (max 30s)...")
for i in range(30):
await asyncio.sleep(2)
status_res = await client.get(f"{API_V1}/pipelines/{pipeline_id}/executions", headers=headers)
status_res.raise_for_status()
executions = status_res.json()

# Find our execution
execution = next((e for e in executions if e["id"] == execution_id), None)
if not execution:
print("⚠️ Execution not found in list")
continue

print(f" Status: {execution['status']} (Iteration {i+1})")
if execution["status"] == "completed":
print("\n🎉 SUCCESS: Pipeline execution completed!")
return True
if execution["status"] == "failed":
print(f"\n❌ FAILED: Pipeline execution failed with error: {execution.get('error_message')}")
return False

print("\n⌛ TIMEOUT: Pipeline execution did not complete in time.")
return False

async def main():
success = await test_e2e_pipeline()
if success:
print("\nEnd-to-End Integration Verified! 🚀")
sys.exit(0)
else:
print("\nEnd-to-End Integration Failed. ⚠️")
sys.exit(1)

if __name__ == "__main__":
asyncio.run(main())
Loading