Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions apps/sim/app/api/knowledge/utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,32 @@ vi.mock('@sim/db', async () => {
},
}
},
innerJoin() {
// document × knowledge_base context JOIN — return the first kb and
// doc row merged (covers processDocumentAsync's prefetch).
return {
leftJoin: () => ({
where: () => ({
limit: (n: number) =>
Promise.resolve(
kbRows.length > 0 && docRows.length > 0
? [
{ ...kbRows[0], ...docRows[0], billedAccountUserId: 'billing-user-1' },
].slice(0, n)
: []
),
}),
}),
where: () => ({
limit: (n: number) =>
Promise.resolve(
kbRows.length > 0 && docRows.length > 0
? [{ ...kbRows[0], ...docRows[0] }].slice(0, n)
: []
),
}),
}
},
}
},
}
Expand Down
148 changes: 82 additions & 66 deletions apps/sim/lib/knowledge/documents/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
knowledgeBase,
knowledgeBaseTagDefinitions,
knowledgeConnector,
workspace as workspaceTable,
} from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { sha256Hex } from '@sim/security/hash'
Expand Down Expand Up @@ -47,7 +48,6 @@ import type { ProcessedDocumentTags } from '@/lib/knowledge/types'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { deleteFile } from '@/lib/uploads/core/storage-service'
import { extractStorageKey } from '@/lib/uploads/utils/file-utils'
import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils'
import type {
DocumentProcessingPayload,
processDocument as processDocumentTask,
Expand Down Expand Up @@ -111,11 +111,26 @@ interface DocumentTagData {
value: string
}

async function processDocumentTags(
type TagDefinition = typeof knowledgeBaseTagDefinitions.$inferSelect
type TagDefinitionsByName = Map<string, TagDefinition>
type DbExecutor = Pick<typeof db, 'select'>

async function loadTagDefinitions(
knowledgeBaseId: string,
executor: DbExecutor = db
): Promise<TagDefinitionsByName> {
const defs = await executor
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
return new Map(defs.map((def) => [def.displayName, def]))
}

function resolveDocumentTags(
tagData: DocumentTagData[],
tagDefinitions: TagDefinitionsByName,
requestId: string
): Promise<ProcessedDocumentTags> {
): ProcessedDocumentTags {
const setTagValue = (
tags: ProcessedDocumentTags,
slot: string,
Expand Down Expand Up @@ -200,13 +215,6 @@ async function processDocumentTags(
return result
}

const existingDefinitions = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))

const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))

const undefinedTags: string[] = []
const typeErrors: string[] = []

Expand All @@ -223,7 +231,7 @@ async function processDocumentTags(

if (!hasValue) continue

const existingDef = existingByName.get(tagName)
const existingDef = tagDefinitions.get(tagName)
if (!existingDef) {
undefinedTags.push(tagName)
continue
Expand Down Expand Up @@ -264,7 +272,7 @@ async function processDocumentTags(

if (!hasValue) continue

const existingDef = existingByName.get(tagName)
const existingDef = tagDefinitions.get(tagName)
if (!existingDef) continue

const targetSlot = existingDef.tagSlot
Expand Down Expand Up @@ -418,34 +426,66 @@ export async function processDocumentAsync(
try {
logger.info(`[${documentId}] Starting document processing: ${docData.filename}`)

const kb = await db
// KB config + workspace billing + doc tags in one JOIN (was 3 SELECTs).
const contextRows = await db
.select({
userId: knowledgeBase.userId,
workspaceId: knowledgeBase.workspaceId,
chunkingConfig: knowledgeBase.chunkingConfig,
embeddingModel: knowledgeBase.embeddingModel,
billedAccountUserId: workspaceTable.billedAccountUserId,
tag1: document.tag1,
tag2: document.tag2,
tag3: document.tag3,
tag4: document.tag4,
tag5: document.tag5,
tag6: document.tag6,
tag7: document.tag7,
number1: document.number1,
number2: document.number2,
number3: document.number3,
number4: document.number4,
number5: document.number5,
date1: document.date1,
date2: document.date2,
boolean1: document.boolean1,
boolean2: document.boolean2,
boolean3: document.boolean3,
})
.from(knowledgeBase)
.where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt)))
.from(document)
.innerJoin(knowledgeBase, eq(knowledgeBase.id, document.knowledgeBaseId))
.leftJoin(
workspaceTable,
and(eq(workspaceTable.id, knowledgeBase.workspaceId), isNull(workspaceTable.archivedAt))
)
.where(
and(
eq(document.id, documentId),
eq(knowledgeBase.id, knowledgeBaseId),
isNull(document.archivedAt),
isNull(document.deletedAt),
isNull(knowledgeBase.deletedAt)
)
)
.limit(1)

if (kb.length === 0) {
if (contextRows.length === 0) {
logger.warn(
`[${documentId}] Skipping document processing: knowledge base ${knowledgeBaseId} is deleted`
`[${documentId}] Skipping document processing: document or knowledge base ${knowledgeBaseId} no longer exists`
)
await db
.update(document)
.set({
processingStatus: 'failed',
processingError: 'Knowledge base deleted',
processingError: 'Document or knowledge base no longer exists',
processingCompletedAt: new Date(),
})
.where(
and(eq(document.id, documentId), isNull(document.archivedAt), isNull(document.deletedAt))
)
.where(eq(document.id, documentId))
return
}

const ctx = contextRows[0]

await db
.update(document)
.set({
Expand All @@ -460,7 +500,7 @@ export async function processDocumentAsync(

logger.info(`[${documentId}] Status updated to 'processing', starting document processor`)

const rawConfig = kb[0].chunkingConfig as {
const rawConfig = ctx.chunkingConfig as {
maxSize?: number
minSize?: number
overlap?: number
Expand All @@ -473,13 +513,13 @@ export async function processDocumentAsync(
overlap: rawConfig?.overlap ?? 200,
}

const kbEmbeddingModel = kb[0].embeddingModel
if (!kb[0].workspaceId) {
const kbEmbeddingModel = ctx.embeddingModel
if (!ctx.workspaceId) {
throw new Error(`Knowledge base ${knowledgeBaseId} is missing workspace billing context`)
}
const billingUserId = await getWorkspaceBilledAccountUserId(kb[0].workspaceId)
const billingUserId = ctx.billedAccountUserId
if (!billingUserId) {
throw new Error(`Workspace ${kb[0].workspaceId} is missing billed account`)
throw new Error(`Workspace ${ctx.workspaceId} is missing billed account`)
}
let totalEmbeddingTokens = 0
let embeddingIsBYOK = false
Expand All @@ -495,8 +535,8 @@ export async function processDocumentAsync(
kbConfig.maxSize,
kbConfig.overlap,
kbConfig.minSize,
kb[0].userId,
kb[0].workspaceId,
ctx.userId,
ctx.workspaceId,
rawConfig?.strategy,
rawConfig?.strategyOptions
)
Expand Down Expand Up @@ -534,7 +574,7 @@ export async function processDocumentAsync(
isBYOK,
modelName,
pricingId,
} = await generateEmbeddings(batch, kbEmbeddingModel, kb[0].workspaceId)
} = await generateEmbeddings(batch, kbEmbeddingModel, ctx.workspaceId)
for (const emb of batchEmbeddings) {
embeddings.push(emb)
}
Expand All @@ -547,41 +587,10 @@ export async function processDocumentAsync(
}
}

logger.info(`[${documentId}] Embeddings generated, fetching document tags`)

const documentRecord = await db
.select({
tag1: document.tag1,
tag2: document.tag2,
tag3: document.tag3,
tag4: document.tag4,
tag5: document.tag5,
tag6: document.tag6,
tag7: document.tag7,
number1: document.number1,
number2: document.number2,
number3: document.number3,
number4: document.number4,
number5: document.number5,
date1: document.date1,
date2: document.date2,
boolean1: document.boolean1,
boolean2: document.boolean2,
boolean3: document.boolean3,
})
.from(document)
.where(
and(
eq(document.id, documentId),
isNull(document.archivedAt),
isNull(document.deletedAt)
)
)
.limit(1)

const documentTags = documentRecord[0] || {}
// Tag values prefetched above; reuse for the embedding rows.
const documentTags = ctx

logger.info(`[${documentId}] Creating embedding records with tags`)
logger.info(`[${documentId}] Embeddings generated, creating embedding records with tags`)

const tokenizerProvider = getEmbeddingModelInfo(kbEmbeddingModel).tokenizerProvider

Expand Down Expand Up @@ -686,7 +695,7 @@ export async function processDocumentAsync(
if (cost > 0) {
await recordUsage({
userId: billingUserId,
workspaceId: kb[0].workspaceId ?? undefined,
workspaceId: ctx.workspaceId ?? undefined,
entries: [
{
category: 'model',
Expand Down Expand Up @@ -770,6 +779,12 @@ export async function createDocumentRecords(
throw new Error('Knowledge base not found')
}

// One load per batch (was N+1); skip entirely if no doc carries tags.
const hasTaggedDocs = documents.some((d) => d.documentTagsData)
const tagDefinitions = hasTaggedDocs
? await loadTagDefinitions(knowledgeBaseId, tx)
: (new Map() as TagDefinitionsByName)

const now = new Date()
const documentRecords = []
const returnData: DocumentData[] = []
Expand All @@ -783,7 +798,7 @@ export async function createDocumentRecords(
try {
const tagData = JSON.parse(docData.documentTagsData)
if (Array.isArray(tagData)) {
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
processedTags = resolveDocumentTags(tagData, tagDefinitions, requestId)
}
} catch (error) {
if (error instanceof SyntaxError) {
Expand Down Expand Up @@ -1277,7 +1292,8 @@ export async function createSingleDocument(
try {
const tagData = JSON.parse(documentData.documentTagsData)
if (Array.isArray(tagData)) {
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
const tagDefinitions = await loadTagDefinitions(knowledgeBaseId)
processedTags = resolveDocumentTags(tagData, tagDefinitions, requestId)
}
} catch (error) {
if (error instanceof SyntaxError) {
Expand Down
Loading