From 681610cba7323c2146455bd14f8b19fd3130b3d1 Mon Sep 17 00:00:00 2001 From: Artem Niehrieiev Date: Tue, 2 Jun 2026 10:15:30 +0000 Subject: [PATCH] feat: implement permission checks for qai requests and add tests --- .../collect-mongo-pipeline-collections.ts | 59 ++++++++ ...est-info-from-table-with-ai-v7.use.case.ts | 137 ++++++++++++++++-- ...collect-mongo-pipeline-collections.test.ts | 57 ++++++++ 3 files changed, 239 insertions(+), 14 deletions(-) create mode 100644 backend/src/ai-core/tools/collect-mongo-pipeline-collections.ts create mode 100644 backend/test/ava-tests/non-saas-tests/non-saas-collect-mongo-pipeline-collections.test.ts diff --git a/backend/src/ai-core/tools/collect-mongo-pipeline-collections.ts b/backend/src/ai-core/tools/collect-mongo-pipeline-collections.ts new file mode 100644 index 000000000..cf4dcf26e --- /dev/null +++ b/backend/src/ai-core/tools/collect-mongo-pipeline-collections.ts @@ -0,0 +1,59 @@ +import { CollectQueryTablesResult } from '../../entities/visualizations/panel/utils/collect-query-tables.util.js'; +import { getErrorMessage } from '../../helpers/get-error-message.js'; + +/** + * Recursively collects collection names referenced by stages that read from + * other collections (`$lookup`, `$graphLookup`, `$unionWith`) anywhere in a + * MongoDB aggregation pipeline, including nested sub-pipelines. + */ +function collectReferencedCollections(node: unknown, collected: Set): void { + if (Array.isArray(node)) { + for (const item of node) { + collectReferencedCollections(item, collected); + } + return; + } + if (!node || typeof node !== 'object') { + return; + } + for (const [key, value] of Object.entries(node as Record)) { + if (key === '$lookup' || key === '$graphLookup') { + const from = (value as { from?: unknown })?.from; + if (typeof from === 'string' && from.length > 0) { + collected.add(from); + } + } else if (key === '$unionWith') { + // `$unionWith` accepts either a collection-name string or `{ coll: , pipeline: [...] }`. + if (typeof value === 'string' && value.length > 0) { + collected.add(value); + } else { + const coll = (value as { coll?: unknown })?.coll; + if (typeof coll === 'string' && coll.length > 0) { + collected.add(coll); + } + } + } + collectReferencedCollections(value, collected); + } +} + +/** + * Resolves the collections a MongoDB aggregation pipeline reads from besides + * its base collection (the `$lookup` / `$graphLookup` / `$unionWith` targets), + * so the caller can verify the user has read permission on each. + * + * Returns `{ kind: 'tables' }` (possibly empty) when the pipeline parses, and + * `{ kind: 'indeterminate' }` when it cannot be parsed — in which case the + * caller must fall back to a stricter check rather than assume it is harmless. + */ +export function collectMongoPipelineCollections(pipeline: string): CollectQueryTablesResult { + let parsedPipeline: unknown; + try { + parsedPipeline = JSON.parse(pipeline); + } catch (error) { + return { kind: 'indeterminate', reason: `pipeline parse error: ${getErrorMessage(error)}` }; + } + const collected = new Set(); + collectReferencedCollections(parsedPipeline, collected); + return { kind: 'tables', tables: Array.from(collected) }; +} diff --git a/backend/src/entities/ai/use-cases/request-info-from-table-with-ai-v7.use.case.ts b/backend/src/entities/ai/use-cases/request-info-from-table-with-ai-v7.use.case.ts index a060ebd3a..72efeccf6 100644 --- a/backend/src/entities/ai/use-cases/request-info-from-table-with-ai-v7.use.case.ts +++ b/backend/src/entities/ai/use-cases/request-info-from-table-with-ai-v7.use.case.ts @@ -1,5 +1,13 @@ import { BaseMessage } from '@langchain/core/messages'; -import { BadRequestException, Inject, Injectable, Logger, NotFoundException, Scope } from '@nestjs/common'; +import { + BadRequestException, + ForbiddenException, + Inject, + Injectable, + Logger, + NotFoundException, + Scope, +} from '@nestjs/common'; import { getDataAccessObject } from '@rocketadmin/shared-code/dist/src/data-access-layer/shared/create-data-access-object.js'; import { ConnectionTypesEnum } from '@rocketadmin/shared-code/dist/src/shared/enums/connection-types-enum.js'; import { IDataAccessObject } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object.interface.js'; @@ -9,6 +17,7 @@ import { Response } from 'express'; import { AIToolCall, AIToolDefinition } from '../../../ai-core/interfaces/ai-provider.interface.js'; import { AIProviderType } from '../../../ai-core/interfaces/ai-service.interface.js'; import { AICoreService } from '../../../ai-core/services/ai-core.service.js'; +import { collectMongoPipelineCollections } from '../../../ai-core/tools/collect-mongo-pipeline-collections.js'; import { createDatabaseTools } from '../../../ai-core/tools/database-tools.js'; import { searchDocumentation } from '../../../ai-core/tools/documentation-search.js'; import { createDatabaseQuerySystemPrompt } from '../../../ai-core/tools/prompts.js'; @@ -22,7 +31,9 @@ import { Messages } from '../../../exceptions/text/messages.js'; import { getErrorMessage } from '../../../helpers/get-error-message.js'; import { isConnectionTypeAgent } from '../../../helpers/is-connection-entity-agent.js'; import { slackPostMessage } from '../../../helpers/slack/slack-post-message.js'; +import { CedarPermissionsService } from '../../cedar-authorization/cedar-permissions.service.js'; import { ConnectionEntity } from '../../connection/connection.entity.js'; +import { assertUserCanReadQueryTables } from '../../visualizations/panel/utils/assert-query-tables-readable.util.js'; import { MessageRole } from '../ai-conversation-history/ai-chat-messages/message-role.enum.js'; import { UserAiChatEntity } from '../ai-conversation-history/user-ai-chat/user-ai-chat.entity.js'; import { IRequestInfoFromTableV2 } from '../ai-use-cases.interface.js'; @@ -41,6 +52,7 @@ export class RequestInfoFromTableWithAIUseCaseV7 @Inject(BaseType.GLOBAL_DB_CONTEXT) protected _dbContext: IGlobalDatabaseContext, private readonly aiCoreService: AICoreService, + private readonly cedarPermissions: CedarPermissionsService, ) { super(); } @@ -104,6 +116,7 @@ export class RequestInfoFromTableWithAIUseCaseV7 tableName, userEmail, foundConnection, + user_id, ); if (accumulatedResponse) { @@ -132,6 +145,7 @@ export class RequestInfoFromTableWithAIUseCaseV7 inputTableName: string, userEmail: string, foundConnection: ConnectionEntity, + userId: string, ): Promise { let currentMessages = [...messages]; let depth = 0; @@ -178,6 +192,7 @@ export class RequestInfoFromTableWithAIUseCaseV7 inputTableName, userEmail, foundConnection, + userId, ); for (const toolResult of toolResults) { @@ -226,6 +241,7 @@ export class RequestInfoFromTableWithAIUseCaseV7 inputTableName: string, userEmail: string, foundConnection: ConnectionEntity, + userId: string, ): Promise> { const results: Array<{ toolCallId: string; result: string }> = []; @@ -236,11 +252,13 @@ export class RequestInfoFromTableWithAIUseCaseV7 switch (toolCall.name) { case 'getTableStructure': { const tableName = (toolCall.arguments.tableName as string) || inputTableName; + await this.assertUserCanReadTables([tableName], userId, foundConnection.id); const structureInfo = await this.getTableStructureInfo( dataAccessObject, tableName, userEmail, foundConnection, + userId, ); result = encodeToToon(structureInfo); break; @@ -256,6 +274,14 @@ export class RequestInfoFromTableWithAIUseCaseV7 'Invalid SQL query. Please ensure it is a read-only SELECT statement without any forbidden keywords.', ); } + await assertUserCanReadQueryTables({ + query, + connectionType: foundConnection.type as ConnectionTypesEnum, + connectionId: foundConnection.id, + validateTableRead: (referencedTableName) => + this.cedarPermissions.improvedCheckTableRead(userId, foundConnection.id, referencedTableName), + listAllTableNames: async () => (await dataAccessObject.getTablesFromDB()).map((table) => table.tableName), + }); const wrappedQuery = wrapQueryWithLimit(query, foundConnection.type as ConnectionTypesEnum); const queryResult = await dataAccessObject.executeRawQuery(wrappedQuery, inputTableName, userEmail); result = encodeToToon(queryResult); @@ -272,6 +298,13 @@ export class RequestInfoFromTableWithAIUseCaseV7 'Invalid MongoDB command. Please ensure it is a read-only aggregation pipeline without any forbidden keywords.', ); } + await this.assertUserCanReadPipelineCollections( + pipeline, + inputTableName, + userId, + foundConnection.id, + dataAccessObject, + ); const pipelineResult = await dataAccessObject.executeRawQuery(pipeline, inputTableName, userEmail); result = encodeToToon(pipelineResult); break; @@ -307,6 +340,7 @@ export class RequestInfoFromTableWithAIUseCaseV7 tableName: string, userEmail: string, foundConnection: ConnectionEntity, + userId: string, ) { const [tableStructure, tableForeignKeys, referencedTableNamesAndColumns] = await Promise.all([ dao.getTableStructure(tableName, userEmail), @@ -314,25 +348,42 @@ export class RequestInfoFromTableWithAIUseCaseV7 dao.getReferencedTableNamesAndColumns(tableName, userEmail), ]); + // Only expose the structure of related tables the user is permitted to + // read — otherwise foreign-key traversal would leak the schema of tables + // the user has no access to. const referencedTablesStructures = []; const structurePromises = referencedTableNamesAndColumns.flatMap((referencedTable) => - referencedTable.referenced_by.map((table) => - dao.getTableStructure(table.table_name, userEmail).then((structure) => ({ - tableName: table.table_name, - structure, - })), - ), + referencedTable.referenced_by.map(async (table) => { + const canRead = await this.cedarPermissions.improvedCheckTableRead( + userId, + foundConnection.id, + table.table_name, + ); + if (!canRead) { + return null; + } + const structure = await dao.getTableStructure(table.table_name, userEmail); + return { tableName: table.table_name, structure }; + }), ); - referencedTablesStructures.push(...(await Promise.all(structurePromises))); + referencedTablesStructures.push(...(await Promise.all(structurePromises)).filter((item) => item !== null)); const foreignTablesStructures = []; - const foreignTablesStructurePromises = tableForeignKeys.flatMap((foreignKey) => - dao.getTableStructure(foreignKey.referenced_table_name, userEmail).then((structure) => ({ - tableName: foreignKey.referenced_table_name, - structure, - })), + const foreignTablesStructurePromises = tableForeignKeys.map(async (foreignKey) => { + const canRead = await this.cedarPermissions.improvedCheckTableRead( + userId, + foundConnection.id, + foreignKey.referenced_table_name, + ); + if (!canRead) { + return null; + } + const structure = await dao.getTableStructure(foreignKey.referenced_table_name, userEmail); + return { tableName: foreignKey.referenced_table_name, structure }; + }); + foreignTablesStructures.push( + ...(await Promise.all(foreignTablesStructurePromises)).filter((item) => item !== null), ); - foreignTablesStructures.push(...(await Promise.all(foreignTablesStructurePromises))); return { tableStructure, @@ -345,6 +396,64 @@ export class RequestInfoFromTableWithAIUseCaseV7 }; } + /** + * Verifies the user has read permission on every supplied table before the + * AI is allowed to query or inspect them. Throws a `ForbiddenException` on + * the first unreadable table; inside the tool loop this surfaces back to the + * model as a tool error, so the offending query is never executed. Empty or + * blank names are ignored. + */ + private async assertUserCanReadTables( + tableNames: Array, + userId: string, + connectionId: string, + ): Promise { + const uniqueTableNames = Array.from( + new Set(tableNames.map((name) => name?.trim()).filter((name): name is string => Boolean(name))), + ); + + for (const tableName of uniqueTableNames) { + const canRead = await this.cedarPermissions.improvedCheckTableRead(userId, connectionId, tableName); + if (!canRead) { + this.logger.warn( + `AI request blocked for user ${userId} on connection ${connectionId}: ` + + `no read permission for table "${tableName}"`, + ); + throw new ForbiddenException(Messages.NO_READ_PERMISSION_FOR_TABLE(tableName)); + } + } + } + + /** + * Guards a MongoDB aggregation pipeline against table-level read permissions: + * the user must be able to read the base collection and every collection the + * pipeline pulls in (`$lookup` / `$graphLookup` / `$unionWith`). When the + * pipeline cannot be parsed we cannot trust it to be harmless, so we fall + * back to requiring read permission on every collection in the connection. + */ + private async assertUserCanReadPipelineCollections( + pipeline: string, + baseCollection: string, + userId: string, + connectionId: string, + dataAccessObject: IDataAccessObject | IDataAccessObjectAgent, + ): Promise { + const collected = collectMongoPipelineCollections(pipeline); + + let collectionsToCheck: Array; + if (collected.kind === 'tables') { + collectionsToCheck = [baseCollection, ...collected.tables]; + } else { + this.logger.warn( + `AI pipeline permission check could not resolve referenced collections for connection ${connectionId} ` + + `(reason: ${collected.reason}); falling back to all-collections read check.`, + ); + collectionsToCheck = (await dataAccessObject.getTablesFromDB()).map((table) => table.tableName); + } + + await this.assertUserCanReadTables(collectionsToCheck, userId, connectionId); + } + private setupResponseHeaders(response: Response): void { response.setHeader('Content-Type', 'text/event-stream'); response.setHeader('Cache-Control', 'no-cache'); diff --git a/backend/test/ava-tests/non-saas-tests/non-saas-collect-mongo-pipeline-collections.test.ts b/backend/test/ava-tests/non-saas-tests/non-saas-collect-mongo-pipeline-collections.test.ts new file mode 100644 index 000000000..2d47b1651 --- /dev/null +++ b/backend/test/ava-tests/non-saas-tests/non-saas-collect-mongo-pipeline-collections.test.ts @@ -0,0 +1,57 @@ +import test from 'ava'; +import { collectMongoPipelineCollections } from '../../../src/ai-core/tools/collect-mongo-pipeline-collections.js'; + +function tablesOf(pipeline: string): Array { + const result = collectMongoPipelineCollections(pipeline); + if (result.kind !== 'tables') { + throw new Error(`expected resolved tables, got indeterminate: ${result.reason}`); + } + return [...result.tables].sort(); +} + +test('resolves no referenced collections for a pipeline without joins', (t) => { + t.deepEqual(tablesOf('[{"$match":{"status":"active"}},{"$group":{"_id":"$type"}}]'), []); +}); + +test('resolves a $lookup target collection', (t) => { + t.deepEqual(tablesOf('[{"$lookup":{"from":"salaries","localField":"id","foreignField":"user_id","as":"s"}}]'), [ + 'salaries', + ]); +}); + +test('resolves a $graphLookup target collection', (t) => { + t.deepEqual( + tablesOf( + '[{"$graphLookup":{"from":"org_chart","startWith":"$managerId","connectFromField":"managerId","connectToField":"_id","as":"chain"}}]', + ), + ['org_chart'], + ); +}); + +test('resolves a $unionWith string collection', (t) => { + t.deepEqual(tablesOf('[{"$unionWith":"archived_orders"}]'), ['archived_orders']); +}); + +test('resolves a $unionWith object collection', (t) => { + t.deepEqual(tablesOf('[{"$unionWith":{"coll":"audit_log","pipeline":[]}}]'), ['audit_log']); +}); + +test('resolves collections nested inside a $lookup sub-pipeline', (t) => { + const pipeline = + '[{"$lookup":{"from":"orders","as":"o","pipeline":[{"$lookup":{"from":"secret_payouts","localField":"a","foreignField":"b","as":"p"}}]}}]'; + t.deepEqual(tablesOf(pipeline), ['orders', 'secret_payouts']); +}); + +test('deduplicates repeated collection references', (t) => { + const pipeline = + '[{"$lookup":{"from":"orders","localField":"a","foreignField":"b","as":"o1"}},{"$lookup":{"from":"orders","localField":"c","foreignField":"d","as":"o2"}}]'; + t.deepEqual(tablesOf(pipeline), ['orders']); +}); + +test('returns indeterminate for an unparseable pipeline', (t) => { + const result = collectMongoPipelineCollections('not valid json {'); + t.is(result.kind, 'indeterminate'); + if (result.kind === 'indeterminate') { + t.true(result.reason.includes('parse error')); + } +});