diff --git a/apps/api/src/app/s3.ts b/apps/api/src/app/s3.ts index e74ea40ce0..cf0b43350f 100644 --- a/apps/api/src/app/s3.ts +++ b/apps/api/src/app/s3.ts @@ -1,5 +1,6 @@ import { GetObjectCommand, + HeadObjectCommand, PutObjectCommand, S3Client, type GetObjectCommandOutput, @@ -195,3 +196,27 @@ export async function getObjectAsBuffer( const bytes = await response.Body.transformToByteArray(); return Buffer.from(bytes); } + +/** + * Fetch an S3 object's size (in bytes) via a HEAD request, WITHOUT downloading + * the body. Used to reject oversized uploads before loading them into memory — + * `getObjectAsBuffer` would otherwise buffer the entire object (and base64 + * callers expand it ~1.33x on top), so a single huge file could OOM the API. + * + * Returns `undefined` if S3 doesn't report a ContentLength (callers should treat + * that as "size unknown" rather than "zero"). + */ +export async function getObjectContentLength( + bucket: string, + key: string, +): Promise { + if (!s3Client) { + throw new Error('S3 client not configured'); + } + + const response = await s3Client.send( + new HeadObjectCommand({ Bucket: bucket, Key: key }), + ); + + return response.ContentLength; +} diff --git a/apps/api/src/attachments/attachments.module.ts b/apps/api/src/attachments/attachments.module.ts index 52999437e8..dc8cab4b32 100644 --- a/apps/api/src/attachments/attachments.module.ts +++ b/apps/api/src/attachments/attachments.module.ts @@ -1,10 +1,12 @@ import { Module } from '@nestjs/common'; import { AuthModule } from '../auth/auth.module'; +import { UploadsModule } from '../uploads/uploads.module'; import { AttachmentsController } from './attachments.controller'; import { AttachmentsService } from './attachments.service'; @Module({ - imports: [AuthModule], // Import AuthModule for HybridAuthGuard dependencies + // AuthModule: HybridAuthGuard deps. UploadsModule: presigned-upload s3Key reads. + imports: [AuthModule, UploadsModule], controllers: [AttachmentsController], providers: [AttachmentsService], exports: [AttachmentsService], diff --git a/apps/api/src/attachments/attachments.service.spec.ts b/apps/api/src/attachments/attachments.service.spec.ts new file mode 100644 index 0000000000..cde9ffe991 --- /dev/null +++ b/apps/api/src/attachments/attachments.service.spec.ts @@ -0,0 +1,86 @@ +import { BadRequestException } from '@nestjs/common'; + +// Mocks must be declared before importing the service under test. +jest.mock('@/app/s3', () => ({ + s3Client: { send: jest.fn().mockResolvedValue({}) }, + getSignedUrl: jest.fn().mockResolvedValue('https://signed.example/file'), +})); + +jest.mock('@db', () => ({ + db: { attachment: { create: jest.fn() } }, + AttachmentType: { + image: 'image', + video: 'video', + audio: 'audio', + document: 'document', + other: 'other', + }, + AttachmentEntityType: { task: 'task', offboarding_checklist: 'offboarding_checklist' }, +})); + +jest.mock('../utils/file-type-validation', () => ({ + validateFileContent: jest.fn(), +})); + +import { db } from '@db'; +import { AttachmentsService } from './attachments.service'; + +const mockUploadsService = { readUploadAsBase64: jest.fn() }; + +describe('AttachmentsService — presigned s3Key uploads', () => { + let service: AttachmentsService; + + beforeEach(() => { + jest.clearAllMocks(); + process.env.APP_AWS_BUCKET_NAME = 'test-bucket'; + service = new AttachmentsService(mockUploadsService as never); + }); + + it('resolves the file from s3Key (presigned) — no base64 through the LLM — and uploads it', async () => { + mockUploadsService.readUploadAsBase64.mockResolvedValue( + Buffer.from('hello world').toString('base64'), + ); + (db.attachment.create as jest.Mock).mockResolvedValue({ + id: 'att_1', + name: 'rbac.pdf', + type: 'document', + url: 'org_1/attachments/task/tsk_1/key', + createdAt: new Date(), + }); + + const result = await service.uploadAttachment( + 'org_1', + 'tsk_1', + 'task' as never, + { + fileName: 'rbac.pdf', + fileType: 'application/pdf', + s3Key: 'org_1/uploads/attachment/123-rbac.pdf', + } as never, + 'usr_1', + ); + + // Fetched the bytes from the org-scoped presigned key instead of base64. + expect(mockUploadsService.readUploadAsBase64).toHaveBeenCalledWith( + 'org_1', + 'org_1/uploads/attachment/123-rbac.pdf', + ); + expect(db.attachment.create).toHaveBeenCalled(); + expect(result.id).toBe('att_1'); + }); + + it('throws when neither fileData nor s3Key is provided', async () => { + await expect( + service.uploadAttachment( + 'org_1', + 'tsk_1', + 'task' as never, + { fileName: 'rbac.pdf', fileType: 'application/pdf' } as never, + 'usr_1', + ), + ).rejects.toBeInstanceOf(BadRequestException); + + expect(mockUploadsService.readUploadAsBase64).not.toHaveBeenCalled(); + expect(db.attachment.create).not.toHaveBeenCalled(); + }); +}); diff --git a/apps/api/src/attachments/attachments.service.ts b/apps/api/src/attachments/attachments.service.ts index 6bb88bed94..849037373a 100644 --- a/apps/api/src/attachments/attachments.service.ts +++ b/apps/api/src/attachments/attachments.service.ts @@ -15,6 +15,7 @@ import { import { randomBytes } from 'crypto'; import { AttachmentResponseDto } from '../tasks/dto/task-responses.dto'; import { UploadAttachmentDto } from './upload-attachment.dto'; +import { UploadsService } from '../uploads/uploads.service'; import { validateFileContent } from '../utils/file-type-validation'; @Injectable() @@ -24,7 +25,7 @@ export class AttachmentsService { private readonly MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024; // 100MB private readonly SIGNED_URL_EXPIRY = 900; // 15 minutes - constructor() { + constructor(private readonly uploadsService: UploadsService) { // AWS configuration is validated at startup via ConfigModule // Safe to access environment variables directly since they're validated this.bucketName = process.env.APP_AWS_BUCKET_NAME!; @@ -115,8 +116,26 @@ export class AttachmentsService { ); } + // Resolve the file content from either inline base64 (UI/direct callers) + // or a presigned-upload s3Key (AI/MCP clients — avoids slow base64 through + // an LLM). readUploadAsBase64 enforces that the key belongs to this org. + const fileData = + uploadDto.fileData ?? + (uploadDto.s3Key + ? await this.uploadsService.readUploadAsBase64( + organizationId, + uploadDto.s3Key, + ) + : undefined); + + if (!fileData) { + throw new BadRequestException( + 'Provide either fileData (base64) or s3Key from /v1/uploads/presign.', + ); + } + // Validate file size - const fileBuffer = Buffer.from(uploadDto.fileData, 'base64'); + const fileBuffer = Buffer.from(fileData, 'base64'); if (fileBuffer.length > this.MAX_FILE_SIZE_BYTES) { throw new BadRequestException( `File size exceeds maximum allowed size of ${this.MAX_FILE_SIZE_BYTES / (1024 * 1024)}MB`, diff --git a/apps/api/src/attachments/upload-attachment.dto.ts b/apps/api/src/attachments/upload-attachment.dto.ts index 6b950486ec..3e38d4ec07 100644 --- a/apps/api/src/attachments/upload-attachment.dto.ts +++ b/apps/api/src/attachments/upload-attachment.dto.ts @@ -8,6 +8,7 @@ import { MaxLength, } from 'class-validator'; import { IsMimeTypeField } from '../utils/mime-type.validator'; +import { MAX_UPLOAD_BASE64_LENGTH } from '../uploads/upload-limits'; export class UploadAttachmentDto { @ApiProperty({ @@ -29,15 +30,29 @@ export class UploadAttachmentDto { fileType: string; @ApiProperty({ - description: 'Base64 encoded file data', + description: + 'Base64-encoded file contents. For the web UI / direct callers. AI/MCP clients should instead upload via /v1/uploads/presign (purpose=attachment) and pass `s3Key` — base64 through an LLM is impractically slow and times out. Provide exactly one of fileData or s3Key.', + required: false, example: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==', }) + @IsOptional() @IsString() @IsNotEmpty() - @MaxLength(134_217_728) + @MaxLength(MAX_UPLOAD_BASE64_LENGTH) @IsBase64() - fileData: string; + fileData?: string; + + @ApiProperty({ + description: + 'Key of a file already uploaded via /v1/uploads/presign (purpose=attachment). The server fetches the bytes from storage — no base64 needed. Provide exactly one of fileData or s3Key.', + required: false, + example: 'org_abc123/uploads/attachment/1700000000000-rbac-matrix.xlsx', + }) + @IsOptional() + @IsString() + @IsNotEmpty() + s3Key?: string; @ApiProperty({ description: 'Description of the attachment', diff --git a/apps/api/src/knowledge-base/dto/upload-document.dto.ts b/apps/api/src/knowledge-base/dto/upload-document.dto.ts index 5240521176..08088a79ad 100644 --- a/apps/api/src/knowledge-base/dto/upload-document.dto.ts +++ b/apps/api/src/knowledge-base/dto/upload-document.dto.ts @@ -1,18 +1,48 @@ -import { IsOptional, IsString } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; +import { IsBase64, IsOptional, IsString, MaxLength } from 'class-validator'; +import { MAX_UPLOAD_BASE64_LENGTH } from '../../uploads/upload-limits'; export class UploadDocumentDto { + @ApiProperty({ description: 'Organization ID that owns the document' }) @IsString() organizationId!: string; + @ApiProperty({ description: 'File name', example: 'rbac-matrix.xlsx' }) @IsString() fileName!: string; + @ApiProperty({ + description: 'MIME type of the file', + example: + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + }) @IsString() fileType!: string; + @ApiProperty({ + description: + 'Base64-encoded file contents. For the web UI / direct callers. AI/MCP clients should instead upload via /v1/uploads/presign (purpose=document) and pass `s3Key` — base64 through an LLM is impractically slow and times out. Provide exactly one of fileData or s3Key.', + required: false, + }) + @IsOptional() + @IsString() + // Cap the inline payload at the validation layer (before it is decoded), + // matching the other migrated upload DTOs. The limit is the base64 length of + // the 100 MiB file ceiling — see upload-limits.ts. + @MaxLength(MAX_UPLOAD_BASE64_LENGTH) + @IsBase64() + fileData?: string; // base64 encoded + + @ApiProperty({ + description: + 'Key of a file already uploaded via /v1/uploads/presign (purpose=document). The server fetches the bytes from storage — no base64 needed. Provide exactly one of fileData or s3Key.', + required: false, + }) + @IsOptional() @IsString() - fileData!: string; // base64 encoded + s3Key?: string; + @ApiProperty({ description: 'Optional description', required: false }) @IsOptional() @IsString() description?: string; diff --git a/apps/api/src/knowledge-base/knowledge-base.module.ts b/apps/api/src/knowledge-base/knowledge-base.module.ts index bce607bd61..2774e28ed3 100644 --- a/apps/api/src/knowledge-base/knowledge-base.module.ts +++ b/apps/api/src/knowledge-base/knowledge-base.module.ts @@ -1,10 +1,11 @@ import { Module } from '@nestjs/common'; import { AuthModule } from '../auth/auth.module'; +import { UploadsModule } from '../uploads/uploads.module'; import { KnowledgeBaseController } from './knowledge-base.controller'; import { KnowledgeBaseService } from './knowledge-base.service'; @Module({ - imports: [AuthModule], + imports: [AuthModule, UploadsModule], controllers: [KnowledgeBaseController], providers: [KnowledgeBaseService], }) diff --git a/apps/api/src/knowledge-base/knowledge-base.service.spec.ts b/apps/api/src/knowledge-base/knowledge-base.service.spec.ts index 3b829abce4..efe03abc0c 100644 --- a/apps/api/src/knowledge-base/knowledge-base.service.spec.ts +++ b/apps/api/src/knowledge-base/knowledge-base.service.spec.ts @@ -1,5 +1,11 @@ +import { BadRequestException } from '@nestjs/common'; import { Test, TestingModule } from '@nestjs/testing'; import { KnowledgeBaseService } from './knowledge-base.service'; +import { UploadsService } from '../uploads/uploads.service'; + +const mockUploadsService = { + readUploadAsBase64: jest.fn(), +}; jest.mock('@db', () => ({ db: { @@ -68,7 +74,10 @@ describe('KnowledgeBaseService', () => { beforeEach(async () => { const module: TestingModule = await Test.createTestingModule({ - providers: [KnowledgeBaseService], + providers: [ + KnowledgeBaseService, + { provide: UploadsService, useValue: mockUploadsService }, + ], }).compile(); service = module.get(KnowledgeBaseService); @@ -215,6 +224,49 @@ describe('KnowledgeBaseService', () => { 'base64data', ); }); + + it('resolves content from s3Key (presigned upload) when no fileData', async () => { + mockUploadsService.readUploadAsBase64.mockResolvedValue('fromS3base64'); + (uploadToS3 as jest.Mock).mockResolvedValue({ + s3Key: 'org_1/doc.pdf', + fileSize: 2048, + }); + (mockDb.knowledgeBaseDocument.create as jest.Mock).mockResolvedValue({ + id: 'd2', + name: 'doc.pdf', + s3Key: 'org_1/doc.pdf', + }); + + await service.uploadDocument({ + organizationId: 'org_1', + fileName: 'doc.pdf', + fileType: 'application/pdf', + s3Key: 'org_1/uploads/document/123-doc.pdf', + } as any); + + // Fetched the bytes from the presigned key (org-scoped, no base64 via LLM) + expect(mockUploadsService.readUploadAsBase64).toHaveBeenCalledWith( + 'org_1', + 'org_1/uploads/document/123-doc.pdf', + ); + expect(uploadToS3).toHaveBeenCalledWith( + 'org_1', + 'doc.pdf', + 'application/pdf', + 'fromS3base64', + ); + }); + + it('throws when neither fileData nor s3Key is provided', async () => { + await expect( + service.uploadDocument({ + organizationId: 'org_1', + fileName: 'doc.pdf', + fileType: 'application/pdf', + } as any), + ).rejects.toBeInstanceOf(BadRequestException); + expect(uploadToS3).not.toHaveBeenCalled(); + }); }); describe('getDownloadUrl', () => { diff --git a/apps/api/src/knowledge-base/knowledge-base.service.ts b/apps/api/src/knowledge-base/knowledge-base.service.ts index 1b754192db..beb2612b98 100644 --- a/apps/api/src/knowledge-base/knowledge-base.service.ts +++ b/apps/api/src/knowledge-base/knowledge-base.service.ts @@ -1,8 +1,9 @@ -import { Injectable, Logger } from '@nestjs/common'; +import { BadRequestException, Injectable, Logger } from '@nestjs/common'; import { db } from '@db'; import { tasks, auth } from '@trigger.dev/sdk'; import { syncManualAnswerToVector } from '@/vector-store/lib'; import { UploadDocumentDto } from './dto/upload-document.dto'; +import { UploadsService } from '../uploads/uploads.service'; import { DeleteDocumentDto } from './dto/delete-document.dto'; import { GetDocumentUrlDto } from './dto/get-document-url.dto'; import { ProcessDocumentsDto } from './dto/process-documents.dto'; @@ -25,6 +26,8 @@ import { export class KnowledgeBaseService { private readonly logger = new Logger(KnowledgeBaseService.name); + constructor(private readonly uploadsService: UploadsService) {} + async listDocuments(organizationId: string) { return db.knowledgeBaseDocument.findMany({ where: { organizationId }, @@ -44,12 +47,30 @@ export class KnowledgeBaseService { } async uploadDocument(dto: UploadDocumentDto) { + // Resolve content from inline base64 (UI/direct) or a presigned-upload + // s3Key (AI/MCP clients — avoids slow base64 through an LLM). The read + // enforces that the key belongs to this org. + const fileData = + dto.fileData ?? + (dto.s3Key + ? await this.uploadsService.readUploadAsBase64( + dto.organizationId, + dto.s3Key, + ) + : undefined); + + if (!fileData) { + throw new BadRequestException( + 'Provide either fileData (base64) or s3Key from /v1/uploads/presign.', + ); + } + // Upload to S3 const { s3Key, fileSize } = await uploadToS3( dto.organizationId, dto.fileName, dto.fileType, - dto.fileData, + fileData, ); // Create database record diff --git a/apps/api/src/offboarding-checklist/dto/complete-checklist-item.dto.ts b/apps/api/src/offboarding-checklist/dto/complete-checklist-item.dto.ts index 017b7614a2..fa690233a7 100644 --- a/apps/api/src/offboarding-checklist/dto/complete-checklist-item.dto.ts +++ b/apps/api/src/offboarding-checklist/dto/complete-checklist-item.dto.ts @@ -1,6 +1,7 @@ import { ApiProperty } from '@nestjs/swagger'; import { IsOptional, IsString, MaxLength, IsBase64 } from 'class-validator'; import { IsMimeTypeField } from '../../utils/mime-type.validator'; +import { MAX_UPLOAD_BASE64_LENGTH } from '../../uploads/upload-limits'; export class CompleteChecklistItemDto { @ApiProperty({ description: 'Optional notes', required: false }) @@ -19,12 +20,22 @@ export class CompleteChecklistItemDto { fileType?: string; @ApiProperty({ - description: 'Base64 encoded evidence file', + description: + 'Base64-encoded evidence file. For the web UI / direct callers. AI/MCP clients should instead upload via /v1/uploads/presign (purpose=evidence) and pass `s3Key` — base64 through an LLM is impractically slow and times out. Provide fileData or s3Key (not both).', required: false, }) @IsOptional() @IsString() - @MaxLength(134_217_728) + @MaxLength(MAX_UPLOAD_BASE64_LENGTH) @IsBase64() fileData?: string; + + @ApiProperty({ + description: + 'Key of an evidence file already uploaded via /v1/uploads/presign (purpose=evidence). The server fetches the bytes from storage — no base64 needed. Provide fileData or s3Key (not both).', + required: false, + }) + @IsOptional() + @IsString() + s3Key?: string; } diff --git a/apps/api/src/offboarding-checklist/offboarding-checklist.service.ts b/apps/api/src/offboarding-checklist/offboarding-checklist.service.ts index 09a4378478..23fadea7cc 100644 --- a/apps/api/src/offboarding-checklist/offboarding-checklist.service.ts +++ b/apps/api/src/offboarding-checklist/offboarding-checklist.service.ts @@ -13,12 +13,15 @@ interface CompleteChecklistItemDto { fileName?: string; fileType?: string; fileData?: string; + s3Key?: string; } interface UploadEvidenceDto { fileName: string; fileType: string; - fileData: string; + // Either inline base64 (UI/direct) or a presigned-upload s3Key (AI/MCP). + fileData?: string; + s3Key?: string; description?: string; } @@ -211,7 +214,14 @@ export class OffboardingChecklistService { throw new NotFoundException('Template item not found'); } - if (template.evidenceRequired && (!dto.fileData || !dto.fileName || !dto.fileType)) { + // Evidence can arrive as inline base64 (fileData) or a presigned-upload + // s3Key (AI/MCP clients — avoids slow base64 through an LLM). + const hasEvidenceFile = Boolean(dto.fileData || dto.s3Key); + + if ( + template.evidenceRequired && + (!hasEvidenceFile || !dto.fileName || !dto.fileType) + ) { throw new BadRequestException('Evidence is required to complete this item'); } @@ -225,8 +235,10 @@ export class OffboardingChecklistService { }, }); - if (dto.fileName && dto.fileData && dto.fileType) { + if (dto.fileName && hasEvidenceFile && dto.fileType) { try { + // AttachmentsService.uploadAttachment resolves the bytes from whichever + // of fileData / s3Key is provided. await this.attachmentsService.uploadAttachment( organizationId, completion.id, @@ -234,6 +246,7 @@ export class OffboardingChecklistService { { fileName: dto.fileName, fileData: dto.fileData, + s3Key: dto.s3Key, fileType: dto.fileType, }, completedById, diff --git a/apps/api/src/openapi-docs.spec.ts b/apps/api/src/openapi-docs.spec.ts index 74f95f62d4..d9d42d89ae 100644 --- a/apps/api/src/openapi-docs.spec.ts +++ b/apps/api/src/openapi-docs.spec.ts @@ -113,6 +113,7 @@ import { AppModule } from './app.module'; import { applyPublicOpenApiMetadata, PUBLIC_OPENAPI_DESCRIPTION, + PUBLIC_OPENAPI_TIMEOUT_MS, PUBLIC_OPENAPI_TITLE, PUBLIC_SERVER_URL, } from './openapi/public-docs-metadata'; @@ -156,6 +157,14 @@ describe('OpenAPI document', () => { ]); }); + it('bakes a finite default request timeout into the generated SDK + MCP server', () => { + // Without x-speakeasy-timeout the generated request funcs use -1 ("no + // timeout") and a hung upstream wedges the MCP connection forever. + expect( + (document as { 'x-speakeasy-timeout'?: number })['x-speakeasy-timeout'], + ).toBe(PUBLIC_OPENAPI_TIMEOUT_MS); + }); + it('keeps the public spec complete, SEO-ready, and free of private surfaces', () => { const issues = collectPublicOpenApiIssues(document); diff --git a/apps/api/src/openapi/public-docs-metadata.ts b/apps/api/src/openapi/public-docs-metadata.ts index d2c8b9ad91..b3454be8d8 100644 --- a/apps/api/src/openapi/public-docs-metadata.ts +++ b/apps/api/src/openapi/public-docs-metadata.ts @@ -26,6 +26,31 @@ export const PUBLIC_OPENAPI_DESCRIPTION = export const PUBLIC_SERVER_URL = 'https://api.trycomp.ai'; +/** + * Default request timeout (ms) baked into the generated SDK + MCP server via the + * `x-speakeasy-timeout` document-root extension. + * + * Speakeasy-generated request functions resolve their timeout to + * `operationTimeoutMs || clientTimeoutMs || -1`, and `-1` means "no timeout". + * Without this, a slow/hung upstream call leaves the MCP server's fetch dangling + * forever; the MCP client eventually gives up and marks the whole connection + * unhealthy (customer-reported wedging). A finite timeout makes the SDK abort + * the request and return a clean error instead, keeping the connection alive. + * + * 120s comfortably covers our slowest endpoints while staying under the ALB's + * 300s idle timeout (comp-private infra/modules/loadbalancer.ts). + */ +export const PUBLIC_OPENAPI_TIMEOUT_MS = 120_000; + +/** + * OpenAPIObject (from @nestjs/swagger) has no index signature for `x-*` + * extensions at the document root, so we widen it locally instead of reaching + * for `as any`. + */ +type OpenApiDocumentWithExtensions = OpenAPIObject & { + 'x-speakeasy-timeout'?: number; +}; + function getVisibilityForOperation( operation: OpenApiOperation, metadata?: PublicOperationMetadata, @@ -284,6 +309,12 @@ export function applyPublicOpenApiMetadata(document: OpenAPIObject): void { }, ]; + // Bake a finite default request timeout into the generated SDK + MCP server + // so a hung upstream call can never wedge the MCP connection. See + // PUBLIC_OPENAPI_TIMEOUT_MS for the full rationale. + (document as OpenApiDocumentWithExtensions)['x-speakeasy-timeout'] = + PUBLIC_OPENAPI_TIMEOUT_MS; + const paths = document.paths as Record< string, Record diff --git a/apps/api/src/tasks/automations/automations.controller.ts b/apps/api/src/tasks/automations/automations.controller.ts index 05ce95b6cc..c07c3f0434 100644 --- a/apps/api/src/tasks/automations/automations.controller.ts +++ b/apps/api/src/tasks/automations/automations.controller.ts @@ -10,6 +10,7 @@ import { UseGuards, } from '@nestjs/common'; import { + ApiBody, ApiOperation, ApiParam, ApiResponse, @@ -22,6 +23,7 @@ import { PermissionGuard } from '../../auth/permission.guard'; import { RequirePermission } from '../../auth/require-permission.decorator'; import { TasksService } from '../tasks.service'; import { AutomationsService } from './automations.service'; +import { CreateVersionDto } from './dto/create-version.dto'; import { UpdateAutomationDto } from './dto/update-automation.dto'; import { AUTOMATION_OPERATIONS } from './schemas/automation-operations'; import { CREATE_AUTOMATION_RESPONSES } from './schemas/create-automation.responses'; @@ -255,11 +257,12 @@ export class AutomationsController { }) @ApiParam({ name: 'taskId', description: 'Task ID' }) @ApiParam({ name: 'automationId', description: 'Automation ID' }) + @ApiBody({ type: CreateVersionDto }) async createVersion( @OrganizationId() organizationId: string, @Param('taskId') taskId: string, @Param('automationId') automationId: string, - @Body() body: { version: number; scriptKey: string; changelog?: string }, + @Body() body: CreateVersionDto, ) { await this.tasksService.verifyTaskAccess(organizationId, taskId); return this.automationsService.createVersion(automationId, body); diff --git a/apps/api/src/tasks/automations/automations.service.spec.ts b/apps/api/src/tasks/automations/automations.service.spec.ts new file mode 100644 index 0000000000..f7a2746ae2 --- /dev/null +++ b/apps/api/src/tasks/automations/automations.service.spec.ts @@ -0,0 +1,83 @@ +import { ConflictException, NotFoundException } from '@nestjs/common'; + +// Mock the DB layer before importing the service. We also provide a stand-in +// Prisma.PrismaClientKnownRequestError so the service's `instanceof` checks and +// error-code branches can be exercised without a real database. +jest.mock('@db', () => { + class PrismaClientKnownRequestError extends Error { + code: string; + constructor(message: string, { code }: { code: string }) { + super(message); + this.code = code; + this.name = 'PrismaClientKnownRequestError'; + } + } + + return { + db: { + $transaction: jest.fn(), + evidenceAutomationVersion: { create: jest.fn() }, + evidenceAutomation: { update: jest.fn() }, + }, + Prisma: { PrismaClientKnownRequestError }, + }; +}); + +import { db, Prisma } from '@db'; +import { AutomationsService } from './automations.service'; + +const prismaError = (code: string) => + new Prisma.PrismaClientKnownRequestError(code, { + code, + clientVersion: '5.0.0', + }); + +describe('AutomationsService.createVersion — error mapping', () => { + let service: AutomationsService; + const input = { version: 1, scriptKey: 'org_1/tsk_1/aut_1.v1.js' }; + + beforeEach(() => { + jest.clearAllMocks(); + service = new AutomationsService(); + }); + + it('records the version and returns it on success', async () => { + const created = { id: 'eav_1', version: 1, scriptKey: input.scriptKey }; + (db.$transaction as jest.Mock).mockResolvedValue([created, { id: 'aut_1' }]); + + const result = await service.createVersion('aut_1', input); + + expect(result).toEqual({ success: true, version: created }); + }); + + it('maps a duplicate version (P2002) to a 409 ConflictException', async () => { + (db.$transaction as jest.Mock).mockRejectedValue(prismaError('P2002')); + + await expect(service.createVersion('aut_1', input)).rejects.toBeInstanceOf( + ConflictException, + ); + }); + + it('maps a missing automation (P2003 FK violation) to a 404 NotFoundException', async () => { + (db.$transaction as jest.Mock).mockRejectedValue(prismaError('P2003')); + + await expect( + service.createVersion('missing', input), + ).rejects.toBeInstanceOf(NotFoundException); + }); + + it('maps a missing automation (P2025 record not found) to a 404 NotFoundException', async () => { + (db.$transaction as jest.Mock).mockRejectedValue(prismaError('P2025')); + + await expect( + service.createVersion('missing', input), + ).rejects.toBeInstanceOf(NotFoundException); + }); + + it('rethrows unexpected errors untouched (no masking real 500s)', async () => { + const boom = new Error('db exploded'); + (db.$transaction as jest.Mock).mockRejectedValue(boom); + + await expect(service.createVersion('aut_1', input)).rejects.toBe(boom); + }); +}); diff --git a/apps/api/src/tasks/automations/automations.service.ts b/apps/api/src/tasks/automations/automations.service.ts index 86a6fd8ee9..1e8bf00254 100644 --- a/apps/api/src/tasks/automations/automations.service.ts +++ b/apps/api/src/tasks/automations/automations.service.ts @@ -1,5 +1,10 @@ -import { Injectable, NotFoundException } from '@nestjs/common'; -import { db } from '@db'; +import { + ConflictException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db, Prisma } from '@db'; +import { CreateVersionDto } from './dto/create-version.dto'; import { UpdateAutomationDto } from './dto/update-automation.dto'; @Injectable() @@ -135,26 +140,40 @@ export class AutomationsService { }; } - async createVersion( - automationId: string, - data: { version: number; scriptKey: string; changelog?: string }, - ) { - const [version] = await db.$transaction([ - db.evidenceAutomationVersion.create({ - data: { - evidenceAutomationId: automationId, - version: data.version, - scriptKey: data.scriptKey, - changelog: data.changelog, - }, - }), - // Enable automation on publish if not already enabled - db.evidenceAutomation.update({ - where: { id: automationId }, - data: { isEnabled: true }, - }), - ]); - return { success: true, version }; + async createVersion(automationId: string, data: CreateVersionDto) { + try { + const [version] = await db.$transaction([ + db.evidenceAutomationVersion.create({ + data: { + evidenceAutomationId: automationId, + version: data.version, + scriptKey: data.scriptKey, + changelog: data.changelog, + }, + }), + // Enable automation on publish if not already enabled + db.evidenceAutomation.update({ + where: { id: automationId }, + data: { isEnabled: true }, + }), + ]); + return { success: true, version }; + } catch (err) { + if (err instanceof Prisma.PrismaClientKnownRequestError) { + // Duplicate (evidenceAutomationId, version) — version already published. + if (err.code === 'P2002') { + throw new ConflictException( + `Version ${data.version} already exists for this automation`, + ); + } + // Automation row missing — FK on create (P2003) or update target gone + // (P2025). Surface a clean 404 instead of a raw 500. + if (err.code === 'P2003' || err.code === 'P2025') { + throw new NotFoundException(`Automation ${automationId} not found`); + } + } + throw err; + } } async findRunsByAutomationId(automationId: string) { diff --git a/apps/api/src/tasks/automations/dto/create-version.dto.spec.ts b/apps/api/src/tasks/automations/dto/create-version.dto.spec.ts new file mode 100644 index 0000000000..5396fff702 --- /dev/null +++ b/apps/api/src/tasks/automations/dto/create-version.dto.spec.ts @@ -0,0 +1,66 @@ +import { plainToInstance } from 'class-transformer'; +import { validate } from 'class-validator'; +import { CreateVersionDto } from './create-version.dto'; + +/** + * The original endpoint accepted an inline, untyped `@Body()` — invisible to the + * ValidationPipe — so a missing `version`/`scriptKey` slipped through and blew up + * with a Prisma non-null violation (500). These tests prove the DTO now rejects + * those payloads at the validation layer (400) before they reach the service. + */ +describe('CreateVersionDto', () => { + async function validatePayload(payload: Record) { + return validate(plainToInstance(CreateVersionDto, payload)); + } + + it('accepts a valid payload', async () => { + const errors = await validatePayload({ + version: 1, + scriptKey: 'org_1/tsk_1/aut_1.v1.js', + changelog: 'initial publish', + }); + expect(errors).toHaveLength(0); + }); + + it('rejects a missing version (previously a 500)', async () => { + const errors = await validatePayload({ scriptKey: 'k' }); + expect(errors.some((e) => e.property === 'version')).toBe(true); + }); + + it('rejects a missing scriptKey (previously a 500)', async () => { + const errors = await validatePayload({ version: 1 }); + expect(errors.some((e) => e.property === 'scriptKey')).toBe(true); + }); + + it('rejects a version below 1', async () => { + const errors = await validatePayload({ version: 0, scriptKey: 'k' }); + expect(errors.some((e) => e.property === 'version')).toBe(true); + }); + + it('rejects an empty scriptKey', async () => { + const errors = await validatePayload({ version: 1, scriptKey: '' }); + expect(errors.some((e) => e.property === 'scriptKey')).toBe(true); + }); + + it('rejects a whitespace-only scriptKey (would otherwise persist a blank key)', async () => { + for (const scriptKey of [' ', '\t\n', '   ']) { + const errors = await validatePayload({ version: 1, scriptKey }); + expect(errors.some((e) => e.property === 'scriptKey')).toBe(true); + } + }); + + it('trims surrounding whitespace from a valid scriptKey', async () => { + const dto = plainToInstance(CreateVersionDto, { + version: 1, + scriptKey: ' org_1/tsk_1/aut_1.v1.js ', + }); + const errors = await validate(dto); + expect(errors).toHaveLength(0); + expect(dto.scriptKey).toBe('org_1/tsk_1/aut_1.v1.js'); + }); + + it('treats changelog as optional', async () => { + const errors = await validatePayload({ version: 2, scriptKey: 'k' }); + expect(errors.some((e) => e.property === 'changelog')).toBe(false); + }); +}); diff --git a/apps/api/src/tasks/automations/dto/create-version.dto.ts b/apps/api/src/tasks/automations/dto/create-version.dto.ts new file mode 100644 index 0000000000..8bc521f111 --- /dev/null +++ b/apps/api/src/tasks/automations/dto/create-version.dto.ts @@ -0,0 +1,40 @@ +import { ApiProperty } from '@nestjs/swagger'; +import { Transform } from 'class-transformer'; +import { IsInt, IsNotEmpty, IsOptional, IsString, Min } from 'class-validator'; + +/** + * Records that an automation script has been generated + published to S3. + * `version` and `scriptKey` are REQUIRED — the row references an already-stored + * script. The web UI's publish flow supplies both from the enterprise publish + * step; calling this without them used to 500 (Prisma non-null violation). + */ +export class CreateVersionDto { + @ApiProperty({ + description: 'Version number for this published script', + example: 1, + }) + @IsInt() + @Min(1) + version!: number; + + @ApiProperty({ + description: + 'S3 key of the already-generated & published automation script (returned by the publish step).', + example: 'org_abc123/tsk_abc123/aut_abc123.v1.js', + }) + @IsString() + // Trim first so a whitespace-only key collapses to '' and @IsNotEmpty rejects + // it — otherwise a blank key would persist and the automation would later + // fail to fetch a script at that key. Non-strings pass through for @IsString. + @Transform(({ value }) => (typeof value === 'string' ? value.trim() : value)) + @IsNotEmpty() + scriptKey!: string; + + @ApiProperty({ + description: 'Optional changelog describing this version', + required: false, + }) + @IsOptional() + @IsString() + changelog?: string; +} diff --git a/apps/api/src/uploads/dto/create-upload-url.dto.ts b/apps/api/src/uploads/dto/create-upload-url.dto.ts index 34776c779b..af2cf2a1e1 100644 --- a/apps/api/src/uploads/dto/create-upload-url.dto.ts +++ b/apps/api/src/uploads/dto/create-upload-url.dto.ts @@ -11,6 +11,7 @@ export enum UploadPurpose { policyPdf = 'policy_pdf', evidence = 'evidence', attachment = 'attachment', + document = 'document', general = 'general', } diff --git a/apps/api/src/uploads/upload-limits.spec.ts b/apps/api/src/uploads/upload-limits.spec.ts new file mode 100644 index 0000000000..83ffe1d4d0 --- /dev/null +++ b/apps/api/src/uploads/upload-limits.spec.ts @@ -0,0 +1,15 @@ +import { MAX_UPLOAD_BASE64_LENGTH, MAX_UPLOAD_BYTES } from './upload-limits'; + +describe('upload-limits', () => { + it('caps decoded uploads at 100 MiB', () => { + expect(MAX_UPLOAD_BYTES).toBe(100 * 1024 * 1024); + }); + + it('allows the base64 of a full 100 MiB file (no false 413 for UI uploads)', () => { + // Regression guard: the previous literal (134_217_728) was the base64 length + // of only 96 MiB, so a 96–100 MiB file the UI/service accept was rejected. + expect(MAX_UPLOAD_BASE64_LENGTH).toBe(139_810_136); + expect(MAX_UPLOAD_BASE64_LENGTH).toBe(Math.ceil(MAX_UPLOAD_BYTES / 3) * 4); + expect(MAX_UPLOAD_BASE64_LENGTH).toBeGreaterThan(134_217_728); + }); +}); diff --git a/apps/api/src/uploads/upload-limits.ts b/apps/api/src/uploads/upload-limits.ts new file mode 100644 index 0000000000..e1372aefd2 --- /dev/null +++ b/apps/api/src/uploads/upload-limits.ts @@ -0,0 +1,27 @@ +/** + * Shared upload size limits for both upload paths: + * - inline base64 `fileData` (web UI / direct callers), capped on the DTO via + * MAX_UPLOAD_BASE64_LENGTH so an oversized payload is rejected at validation + * time (before it is decoded); + * - presigned `s3Key` (AI/MCP clients), capped in UploadsService via a HEAD + * request before the object is downloaded. + * + * Keep these as the single source of truth so the DTO caps, the service caps, + * and the web UI dropzone limits can't drift apart. + */ + +/** Maximum decoded file size accepted by any upload path (100 MiB). */ +export const MAX_UPLOAD_BYTES = 100 * 1024 * 1024; + +/** + * Maximum length of a base64-encoded inline `fileData` field. + * + * Base64 inflates bytes by 4/3 (4 chars per 3 bytes, padded), so this is the + * encoded length of a MAX_UPLOAD_BYTES file: `4 * ceil(bytes / 3)` = 139,810,136. + * + * IMPORTANT: it must be the base64 length of the FULL byte limit, not the byte + * limit itself. The previous literal (134_217_728 = 128 MiB of characters) only + * permitted ~96 MiB of decoded data, so a 96–100 MiB file the UI dropzone and + * the service both allow was wrongly rejected with a 400. + */ +export const MAX_UPLOAD_BASE64_LENGTH = Math.ceil(MAX_UPLOAD_BYTES / 3) * 4; diff --git a/apps/api/src/uploads/uploads.service.spec.ts b/apps/api/src/uploads/uploads.service.spec.ts index 064271003c..25327b2e73 100644 --- a/apps/api/src/uploads/uploads.service.spec.ts +++ b/apps/api/src/uploads/uploads.service.spec.ts @@ -8,12 +8,14 @@ jest.mock('../app/s3', () => ({ s3Client: { send: jest.fn() }, getSignedUrl: jest.fn(async () => 'https://test-bucket.s3.amazonaws.com/signed'), getObjectAsBuffer: jest.fn(), + getObjectContentLength: jest.fn(), })); // eslint-disable-next-line @typescript-eslint/no-require-imports const s3 = require('../app/s3') as { getSignedUrl: jest.Mock; getObjectAsBuffer: jest.Mock; + getObjectContentLength: jest.Mock; }; describe('UploadsService', () => { @@ -83,6 +85,7 @@ describe('UploadsService', () => { describe('readUploadAsBase64', () => { it('fetches the object and returns base64 for a valid org key', async () => { + s3.getObjectContentLength.mockResolvedValueOnce(11); s3.getObjectAsBuffer.mockResolvedValueOnce(Buffer.from('hello world')); const result = await service.readUploadAsBase64( @@ -97,10 +100,58 @@ describe('UploadsService', () => { await expect( service.readUploadAsBase64(orgId, 'other_org/uploads/questionnaire/x.csv'), ).rejects.toThrow(/does not belong to this organization/); + expect(s3.getObjectContentLength).not.toHaveBeenCalled(); expect(s3.getObjectAsBuffer).not.toHaveBeenCalled(); }); - it('throws a clear error when the object is missing in S3', async () => { + it('rejects an oversized object via HEAD, before downloading it', async () => { + // 200MB — over the 100MB default ceiling. + s3.getObjectContentLength.mockResolvedValueOnce(200 * 1024 * 1024); + + await expect( + service.readUploadAsBase64(orgId, `${orgId}/uploads/document/huge.bin`), + ).rejects.toThrow(/maximum allowed size/); + + // The whole point: we must NOT download the body for an oversized file. + expect(s3.getObjectAsBuffer).not.toHaveBeenCalled(); + }); + + it('honors a caller-provided maxBytes ceiling', async () => { + s3.getObjectContentLength.mockResolvedValueOnce(2 * 1024 * 1024); // 2MB + + await expect( + service.readUploadAsBase64( + orgId, + `${orgId}/uploads/document/2mb.bin`, + 1 * 1024 * 1024, // 1MB cap + ), + ).rejects.toThrow(/maximum allowed size/); + expect(s3.getObjectAsBuffer).not.toHaveBeenCalled(); + }); + + it('proceeds when S3 does not report a content length', async () => { + s3.getObjectContentLength.mockResolvedValueOnce(undefined); + s3.getObjectAsBuffer.mockResolvedValueOnce(Buffer.from('abc')); + + const result = await service.readUploadAsBase64( + orgId, + `${orgId}/uploads/document/unknown-size.bin`, + ); + + expect(result).toBe(Buffer.from('abc').toString('base64')); + }); + + it('throws a clear error when the object cannot be stat-ed (HEAD fails)', async () => { + s3.getObjectContentLength.mockRejectedValueOnce(new Error('NoSuchKey')); + + await expect( + service.readUploadAsBase64(orgId, `${orgId}/uploads/questionnaire/missing.csv`), + ).rejects.toThrow(/No file found/); + expect(s3.getObjectAsBuffer).not.toHaveBeenCalled(); + }); + + it('throws a clear error when the object body fails to download', async () => { + s3.getObjectContentLength.mockResolvedValueOnce(100); s3.getObjectAsBuffer.mockRejectedValueOnce(new Error('NoSuchKey')); await expect( diff --git a/apps/api/src/uploads/uploads.service.ts b/apps/api/src/uploads/uploads.service.ts index 518ee993b7..1ee857e304 100644 --- a/apps/api/src/uploads/uploads.service.ts +++ b/apps/api/src/uploads/uploads.service.ts @@ -1,10 +1,17 @@ import { BadRequestException, Injectable, Logger } from '@nestjs/common'; import { PutObjectCommand } from '@aws-sdk/client-s3'; -import { BUCKET_NAME, getObjectAsBuffer, getSignedUrl, s3Client } from '../app/s3'; +import { + BUCKET_NAME, + getObjectAsBuffer, + getObjectContentLength, + getSignedUrl, + s3Client, +} from '../app/s3'; import { CreateUploadUrlDto, UploadUrlResponseDto, } from './dto/create-upload-url.dto'; +import { MAX_UPLOAD_BYTES } from './upload-limits'; /** * ============================================================================ @@ -47,6 +54,14 @@ export class UploadsService { * of a leaked URL, long enough for a real upload. */ private static readonly UPLOAD_URL_TTL_SECONDS = 900; + /** + * Default ceiling for files read back from S3 via the presigned flow. A plain + * presigned PUT cannot enforce a size limit, so this is the backstop that + * stops an oversized upload from being loaded into memory. Shares the 100MB + * limit the feature services / DTOs enforce (see upload-limits.ts). + */ + static readonly DEFAULT_MAX_UPLOAD_BYTES = MAX_UPLOAD_BYTES; + /** * Generate a presigned S3 PUT URL plus the org-scoped key the file will land * at. The key prefix is always `{organizationId}/uploads/{purpose}/` so files @@ -97,12 +112,39 @@ export class UploadsService { async readUploadAsBase64( organizationId: string, s3Key: string, + maxBytes: number = UploadsService.DEFAULT_MAX_UPLOAD_BYTES, ): Promise { if (!BUCKET_NAME) { throw new BadRequestException('File storage is not configured'); } this.assertKeyBelongsToOrg(organizationId, s3Key); + // Reject oversized uploads via a HEAD request BEFORE downloading and + // base64-encoding the object. A presigned PUT can't cap upload size, so + // without this an authenticated client could PUT a multi-GB file and have + // the API load it fully into memory (buffer + ~1.33x base64) and OOM. + let contentLength: number | undefined; + try { + contentLength = await getObjectContentLength(BUCKET_NAME, s3Key); + } catch (error) { + this.logger.warn( + `Failed to stat uploaded file ${s3Key}: ${ + error instanceof Error ? error.message : 'unknown error' + }`, + ); + throw new BadRequestException( + 'No file found at the given s3Key — upload it via the presigned URL first.', + ); + } + + if (contentLength !== undefined && contentLength > maxBytes) { + throw new BadRequestException( + `File exceeds the maximum allowed size of ${Math.floor( + maxBytes / (1024 * 1024), + )}MB`, + ); + } + try { const buffer = await getObjectAsBuffer(BUCKET_NAME, s3Key); return buffer.toString('base64'); diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection.tsx index ec3d58383d..1d1b4cc9d8 100644 --- a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection.tsx +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection.tsx @@ -47,6 +47,7 @@ import type { Finding } from '../types'; import { CheckDefinitionPanel } from './CheckDefinitionPanel'; import { CheckGroupBlock } from './CheckGroupBlock'; import { buildCheckGroups } from './check-groups'; +import { filterFindingsByConnection } from './finding-filters'; import { EvidenceJsonViewer } from './EvidenceJsonViewer'; import { MarkExceptionModal } from './MarkExceptionModal'; import { RemediationSection } from './RemediationSection'; @@ -285,28 +286,26 @@ export function CloudTestsSection({ ); const findings = useMemo(() => { - return allFindings - .filter((f) => f.providerSlug === providerSlug || f.connectionId === connectionId) + // Scope to the selected connection (= the selected cloud account) so + // picking a different account narrows the list to that account's findings. + return filterFindingsByConnection(allFindings, connectionId) .filter((f) => !projectFilter || f.projectDisplayName === projectFilter) .sort( (a, b) => (SEVERITY_ORDER[a.severity ?? 'info'] ?? 5) - (SEVERITY_ORDER[b.severity ?? 'info'] ?? 5), ); - }, [allFindings, providerSlug, connectionId, projectFilter]); + }, [allFindings, connectionId, projectFilter]); - // Unique project names across all findings (for filter pills) + // Unique project names for the selected connection (for filter pills) const projectNames = useMemo(() => { const names = new Set(); - for (const f of allFindings) { - if ( - (f.providerSlug === providerSlug || f.connectionId === connectionId) && - f.projectDisplayName - ) { + for (const f of filterFindingsByConnection(allFindings, connectionId)) { + if (f.projectDisplayName) { names.add(f.projectDisplayName); } } return [...names].sort((a, b) => a.localeCompare(b)); - }, [allFindings, providerSlug, connectionId]); + }, [allFindings, connectionId]); const failedFindings = findings.filter((f) => f.status === 'failed' || f.status === 'FAILED'); const passedFindings = findings.filter((f) => f.status === 'passed' || f.status === 'success'); diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx index 516da323d1..3da7e2c21f 100644 --- a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx @@ -279,6 +279,15 @@ export function ProviderTabs({ {providerTypes.map((providerType) => { const connections = providerGroups[providerType] || []; const activeConnId = activeConnectionTabs[providerType] || connections[0]?.id; + // The connection selector also filters the findings list, so label it + // with the provider's term for a connection (AWS account, Azure + // subscription, etc.) to make that clear. + const connectionNoun = + providerType === 'aws' + ? 'Account' + : providerType === 'azure' + ? 'Subscription' + : 'Connection'; if (connections.length === 0) { return ; @@ -292,27 +301,32 @@ export function ProviderTabs({ onValueChange={(value) => onConnectionTabChange(providerType, value)} >
- +
+ + {connectionNoun}: + + +
{/* Only show "Add connection" button for providers that support multiple connections */} {canAddConnection !== false && connections.some((c) => c.supportsMultipleConnections) && (