diff --git a/apps/sim/app/api/knowledge/utils.test.ts b/apps/sim/app/api/knowledge/utils.test.ts index 1886c1659e1..326bbee660f 100644 --- a/apps/sim/app/api/knowledge/utils.test.ts +++ b/apps/sim/app/api/knowledge/utils.test.ts @@ -116,6 +116,32 @@ vi.mock('@sim/db', async () => { }, } }, + innerJoin() { + // document × knowledge_base context JOIN — return the first kb and + // doc row merged (covers processDocumentAsync's prefetch). + return { + leftJoin: () => ({ + where: () => ({ + limit: (n: number) => + Promise.resolve( + kbRows.length > 0 && docRows.length > 0 + ? [ + { ...kbRows[0], ...docRows[0], billedAccountUserId: 'billing-user-1' }, + ].slice(0, n) + : [] + ), + }), + }), + where: () => ({ + limit: (n: number) => + Promise.resolve( + kbRows.length > 0 && docRows.length > 0 + ? [{ ...kbRows[0], ...docRows[0] }].slice(0, n) + : [] + ), + }), + } + }, } }, } diff --git a/apps/sim/lib/knowledge/documents/service.ts b/apps/sim/lib/knowledge/documents/service.ts index be71c279717..687f3ef3310 100644 --- a/apps/sim/lib/knowledge/documents/service.ts +++ b/apps/sim/lib/knowledge/documents/service.ts @@ -5,6 +5,7 @@ import { knowledgeBase, knowledgeBaseTagDefinitions, knowledgeConnector, + workspace as workspaceTable, } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { sha256Hex } from '@sim/security/hash' @@ -47,7 +48,6 @@ import type { ProcessedDocumentTags } from '@/lib/knowledge/types' import { estimateTokenCount } from '@/lib/tokenization/estimators' import { deleteFile } from '@/lib/uploads/core/storage-service' import { extractStorageKey } from '@/lib/uploads/utils/file-utils' -import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils' import type { DocumentProcessingPayload, processDocument as processDocumentTask, @@ -111,11 +111,26 @@ interface DocumentTagData { value: string } -async function processDocumentTags( +type TagDefinition = typeof knowledgeBaseTagDefinitions.$inferSelect +type TagDefinitionsByName = Map +type DbExecutor = Pick + +async function loadTagDefinitions( knowledgeBaseId: string, + executor: DbExecutor = db +): Promise { + const defs = await executor + .select() + .from(knowledgeBaseTagDefinitions) + .where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId)) + return new Map(defs.map((def) => [def.displayName, def])) +} + +function resolveDocumentTags( tagData: DocumentTagData[], + tagDefinitions: TagDefinitionsByName, requestId: string -): Promise { +): ProcessedDocumentTags { const setTagValue = ( tags: ProcessedDocumentTags, slot: string, @@ -200,13 +215,6 @@ async function processDocumentTags( return result } - const existingDefinitions = await db - .select() - .from(knowledgeBaseTagDefinitions) - .where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId)) - - const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def])) - const undefinedTags: string[] = [] const typeErrors: string[] = [] @@ -223,7 +231,7 @@ async function processDocumentTags( if (!hasValue) continue - const existingDef = existingByName.get(tagName) + const existingDef = tagDefinitions.get(tagName) if (!existingDef) { undefinedTags.push(tagName) continue @@ -264,7 +272,7 @@ async function processDocumentTags( if (!hasValue) continue - const existingDef = existingByName.get(tagName) + const existingDef = tagDefinitions.get(tagName) if (!existingDef) continue const targetSlot = existingDef.tagSlot @@ -418,34 +426,66 @@ export async function processDocumentAsync( try { logger.info(`[${documentId}] Starting document processing: ${docData.filename}`) - const kb = await db + // KB config + workspace billing + doc tags in one JOIN (was 3 SELECTs). + const contextRows = await db .select({ userId: knowledgeBase.userId, workspaceId: knowledgeBase.workspaceId, chunkingConfig: knowledgeBase.chunkingConfig, embeddingModel: knowledgeBase.embeddingModel, + billedAccountUserId: workspaceTable.billedAccountUserId, + tag1: document.tag1, + tag2: document.tag2, + tag3: document.tag3, + tag4: document.tag4, + tag5: document.tag5, + tag6: document.tag6, + tag7: document.tag7, + number1: document.number1, + number2: document.number2, + number3: document.number3, + number4: document.number4, + number5: document.number5, + date1: document.date1, + date2: document.date2, + boolean1: document.boolean1, + boolean2: document.boolean2, + boolean3: document.boolean3, }) - .from(knowledgeBase) - .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + .from(document) + .innerJoin(knowledgeBase, eq(knowledgeBase.id, document.knowledgeBaseId)) + .leftJoin( + workspaceTable, + and(eq(workspaceTable.id, knowledgeBase.workspaceId), isNull(workspaceTable.archivedAt)) + ) + .where( + and( + eq(document.id, documentId), + eq(knowledgeBase.id, knowledgeBaseId), + isNull(document.archivedAt), + isNull(document.deletedAt), + isNull(knowledgeBase.deletedAt) + ) + ) .limit(1) - if (kb.length === 0) { + if (contextRows.length === 0) { logger.warn( - `[${documentId}] Skipping document processing: knowledge base ${knowledgeBaseId} is deleted` + `[${documentId}] Skipping document processing: document or knowledge base ${knowledgeBaseId} no longer exists` ) await db .update(document) .set({ processingStatus: 'failed', - processingError: 'Knowledge base deleted', + processingError: 'Document or knowledge base no longer exists', processingCompletedAt: new Date(), }) - .where( - and(eq(document.id, documentId), isNull(document.archivedAt), isNull(document.deletedAt)) - ) + .where(eq(document.id, documentId)) return } + const ctx = contextRows[0] + await db .update(document) .set({ @@ -460,7 +500,7 @@ export async function processDocumentAsync( logger.info(`[${documentId}] Status updated to 'processing', starting document processor`) - const rawConfig = kb[0].chunkingConfig as { + const rawConfig = ctx.chunkingConfig as { maxSize?: number minSize?: number overlap?: number @@ -473,13 +513,13 @@ export async function processDocumentAsync( overlap: rawConfig?.overlap ?? 200, } - const kbEmbeddingModel = kb[0].embeddingModel - if (!kb[0].workspaceId) { + const kbEmbeddingModel = ctx.embeddingModel + if (!ctx.workspaceId) { throw new Error(`Knowledge base ${knowledgeBaseId} is missing workspace billing context`) } - const billingUserId = await getWorkspaceBilledAccountUserId(kb[0].workspaceId) + const billingUserId = ctx.billedAccountUserId if (!billingUserId) { - throw new Error(`Workspace ${kb[0].workspaceId} is missing billed account`) + throw new Error(`Workspace ${ctx.workspaceId} is missing billed account`) } let totalEmbeddingTokens = 0 let embeddingIsBYOK = false @@ -495,8 +535,8 @@ export async function processDocumentAsync( kbConfig.maxSize, kbConfig.overlap, kbConfig.minSize, - kb[0].userId, - kb[0].workspaceId, + ctx.userId, + ctx.workspaceId, rawConfig?.strategy, rawConfig?.strategyOptions ) @@ -534,7 +574,7 @@ export async function processDocumentAsync( isBYOK, modelName, pricingId, - } = await generateEmbeddings(batch, kbEmbeddingModel, kb[0].workspaceId) + } = await generateEmbeddings(batch, kbEmbeddingModel, ctx.workspaceId) for (const emb of batchEmbeddings) { embeddings.push(emb) } @@ -547,41 +587,10 @@ export async function processDocumentAsync( } } - logger.info(`[${documentId}] Embeddings generated, fetching document tags`) - - const documentRecord = await db - .select({ - tag1: document.tag1, - tag2: document.tag2, - tag3: document.tag3, - tag4: document.tag4, - tag5: document.tag5, - tag6: document.tag6, - tag7: document.tag7, - number1: document.number1, - number2: document.number2, - number3: document.number3, - number4: document.number4, - number5: document.number5, - date1: document.date1, - date2: document.date2, - boolean1: document.boolean1, - boolean2: document.boolean2, - boolean3: document.boolean3, - }) - .from(document) - .where( - and( - eq(document.id, documentId), - isNull(document.archivedAt), - isNull(document.deletedAt) - ) - ) - .limit(1) - - const documentTags = documentRecord[0] || {} + // Tag values prefetched above; reuse for the embedding rows. + const documentTags = ctx - logger.info(`[${documentId}] Creating embedding records with tags`) + logger.info(`[${documentId}] Embeddings generated, creating embedding records with tags`) const tokenizerProvider = getEmbeddingModelInfo(kbEmbeddingModel).tokenizerProvider @@ -686,7 +695,7 @@ export async function processDocumentAsync( if (cost > 0) { await recordUsage({ userId: billingUserId, - workspaceId: kb[0].workspaceId ?? undefined, + workspaceId: ctx.workspaceId ?? undefined, entries: [ { category: 'model', @@ -770,6 +779,12 @@ export async function createDocumentRecords( throw new Error('Knowledge base not found') } + // One load per batch (was N+1); skip entirely if no doc carries tags. + const hasTaggedDocs = documents.some((d) => d.documentTagsData) + const tagDefinitions = hasTaggedDocs + ? await loadTagDefinitions(knowledgeBaseId, tx) + : (new Map() as TagDefinitionsByName) + const now = new Date() const documentRecords = [] const returnData: DocumentData[] = [] @@ -783,7 +798,7 @@ export async function createDocumentRecords( try { const tagData = JSON.parse(docData.documentTagsData) if (Array.isArray(tagData)) { - processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId) + processedTags = resolveDocumentTags(tagData, tagDefinitions, requestId) } } catch (error) { if (error instanceof SyntaxError) { @@ -1277,7 +1292,8 @@ export async function createSingleDocument( try { const tagData = JSON.parse(documentData.documentTagsData) if (Array.isArray(tagData)) { - processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId) + const tagDefinitions = await loadTagDefinitions(knowledgeBaseId) + processedTags = resolveDocumentTags(tagData, tagDefinitions, requestId) } } catch (error) { if (error instanceof SyntaxError) {