Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions backend/src/ai-core/tools/collect-mongo-pipeline-collections.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { CollectQueryTablesResult } from '../../entities/visualizations/panel/utils/collect-query-tables.util.js';
import { getErrorMessage } from '../../helpers/get-error-message.js';

/**
* Recursively collects collection names referenced by stages that read from
* other collections (`$lookup`, `$graphLookup`, `$unionWith`) anywhere in a
* MongoDB aggregation pipeline, including nested sub-pipelines.
*/
function collectReferencedCollections(node: unknown, collected: Set<string>): void {
if (Array.isArray(node)) {
for (const item of node) {
collectReferencedCollections(item, collected);
}
return;
}
if (!node || typeof node !== 'object') {
return;
}
for (const [key, value] of Object.entries(node as Record<string, unknown>)) {
if (key === '$lookup' || key === '$graphLookup') {
const from = (value as { from?: unknown })?.from;
if (typeof from === 'string' && from.length > 0) {
collected.add(from);
}
} else if (key === '$unionWith') {
// `$unionWith` accepts either a collection-name string or `{ coll: <name>, pipeline: [...] }`.
if (typeof value === 'string' && value.length > 0) {
collected.add(value);
} else {
const coll = (value as { coll?: unknown })?.coll;
if (typeof coll === 'string' && coll.length > 0) {
collected.add(coll);
}
}
}
collectReferencedCollections(value, collected);
}
Comment on lines +20 to +37
}

/**
* Resolves the collections a MongoDB aggregation pipeline reads from besides
* its base collection (the `$lookup` / `$graphLookup` / `$unionWith` targets),
* so the caller can verify the user has read permission on each.
*
* Returns `{ kind: 'tables' }` (possibly empty) when the pipeline parses, and
* `{ kind: 'indeterminate' }` when it cannot be parsed — in which case the
* caller must fall back to a stricter check rather than assume it is harmless.
*/
export function collectMongoPipelineCollections(pipeline: string): CollectQueryTablesResult {
let parsedPipeline: unknown;
try {
parsedPipeline = JSON.parse(pipeline);
} catch (error) {
return { kind: 'indeterminate', reason: `pipeline parse error: ${getErrorMessage(error)}` };
}
const collected = new Set<string>();
collectReferencedCollections(parsedPipeline, collected);
return { kind: 'tables', tables: Array.from(collected) };
Comment on lines +56 to +58
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import { BaseMessage } from '@langchain/core/messages';
import { BadRequestException, Inject, Injectable, Logger, NotFoundException, Scope } from '@nestjs/common';
import {
BadRequestException,
ForbiddenException,
Inject,
Injectable,
Logger,
NotFoundException,
Scope,
} from '@nestjs/common';
import { getDataAccessObject } from '@rocketadmin/shared-code/dist/src/data-access-layer/shared/create-data-access-object.js';
import { ConnectionTypesEnum } from '@rocketadmin/shared-code/dist/src/shared/enums/connection-types-enum.js';
import { IDataAccessObject } from '@rocketadmin/shared-code/dist/src/shared/interfaces/data-access-object.interface.js';
Expand All @@ -9,6 +17,7 @@ import { Response } from 'express';
import { AIToolCall, AIToolDefinition } from '../../../ai-core/interfaces/ai-provider.interface.js';
import { AIProviderType } from '../../../ai-core/interfaces/ai-service.interface.js';
import { AICoreService } from '../../../ai-core/services/ai-core.service.js';
import { collectMongoPipelineCollections } from '../../../ai-core/tools/collect-mongo-pipeline-collections.js';
import { createDatabaseTools } from '../../../ai-core/tools/database-tools.js';
import { searchDocumentation } from '../../../ai-core/tools/documentation-search.js';
import { createDatabaseQuerySystemPrompt } from '../../../ai-core/tools/prompts.js';
Expand All @@ -22,7 +31,9 @@ import { Messages } from '../../../exceptions/text/messages.js';
import { getErrorMessage } from '../../../helpers/get-error-message.js';
import { isConnectionTypeAgent } from '../../../helpers/is-connection-entity-agent.js';
import { slackPostMessage } from '../../../helpers/slack/slack-post-message.js';
import { CedarPermissionsService } from '../../cedar-authorization/cedar-permissions.service.js';
import { ConnectionEntity } from '../../connection/connection.entity.js';
import { assertUserCanReadQueryTables } from '../../visualizations/panel/utils/assert-query-tables-readable.util.js';
import { MessageRole } from '../ai-conversation-history/ai-chat-messages/message-role.enum.js';
import { UserAiChatEntity } from '../ai-conversation-history/user-ai-chat/user-ai-chat.entity.js';
import { IRequestInfoFromTableV2 } from '../ai-use-cases.interface.js';
Expand All @@ -41,6 +52,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
@Inject(BaseType.GLOBAL_DB_CONTEXT)
protected _dbContext: IGlobalDatabaseContext,
private readonly aiCoreService: AICoreService,
private readonly cedarPermissions: CedarPermissionsService,
) {
super();
}
Expand Down Expand Up @@ -104,6 +116,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
tableName,
userEmail,
foundConnection,
user_id,
);

if (accumulatedResponse) {
Expand Down Expand Up @@ -132,6 +145,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
inputTableName: string,
userEmail: string,
foundConnection: ConnectionEntity,
userId: string,
): Promise<string> {
let currentMessages = [...messages];
let depth = 0;
Expand Down Expand Up @@ -178,6 +192,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
inputTableName,
userEmail,
foundConnection,
userId,
);

for (const toolResult of toolResults) {
Expand Down Expand Up @@ -226,6 +241,7 @@ export class RequestInfoFromTableWithAIUseCaseV7
inputTableName: string,
userEmail: string,
foundConnection: ConnectionEntity,
userId: string,
): Promise<Array<{ toolCallId: string; result: string }>> {
const results: Array<{ toolCallId: string; result: string }> = [];

Expand All @@ -236,11 +252,13 @@ export class RequestInfoFromTableWithAIUseCaseV7
switch (toolCall.name) {
case 'getTableStructure': {
const tableName = (toolCall.arguments.tableName as string) || inputTableName;
await this.assertUserCanReadTables([tableName], userId, foundConnection.id);
const structureInfo = await this.getTableStructureInfo(
dataAccessObject,
tableName,
userEmail,
foundConnection,
userId,
);
result = encodeToToon(structureInfo);
break;
Expand All @@ -256,6 +274,14 @@ export class RequestInfoFromTableWithAIUseCaseV7
'Invalid SQL query. Please ensure it is a read-only SELECT statement without any forbidden keywords.',
);
}
await assertUserCanReadQueryTables({
query,
connectionType: foundConnection.type as ConnectionTypesEnum,
connectionId: foundConnection.id,
validateTableRead: (referencedTableName) =>
this.cedarPermissions.improvedCheckTableRead(userId, foundConnection.id, referencedTableName),
listAllTableNames: async () => (await dataAccessObject.getTablesFromDB()).map((table) => table.tableName),
});
const wrappedQuery = wrapQueryWithLimit(query, foundConnection.type as ConnectionTypesEnum);
const queryResult = await dataAccessObject.executeRawQuery(wrappedQuery, inputTableName, userEmail);
result = encodeToToon(queryResult);
Expand All @@ -272,6 +298,13 @@ export class RequestInfoFromTableWithAIUseCaseV7
'Invalid MongoDB command. Please ensure it is a read-only aggregation pipeline without any forbidden keywords.',
);
}
await this.assertUserCanReadPipelineCollections(
pipeline,
inputTableName,
userId,
foundConnection.id,
dataAccessObject,
);
const pipelineResult = await dataAccessObject.executeRawQuery(pipeline, inputTableName, userEmail);
Comment on lines 252 to 308
result = encodeToToon(pipelineResult);
break;
Expand Down Expand Up @@ -307,32 +340,50 @@ export class RequestInfoFromTableWithAIUseCaseV7
tableName: string,
userEmail: string,
foundConnection: ConnectionEntity,
userId: string,
) {
const [tableStructure, tableForeignKeys, referencedTableNamesAndColumns] = await Promise.all([
dao.getTableStructure(tableName, userEmail),
dao.getTableForeignKeys(tableName, userEmail),
dao.getReferencedTableNamesAndColumns(tableName, userEmail),
]);

// Only expose the structure of related tables the user is permitted to
// read — otherwise foreign-key traversal would leak the schema of tables
// the user has no access to.
const referencedTablesStructures = [];
const structurePromises = referencedTableNamesAndColumns.flatMap((referencedTable) =>
referencedTable.referenced_by.map((table) =>
dao.getTableStructure(table.table_name, userEmail).then((structure) => ({
tableName: table.table_name,
structure,
})),
),
referencedTable.referenced_by.map(async (table) => {
const canRead = await this.cedarPermissions.improvedCheckTableRead(
userId,
foundConnection.id,
table.table_name,
);
if (!canRead) {
return null;
}
const structure = await dao.getTableStructure(table.table_name, userEmail);
return { tableName: table.table_name, structure };
}),
);
referencedTablesStructures.push(...(await Promise.all(structurePromises)));
referencedTablesStructures.push(...(await Promise.all(structurePromises)).filter((item) => item !== null));

const foreignTablesStructures = [];
const foreignTablesStructurePromises = tableForeignKeys.flatMap((foreignKey) =>
dao.getTableStructure(foreignKey.referenced_table_name, userEmail).then((structure) => ({
tableName: foreignKey.referenced_table_name,
structure,
})),
const foreignTablesStructurePromises = tableForeignKeys.map(async (foreignKey) => {
const canRead = await this.cedarPermissions.improvedCheckTableRead(
userId,
foundConnection.id,
foreignKey.referenced_table_name,
);
if (!canRead) {
return null;
}
const structure = await dao.getTableStructure(foreignKey.referenced_table_name, userEmail);
return { tableName: foreignKey.referenced_table_name, structure };
});
foreignTablesStructures.push(
...(await Promise.all(foreignTablesStructurePromises)).filter((item) => item !== null),
);
foreignTablesStructures.push(...(await Promise.all(foreignTablesStructurePromises)));

return {
tableStructure,
Expand All @@ -345,6 +396,64 @@ export class RequestInfoFromTableWithAIUseCaseV7
};
}

/**
* Verifies the user has read permission on every supplied table before the
* AI is allowed to query or inspect them. Throws a `ForbiddenException` on
* the first unreadable table; inside the tool loop this surfaces back to the
* model as a tool error, so the offending query is never executed. Empty or
* blank names are ignored.
*/
private async assertUserCanReadTables(
tableNames: Array<string>,
userId: string,
connectionId: string,
): Promise<void> {
const uniqueTableNames = Array.from(
new Set(tableNames.map((name) => name?.trim()).filter((name): name is string => Boolean(name))),
);

for (const tableName of uniqueTableNames) {
const canRead = await this.cedarPermissions.improvedCheckTableRead(userId, connectionId, tableName);
if (!canRead) {
this.logger.warn(
`AI request blocked for user ${userId} on connection ${connectionId}: ` +
`no read permission for table "${tableName}"`,
);
throw new ForbiddenException(Messages.NO_READ_PERMISSION_FOR_TABLE(tableName));
}
}
}

/**
* Guards a MongoDB aggregation pipeline against table-level read permissions:
* the user must be able to read the base collection and every collection the
* pipeline pulls in (`$lookup` / `$graphLookup` / `$unionWith`). When the
* pipeline cannot be parsed we cannot trust it to be harmless, so we fall
* back to requiring read permission on every collection in the connection.
*/
private async assertUserCanReadPipelineCollections(
pipeline: string,
baseCollection: string,
userId: string,
connectionId: string,
dataAccessObject: IDataAccessObject | IDataAccessObjectAgent,
): Promise<void> {
const collected = collectMongoPipelineCollections(pipeline);

let collectionsToCheck: Array<string>;
if (collected.kind === 'tables') {
collectionsToCheck = [baseCollection, ...collected.tables];
} else {
this.logger.warn(
`AI pipeline permission check could not resolve referenced collections for connection ${connectionId} ` +
`(reason: ${collected.reason}); falling back to all-collections read check.`,
);
collectionsToCheck = (await dataAccessObject.getTablesFromDB()).map((table) => table.tableName);
}

await this.assertUserCanReadTables(collectionsToCheck, userId, connectionId);
}

private setupResponseHeaders(response: Response): void {
response.setHeader('Content-Type', 'text/event-stream');
response.setHeader('Cache-Control', 'no-cache');
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import test from 'ava';
import { collectMongoPipelineCollections } from '../../../src/ai-core/tools/collect-mongo-pipeline-collections.js';

function tablesOf(pipeline: string): Array<string> {
const result = collectMongoPipelineCollections(pipeline);
if (result.kind !== 'tables') {
throw new Error(`expected resolved tables, got indeterminate: ${result.reason}`);
}
return [...result.tables].sort();
}

test('resolves no referenced collections for a pipeline without joins', (t) => {
t.deepEqual(tablesOf('[{"$match":{"status":"active"}},{"$group":{"_id":"$type"}}]'), []);
});

test('resolves a $lookup target collection', (t) => {
t.deepEqual(tablesOf('[{"$lookup":{"from":"salaries","localField":"id","foreignField":"user_id","as":"s"}}]'), [
'salaries',
]);
});
Comment on lines +16 to +20

test('resolves a $graphLookup target collection', (t) => {
t.deepEqual(
tablesOf(
'[{"$graphLookup":{"from":"org_chart","startWith":"$managerId","connectFromField":"managerId","connectToField":"_id","as":"chain"}}]',
),
['org_chart'],
);
});

test('resolves a $unionWith string collection', (t) => {
t.deepEqual(tablesOf('[{"$unionWith":"archived_orders"}]'), ['archived_orders']);
});

test('resolves a $unionWith object collection', (t) => {
t.deepEqual(tablesOf('[{"$unionWith":{"coll":"audit_log","pipeline":[]}}]'), ['audit_log']);
});

test('resolves collections nested inside a $lookup sub-pipeline', (t) => {
const pipeline =
'[{"$lookup":{"from":"orders","as":"o","pipeline":[{"$lookup":{"from":"secret_payouts","localField":"a","foreignField":"b","as":"p"}}]}}]';
t.deepEqual(tablesOf(pipeline), ['orders', 'secret_payouts']);
});

test('deduplicates repeated collection references', (t) => {
const pipeline =
'[{"$lookup":{"from":"orders","localField":"a","foreignField":"b","as":"o1"}},{"$lookup":{"from":"orders","localField":"c","foreignField":"d","as":"o2"}}]';
t.deepEqual(tablesOf(pipeline), ['orders']);
});

test('returns indeterminate for an unparseable pipeline', (t) => {
const result = collectMongoPipelineCollections('not valid json {');
t.is(result.kind, 'indeterminate');
if (result.kind === 'indeterminate') {
t.true(result.reason.includes('parse error'));
}
});
Loading