From e65c4b63db19668486d629951c70632eea07d579 Mon Sep 17 00:00:00 2001 From: Piyush Singh Gaur Date: Fri, 24 Apr 2026 12:52:57 +0530 Subject: [PATCH 1/3] feat: phase 0,1 complete --- package-lock.json | 98 +++++++++- package.json | 1 + .../unit/db-knowledge-graph.service.unit.ts | 4 +- .../unit/nodes/check-cache.node.unit.ts | 4 +- .../unit/nodes/check-permission.node.unit.ts | 4 +- .../unit/nodes/classify-change.node.unit.ts | 4 +- .../unit/nodes/fix-query.node.unit.ts | 4 +- .../unit/nodes/get-columns.node.unit.ts | 4 +- .../unit/nodes/get-tables.node.unit.ts | 10 +- .../unit/nodes/save-dataset-node.unit.ts | 10 +- .../nodes/semantic-validator.node.unit.ts | 8 +- .../unit/nodes/sql-generation.node.unit.ts | 8 +- .../nodes/syntactic-validator.node.unit.ts | 4 +- src/__tests__/unit/mastra-bridge.unit.ts | 91 ++++++++++ .../unit/nodes/call-llm.node.unit.ts | 4 +- .../unit/nodes/summarise-file.node.unit.ts | 7 +- .../unit/visualizers/bar.visualizer.unit.ts | 6 +- .../unit/visualizers/line.visualizer.unit.ts | 6 +- .../unit/visualizers/pie.visualizer.unit.ts | 6 +- src/component.ts | 14 ++ .../db-query/nodes/check-cache.node.ts | 4 +- .../db-query/nodes/check-permissions.node.ts | 4 +- .../db-query/nodes/check-templates.node.ts | 4 +- .../db-query/nodes/classify-change.node.ts | 4 +- .../db-query/nodes/fix-query.node.ts | 4 +- .../db-query/nodes/generate-checklist.node.ts | 4 +- .../nodes/generate-description.node.ts | 4 +- .../db-query/nodes/get-columns.node.ts | 4 +- .../db-query/nodes/get-tables.node.ts | 6 +- .../db-query/nodes/save-dataset-node.ts | 4 +- .../db-query/nodes/semantic-validator.node.ts | 6 +- .../db-query/nodes/sql-generation.node.ts | 6 +- .../nodes/syntactic-validator.node.ts | 4 +- .../db-query/nodes/verify-checklist.node.ts | 8 +- .../db-knowledge-graph.service.ts | 4 +- .../services/template-helper.service.ts | 4 +- .../db-query/tools/ask-about-dataset.tool.ts | 22 ++- .../tools/get-data-as-dataset.tool.ts | 16 +- .../db-query/tools/improve-dataset.tool.ts | 16 +- .../nodes/select-visualization.node.ts | 4 +- .../tools/generate-visualization.tool.ts | 16 +- .../visualizers/bar.visualizer.ts | 4 +- .../visualizers/line.visualizer.ts | 4 +- .../visualizers/pie.visualizer.ts | 4 +- src/graphs/base.graph.ts | 4 +- src/graphs/chat/nodes/call-llm.node.ts | 8 +- src/graphs/chat/nodes/run-tool.node.ts | 11 +- src/graphs/chat/nodes/summarise-file.node.ts | 4 +- src/graphs/types.ts | 93 ++++++++-- src/keys.ts | 33 +++- src/services/index.ts | 2 + src/services/mastra-bridge.observer.ts | 20 +++ src/services/mastra-bridge.service.ts | 167 ++++++++++++++++++ .../anthropic/llms/anthropic.provider.ts | 6 +- .../aws/llms/bedrock-non-thinking.provider.ts | 6 +- .../providers/aws/llms/bedrock.provider.ts | 8 +- .../cerebras/llm/cerebras.provider.ts | 4 +- .../providers/google/llms/gemini.provider.ts | 4 +- .../providers/groq/llms/groq.provider.ts | 6 +- .../providers/openai/llms/openai.provider.ts | 6 +- src/types.ts | 108 ++++++++--- 61 files changed, 759 insertions(+), 188 deletions(-) create mode 100644 src/__tests__/unit/mastra-bridge.unit.ts create mode 100644 src/services/mastra-bridge.observer.ts create mode 100644 src/services/mastra-bridge.service.ts diff --git a/package-lock.json b/package-lock.json index 3cc3c94..8906d8d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "lb4-llm-chat-component", - "version": "2.0.0", + "version": "2.1.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "lb4-llm-chat-component", - "version": "2.0.0", + "version": "2.1.0", "license": "MIT", "dependencies": { "@langchain/community": "^0.3.50", @@ -16,6 +16,7 @@ "@sourceloop/chat-service": "15.0.3", "@sourceloop/core": "17.0.3", "@sourceloop/file-utils": "0.3.7", + "ai": "^6.0.168", "langchain": "^0.3.37", "loopback4-authentication": "^12.2.0", "loopback4-authorization": "^7.0.3", @@ -119,6 +120,52 @@ "dev": true, "license": "MIT" }, + "node_modules/@ai-sdk/gateway": { + "version": "3.0.104", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.104.tgz", + "integrity": "sha512-ZKX5n74io8VIRlhIMSLWVlvT3sXC8Z7cZ9GHuWBWZDVi96+62AIsWuLGvMfcBA1STYuSoDrp6rIziZmvrTq0TA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.23", + "@vercel/oidc": "3.2.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/provider": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", + "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/provider-utils": { + "version": "4.0.23", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.23.tgz", + "integrity": "sha512-z8GlDaCmRSDlqkMF2f4/RFgWxdarvIbyuk+m6WXT1LYgsnGiXRJGTD2Z1+SDl3LqtFuRtGX1aghYvQLoHL/9pg==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/@anthropic-ai/sdk": { "version": "0.27.3", "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.27.3.tgz", @@ -5201,9 +5248,7 @@ "version": "1.9.0", "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", - "devOptional": true, "license": "Apache-2.0", - "peer": true, "engines": { "node": ">=8.0.0" } @@ -7034,6 +7079,12 @@ "@loopback/core": "^6.1.6" } }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "license": "MIT" + }, "node_modules/@tokenizer/inflate": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/@tokenizer/inflate/-/inflate-0.4.1.tgz", @@ -7697,6 +7748,15 @@ "dev": true, "license": "ISC" }, + "node_modules/@vercel/oidc": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.2.0.tgz", + "integrity": "sha512-UycprH3T6n3jH0k44NHMa7pnFHGu/N05MjojYr+Mc6I7obkoLIJujSWwin1pCvdy/eOxrI/l3uDLQsmcrOb4ug==", + "license": "Apache-2.0", + "engines": { + "node": ">= 20" + } + }, "node_modules/abbrev": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/abbrev/-/abbrev-1.1.1.tgz", @@ -7816,6 +7876,24 @@ "node": ">=8" } }, + "node_modules/ai": { + "version": "6.0.168", + "resolved": "https://registry.npmjs.org/ai/-/ai-6.0.168.tgz", + "integrity": "sha512-2HqCJuO+1V2aV7vfYs5LFEUfxbkGX+5oa54q/gCCTL7KLTdbxcCu5D7TdLA5kwsrs3Szgjah9q6D9tpjHM3hUQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/gateway": "3.0.104", + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.23", + "@opentelemetry/api": "1.9.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/ajv": { "version": "8.18.0", "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", @@ -11650,6 +11728,15 @@ "node": ">=0.8.x" } }, + "node_modules/eventsource-parser": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.8.tgz", + "integrity": "sha512-70QWGkr4snxr0OXLRWsFLeRBIRPuQOvt4s8QYjmUlmlkyTZkRqS7EDVRZtzU3TiyDbXSzaOeF0XUKy8PchzukQ==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/execa": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/execa/-/execa-5.1.1.tgz", @@ -14811,8 +14898,7 @@ "version": "0.4.0", "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", - "license": "(AFL-2.1 OR BSD-3-Clause)", - "peer": true + "license": "(AFL-2.1 OR BSD-3-Clause)" }, "node_modules/json-schema-compare": { "version": "0.2.2", diff --git a/package.json b/package.json index e8e143a..c47d8b6 100644 --- a/package.json +++ b/package.json @@ -134,6 +134,7 @@ "@sourceloop/chat-service": "15.0.3", "@sourceloop/core": "17.0.3", "@sourceloop/file-utils": "0.3.7", + "ai": "^6.0.168", "langchain": "^0.3.37", "loopback4-authentication": "^12.2.0", "loopback4-authorization": "^7.0.3", diff --git a/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts b/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts index ed9616b..4da5395 100644 --- a/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts +++ b/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts @@ -1,6 +1,6 @@ import {expect, sinon} from '@loopback/testlab'; import {DbKnowledgeGraphService} from '../../../components'; -import {EmbeddingProvider, LLMProvider} from '../../../types'; +import {EmbeddingProvider, RuntimeLLMProvider} from '../../../types'; describe(`DbKnowledgeGraphService Unit`, function () { let service: DbKnowledgeGraphService; @@ -11,7 +11,7 @@ describe(`DbKnowledgeGraphService Unit`, function () { llmStub = sinon.stub(); embedStub = sinon.stub(); service = new DbKnowledgeGraphService( - llmStub as unknown as LLMProvider, + llmStub as unknown as RuntimeLLMProvider, { embedDocuments: embedStub, } as unknown as EmbeddingProvider, diff --git a/src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts b/src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts index 5aa66d5..c3de19b 100644 --- a/src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts @@ -13,7 +13,7 @@ import { DbQueryState, QueryCacheMetadata, } from '../../../../components'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; describe('CheckCacheNode Unit', function () { let node: CheckCacheNode; @@ -28,7 +28,7 @@ describe('CheckCacheNode Unit', function () { const cache = { invoke: cacheStub, } as unknown as BaseRetriever; - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; node = new CheckCacheNode(cache, llm, datasetHelperStub); datasetHelperStub.stubs.checkPermissions.resolves([]); diff --git a/src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts b/src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts index 94cab60..835ae8c 100644 --- a/src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts @@ -6,7 +6,7 @@ import { Errors, PermissionHelper, } from '../../../../components'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {Currency, Employee, ExchangeRate} from '../../../fixtures/models'; describe('CheckPermissionsNode Unit', function () { @@ -15,7 +15,7 @@ describe('CheckPermissionsNode Unit', function () { beforeEach(() => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; const permissionHelper = new PermissionHelper( { models: [ diff --git a/src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts b/src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts index 8c315ca..5ef9746 100644 --- a/src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts @@ -2,7 +2,7 @@ import {expect, sinon} from '@loopback/testlab'; import {LangGraphRunnableConfig} from '@langchain/langgraph'; import {ChangeType, ClassifyChangeNode} from '../../../../components'; import {DbQueryState} from '../../../../components/db-query/state'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; describe('ClassifyChangeNode Unit', function () { let node: ClassifyChangeNode; @@ -10,7 +10,7 @@ describe('ClassifyChangeNode Unit', function () { beforeEach(() => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; node = new ClassifyChangeNode(llm); }); diff --git a/src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts b/src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts index dc7518b..4bac2ae 100644 --- a/src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts @@ -7,7 +7,7 @@ import { } from '../../../../components'; import {DbSchemaHelperService} from '../../../../components/db-query/services'; import {DbQueryState} from '../../../../components/db-query/state'; -import {LLMProvider, SupportedDBs} from '../../../../types'; +import {RuntimeLLMProvider, SupportedDBs} from '../../../../types'; describe('FixQueryNode Unit', function () { let node: FixQueryNode; @@ -16,7 +16,7 @@ describe('FixQueryNode Unit', function () { beforeEach(() => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; schemaHelper = { asString: sinon.stub().returns('CREATE TABLE users (id INT, name TEXT);'), getTablesContext: sinon.stub().returns([]), diff --git a/src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts b/src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts index 467f886..5c20ece 100644 --- a/src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts @@ -7,7 +7,7 @@ import { GetColumnsNode, SqliteConnector, } from '../../../../components'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {Employee, ExchangeRate} from '../../../fixtures/models'; import {IAuthUserWithPermissions} from 'loopback4-authorization'; @@ -18,7 +18,7 @@ describe('GetColumnsNode Unit', function () { beforeEach(async () => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; schemaHelper = new DbSchemaHelperService( new SqliteConnector( diff --git a/src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts b/src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts index 1c30254..2cab51d 100644 --- a/src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts @@ -13,7 +13,7 @@ import { SqliteConnector, TableSearchService, } from '../../../../components'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import { Currency, Employee, @@ -34,7 +34,7 @@ describe('GetTablesNode Unit', function () { beforeEach(async () => { smartllmStub = sinon.stub(); dumbllmStub = sinon.stub(); - const llm = dumbllmStub as unknown as LLMProvider; + const llm = dumbllmStub as unknown as RuntimeLLMProvider; schemaHelper = new DbSchemaHelperService( new SqliteConnector( @@ -52,7 +52,7 @@ describe('GetTablesNode Unit', function () { tableSearchStub = createStubInstance(TableSearchService); node = new GetTablesNode( llm, - dumbllmStub as unknown as LLMProvider, + dumbllmStub as unknown as RuntimeLLMProvider, { models: [], }, @@ -134,8 +134,8 @@ failed attempt: reason for failure it('should return state with minimal schema based on prompt and table search with smart llm', async () => { node = new GetTablesNode( - dumbllmStub as unknown as LLMProvider, - smartllmStub as unknown as LLMProvider, + dumbllmStub as unknown as RuntimeLLMProvider, + smartllmStub as unknown as RuntimeLLMProvider, { models: [], nodes: { diff --git a/src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts b/src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts index 4c19a9c..312d4ad 100644 --- a/src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts @@ -13,7 +13,7 @@ import { SaveDataSetNode, } from '../../../../components'; import {DataSet} from '../../../../components/db-query/models'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {buildDatasetStoreStub} from '../../../test-helper'; describe('SaveDataSetNode Unit', function () { @@ -24,7 +24,7 @@ describe('SaveDataSetNode Unit', function () { beforeEach(() => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; store = buildDatasetStoreStub(); helper = createStubInstance(DbSchemaHelperService); node = new SaveDataSetNode( @@ -69,7 +69,7 @@ describe('SaveDataSetNode Unit', function () { it('should return state with dataset id and result array if readAccessForAI is true', async () => { node = new SaveDataSetNode( - llmStub as unknown as LLMProvider, + llmStub as unknown as RuntimeLLMProvider, store, {models: [], readAccessForAI: true, maxRowsForAI: 50}, { @@ -109,7 +109,7 @@ describe('SaveDataSetNode Unit', function () { }); it('should throw error if user does not have tenantId', async () => { - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; node = new SaveDataSetNode( llm, store, @@ -138,7 +138,7 @@ describe('SaveDataSetNode Unit', function () { }); it('should throw error if sql is not present in state', async () => { - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; node = new SaveDataSetNode( llm, store, diff --git a/src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts b/src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts index 5df193a..a6fe36b 100644 --- a/src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts @@ -8,7 +8,7 @@ import { DbSchemaHelperService, TableSearchService, } from '../../../../components/db-query/services'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; describe('SemanticValidatorNode Unit', function () { let node: SemanticValidatorNode; @@ -17,7 +17,7 @@ describe('SemanticValidatorNode Unit', function () { beforeEach(() => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; const schemaHelper = { asString: sinon.stub().returns(''), } as unknown as DbSchemaHelperService; @@ -183,8 +183,8 @@ describe('SemanticValidatorNode Unit', function () { } as unknown as DbSchemaHelperService; const nodeWithTables = new SemanticValidatorNode( - llmStub as unknown as LLMProvider, - llmStub as unknown as LLMProvider, + llmStub as unknown as RuntimeLLMProvider, + llmStub as unknown as RuntimeLLMProvider, {models: []}, tableSearchStub, schemaHelper, diff --git a/src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts b/src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts index c3e65fd..7243638 100644 --- a/src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts @@ -6,7 +6,7 @@ import { SqlGenerationNode, SqliteConnector, } from '../../../../components'; -import {LLMProvider, SupportedDBs} from '../../../../types'; +import {RuntimeLLMProvider, SupportedDBs} from '../../../../types'; import {IAuthUserWithPermissions} from 'loopback4-authorization'; describe('SqlGenerationNode Unit', function () { @@ -16,7 +16,7 @@ describe('SqlGenerationNode Unit', function () { beforeEach(() => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; schemaHelper = new DbSchemaHelperService( new SqliteConnector( @@ -507,8 +507,8 @@ It should have no other character or symbol or character that is not part of SQL cheapLLMStub = sinon.stub(); originalEnv = process.env.OPTIMIZE_CACHED_QUERIES; - const smartLLM = smartLLMStub as unknown as LLMProvider; - const cheapLLM = cheapLLMStub as unknown as LLMProvider; + const smartLLM = smartLLMStub as unknown as RuntimeLLMProvider; + const cheapLLM = cheapLLMStub as unknown as RuntimeLLMProvider; nodeWithTwoLLMs = new SqlGenerationNode( smartLLM, diff --git a/src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts b/src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts index 1403441..7d38f6a 100644 --- a/src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts +++ b/src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts @@ -7,7 +7,7 @@ import { SqliteConnector, SyntacticValidatorNode, } from '../../../../components'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {IAuthUserWithPermissions} from 'loopback4-authorization'; describe('SyntacticValidatorNode Unit', function () { @@ -17,7 +17,7 @@ describe('SyntacticValidatorNode Unit', function () { beforeEach(async () => { llmStub = sinon.stub(); - const llm = llmStub as unknown as LLMProvider; + const llm = llmStub as unknown as RuntimeLLMProvider; const ds = new juggler.DataSource({ connector: 'sqlite3', diff --git a/src/__tests__/unit/mastra-bridge.unit.ts b/src/__tests__/unit/mastra-bridge.unit.ts new file mode 100644 index 0000000..08185ab --- /dev/null +++ b/src/__tests__/unit/mastra-bridge.unit.ts @@ -0,0 +1,91 @@ +import {Context} from '@loopback/core'; +import {AnyObject} from '@loopback/repository'; +import {expect} from '@loopback/testlab'; +import { + GRAPH_NODE_NAME, + GRAPH_NODE_TAG, + TOOL_NAME, + TOOL_TAG, +} from '../../constant'; +import {IGraphNode, IGraphTool} from '../../graphs/types'; +import { + MastraBridgeService, + MastraRuntimeFactory, +} from '../../services/mastra-bridge.service'; + +describe('MastraBridgeService Unit', () => { + it('collects tagged node and tool bindings during initialization', async () => { + const context = new Context('mastra-bridge-context'); + + const node: IGraphNode = { + execute: async () => ({}), + }; + const tool: IGraphTool = { + key: 'fake-tool', + build: async () => { + throw new Error('Not implemented for this unit test'); + }, + }; + + context + .bind('services.fake-node') + .to(node) + .tag({ + [GRAPH_NODE_TAG]: true, + [GRAPH_NODE_NAME]: 'FakeNode', + }); + + context + .bind('services.fake-tool') + .to(tool) + .tag({ + [TOOL_TAG]: true, + [TOOL_NAME]: 'FakeTool', + }); + + const bridge = new MastraBridgeService(context); + await bridge.initialize(); + + const snapshot = bridge.getBootstrapSnapshot(); + + expect(snapshot.nodes).to.have.length(1); + expect(snapshot.nodes[0].key).to.equal('FakeNode'); + + expect(snapshot.tools).to.have.length(1); + expect(snapshot.tools[0].name).to.equal('FakeTool'); + + const resolvedNode = await snapshot.nodes[0].resolve(); + const resolvedTool = await snapshot.tools[0].resolve(); + + expect(resolvedNode).to.equal(node); + expect(resolvedTool).to.equal(tool); + }); + + it('initializes the runtime adapter only once', async () => { + const context = new Context('mastra-bridge-runtime-context'); + let callCount = 0; + + const adapter = { + getAgent: () => ({name: 'agent'}) as T, + getWorkflow: () => ({name: 'workflow'}) as T, + }; + + const runtimeFactory: MastraRuntimeFactory = async () => { + callCount += 1; + return adapter; + }; + + const bridge = new MastraBridgeService(context, runtimeFactory); + + await bridge.initialize(); + await bridge.initialize(); + + expect(callCount).to.equal(1); + expect(bridge.getAgent<{name: string}>('chat-agent')?.name).to.equal( + 'agent', + ); + expect(bridge.getWorkflow<{name: string}>('db-workflow')?.name).to.equal( + 'workflow', + ); + }); +}); diff --git a/src/__tests__/unit/nodes/call-llm.node.unit.ts b/src/__tests__/unit/nodes/call-llm.node.unit.ts index 9d14b22..91d9c4b 100644 --- a/src/__tests__/unit/nodes/call-llm.node.unit.ts +++ b/src/__tests__/unit/nodes/call-llm.node.unit.ts @@ -6,7 +6,7 @@ import {CallLLMNode, ChatStore, RunnableConfig} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; import {Chat} from '../../../models'; import {ChatRepository, MessageRepository} from '../../../repositories'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {setupChats, setupMessages, stubUser} from '../../test-helper'; describe('CallLLMNode Unit', function () { @@ -24,7 +24,7 @@ describe('CallLLMNode Unit', function () { invoke: llmStub, }; }), - } as unknown as LLMProvider; + } as unknown as RuntimeLLMProvider; const context = new Context('test-context'); context.bind('services.CallLLMNode').toClass(CallLLMNode); context.bind('services.ChatStore').toClass(ChatStore); diff --git a/src/__tests__/unit/nodes/summarise-file.node.unit.ts b/src/__tests__/unit/nodes/summarise-file.node.unit.ts index 9598b3b..8a44d17 100644 --- a/src/__tests__/unit/nodes/summarise-file.node.unit.ts +++ b/src/__tests__/unit/nodes/summarise-file.node.unit.ts @@ -8,7 +8,7 @@ import { } from '@loopback/testlab'; import {ChatState, ChatStore, SummariseFileNode} from '../../../graphs'; import {Message} from '../../../models'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {buildFileStub} from '../../test-helper'; describe(`SummariseFileNode Unit`, function () { @@ -29,7 +29,10 @@ describe(`SummariseFileNode Unit`, function () { llmStub = sinon.stub(); writerStub = sinon.stub(); chatStore = createStubInstance(ChatStore); - node = new SummariseFileNode(llmStub as unknown as LLMProvider, chatStore); + node = new SummariseFileNode( + llmStub as unknown as RuntimeLLMProvider, + chatStore, + ); }); it('should throw an error if no chat ID is found in state', async () => { diff --git a/src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts b/src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts index 6885692..c60e480 100644 --- a/src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts +++ b/src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts @@ -1,12 +1,12 @@ import {expect, sinon} from '@loopback/testlab'; import {BarVisualizer} from '../../../../components/visualization/visualizers/bar.visualizer'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {fail} from 'assert'; import {VisualizationGraphState} from '../../../../components'; describe('BarVisualizer Unit', function () { let visualizer: BarVisualizer; - let llmProvider: sinon.SinonStubbedInstance; + let llmProvider: sinon.SinonStubbedInstance; let withStructuredOutputStub: sinon.SinonStub; beforeEach(() => { @@ -14,7 +14,7 @@ describe('BarVisualizer Unit', function () { withStructuredOutputStub = sinon.stub(); llmProvider = { withStructuredOutput: withStructuredOutputStub, - } as sinon.SinonStubbedInstance; + } as sinon.SinonStubbedInstance; visualizer = new BarVisualizer(llmProvider); }); diff --git a/src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts b/src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts index 9195bc6..e5dc3ef 100644 --- a/src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts +++ b/src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts @@ -1,12 +1,12 @@ import {expect, sinon} from '@loopback/testlab'; import {LineVisualizer} from '../../../../components/visualization/visualizers/line.visualizer'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {fail} from 'assert'; import {VisualizationGraphState} from '../../../../components'; describe('LineVisualizer Unit', function () { let visualizer: LineVisualizer; - let llmProvider: sinon.SinonStubbedInstance; + let llmProvider: sinon.SinonStubbedInstance; let withStructuredOutputStub: sinon.SinonStub; beforeEach(() => { @@ -14,7 +14,7 @@ describe('LineVisualizer Unit', function () { withStructuredOutputStub = sinon.stub(); llmProvider = { withStructuredOutput: withStructuredOutputStub, - } as sinon.SinonStubbedInstance; + } as sinon.SinonStubbedInstance; visualizer = new LineVisualizer(llmProvider); }); diff --git a/src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts b/src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts index c76405b..87000a9 100644 --- a/src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts +++ b/src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts @@ -1,12 +1,12 @@ import {expect, sinon} from '@loopback/testlab'; import {PieVisualizer} from '../../../../components/visualization/visualizers/pie.visualizer'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {fail} from 'assert'; import {VisualizationGraphState} from '../../../../components'; describe('PieVisualizer Unit', function () { let visualizer: PieVisualizer; - let llmProvider: sinon.SinonStubbedInstance; + let llmProvider: sinon.SinonStubbedInstance; let withStructuredOutputStub: sinon.SinonStub; beforeEach(() => { @@ -14,7 +14,7 @@ describe('PieVisualizer Unit', function () { withStructuredOutputStub = sinon.stub(); llmProvider = { withStructuredOutput: withStructuredOutputStub, - } as sinon.SinonStubbedInstance; + } as sinon.SinonStubbedInstance; visualizer = new PieVisualizer(llmProvider); }); diff --git a/src/component.ts b/src/component.ts index a6f1e1d..619ced2 100644 --- a/src/component.ts +++ b/src/component.ts @@ -2,10 +2,12 @@ import { Binding, BindingScope, Component, + Constructor, ControllerClass, CoreBindings, createBindingFromClass, inject, + LifeCycleObserver, ProviderMap, ServiceOrProviderClass, } from '@loopback/core'; @@ -50,6 +52,8 @@ import {ChatRepository, MessageRepository} from './repositories'; import { ChatCountStrategy, GenerationService, + MastraBridgeObserver, + MastraBridgeService, TokenCountPerUserStrategy, TokenCountStrategy, } from './services'; @@ -76,6 +80,10 @@ export class AiIntegrationsComponent implements Component { createBindingFromClass(RedisCache, { key: AiIntegrationBindings.Cache.key, }), + createBindingFromClass(MastraBridgeService, { + key: AiIntegrationBindings.MastraBridge.key, + defaultScope: BindingScope.SINGLETON, + }), ]; this.providers = { @@ -100,6 +108,7 @@ export class AiIntegrationsComponent implements Component { ]; this.controllers = [GenerationController, ChatController]; + this.lifeCycleObservers = [MastraBridgeObserver]; this.models = [Chat, Message, CacheModel]; this.repositories = [ ChatRepository, @@ -214,6 +223,11 @@ export class AiIntegrationsComponent implements Component { */ controllers?: ControllerClass[]; + /** + * An optional list of lifecycle observers. + */ + lifeCycleObservers?: Constructor[]; + /** * Setup ServiceSequence by default if no other sequnce provided * diff --git a/src/components/db-query/nodes/check-cache.node.ts b/src/components/db-query/nodes/check-cache.node.ts index 2247962..6846b70 100644 --- a/src/components/db-query/nodes/check-cache.node.ts +++ b/src/components/db-query/nodes/check-cache.node.ts @@ -11,7 +11,7 @@ import { ToolStatus, } from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -26,7 +26,7 @@ export class CheckCacheNode implements IGraphNode { @inject(DbQueryAIExtensionBindings.QueryCache) private readonly cache: BaseRetriever, @inject(AiIntegrationBindings.CheapLLM) - private readonly smartLLM: LLMProvider, + private readonly smartLLM: RuntimeLLMProvider, @service(DataSetHelper) private readonly dataSetHelper: DataSetHelper, ) {} diff --git a/src/components/db-query/nodes/check-permissions.node.ts b/src/components/db-query/nodes/check-permissions.node.ts index 8d59761..d6ba80e 100644 --- a/src/components/db-query/nodes/check-permissions.node.ts +++ b/src/components/db-query/nodes/check-permissions.node.ts @@ -5,7 +5,7 @@ import {service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, RunnableConfig} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryNodes} from '../nodes.enum'; import {PermissionHelper} from '../services'; @@ -16,7 +16,7 @@ import {Errors} from '../types'; export class CheckPermissionsNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, // Replace with actual type if available + private readonly llm: RuntimeLLMProvider, // Replace with actual type if available @service(PermissionHelper) private readonly permissions: PermissionHelper, diff --git a/src/components/db-query/nodes/check-templates.node.ts b/src/components/db-query/nodes/check-templates.node.ts index 0948764..2aa127c 100644 --- a/src/components/db-query/nodes/check-templates.node.ts +++ b/src/components/db-query/nodes/check-templates.node.ts @@ -5,7 +5,7 @@ import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -21,7 +21,7 @@ export class CheckTemplatesNode implements IGraphNode { @inject(DbQueryAIExtensionBindings.TemplateCache) private readonly templateCache: BaseRetriever, @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @service(PermissionHelper) private readonly permissionHelper: PermissionHelper, @service(TemplateHelper) diff --git a/src/components/db-query/nodes/classify-change.node.ts b/src/components/db-query/nodes/classify-change.node.ts index 94aa648..d0b325f 100644 --- a/src/components/db-query/nodes/classify-change.node.ts +++ b/src/components/db-query/nodes/classify-change.node.ts @@ -5,7 +5,7 @@ import {inject} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryNodes} from '../nodes.enum'; import {DbQueryState} from '../state'; @@ -39,7 +39,7 @@ Do not include any other text, explanation, or formatting. constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, ) {} async execute( diff --git a/src/components/db-query/nodes/fix-query.node.ts b/src/components/db-query/nodes/fix-query.node.ts index fb294d5..bd2e9ff 100644 --- a/src/components/db-query/nodes/fix-query.node.ts +++ b/src/components/db-query/nodes/fix-query.node.ts @@ -5,7 +5,7 @@ import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider, SupportedDBs} from '../../../types'; +import {RuntimeLLMProvider, SupportedDBs} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -65,7 +65,7 @@ It should have no other character or symbol or character that is not part of SQL constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(DbSchemaHelperService) diff --git a/src/components/db-query/nodes/generate-checklist.node.ts b/src/components/db-query/nodes/generate-checklist.node.ts index 58ff5aa..d79d03b 100644 --- a/src/components/db-query/nodes/generate-checklist.node.ts +++ b/src/components/db-query/nodes/generate-checklist.node.ts @@ -5,7 +5,7 @@ import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {AIMessage} from '@langchain/core/messages'; import {DbQueryAIExtensionBindings} from '../keys'; @@ -18,7 +18,7 @@ import {DbQueryConfig} from '../types'; export class GenerateChecklistNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(DbSchemaHelperService) diff --git a/src/components/db-query/nodes/generate-description.node.ts b/src/components/db-query/nodes/generate-description.node.ts index 4c5499e..41704ff 100644 --- a/src/components/db-query/nodes/generate-description.node.ts +++ b/src/components/db-query/nodes/generate-description.node.ts @@ -5,7 +5,7 @@ import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; import {DbSchemaHelperService} from '../services'; @@ -16,7 +16,7 @@ import {DbQueryConfig} from '../types'; export class GenerateDescriptionNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(DbSchemaHelperService) diff --git a/src/components/db-query/nodes/get-columns.node.ts b/src/components/db-query/nodes/get-columns.node.ts index efbbe2d..71b98c3 100644 --- a/src/components/db-query/nodes/get-columns.node.ts +++ b/src/components/db-query/nodes/get-columns.node.ts @@ -5,7 +5,7 @@ import {service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -23,7 +23,7 @@ import { export class GetColumnsNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @service(DbSchemaHelperService) private readonly schemaHelper: DbSchemaHelperService, @inject(DbQueryAIExtensionBindings.Config) diff --git a/src/components/db-query/nodes/get-tables.node.ts b/src/components/db-query/nodes/get-tables.node.ts index 147c1ef..d342758 100644 --- a/src/components/db-query/nodes/get-tables.node.ts +++ b/src/components/db-query/nodes/get-tables.node.ts @@ -5,7 +5,7 @@ import {service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -19,9 +19,9 @@ import {DatabaseSchema, DbQueryConfig, GenerationError} from '../types'; export class GetTablesNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llmCheap: LLMProvider, + private readonly llmCheap: RuntimeLLMProvider, @inject(AiIntegrationBindings.SmartLLM) - private readonly llmSmart: LLMProvider, + private readonly llmSmart: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(DbSchemaHelperService) diff --git a/src/components/db-query/nodes/save-dataset-node.ts b/src/components/db-query/nodes/save-dataset-node.ts index 30d412b..97692f0 100644 --- a/src/components/db-query/nodes/save-dataset-node.ts +++ b/src/components/db-query/nodes/save-dataset-node.ts @@ -9,7 +9,7 @@ import {AuthenticationBindings} from 'loopback4-authentication'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType, ToolStatus} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -23,7 +23,7 @@ import {DbSchemaHelperService} from '../services'; export class SaveDataSetNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.DatasetStore) private readonly store: IDataSetStore, @inject(DbQueryAIExtensionBindings.Config) diff --git a/src/components/db-query/nodes/semantic-validator.node.ts b/src/components/db-query/nodes/semantic-validator.node.ts index 80aa625..1e8d7a6 100644 --- a/src/components/db-query/nodes/semantic-validator.node.ts +++ b/src/components/db-query/nodes/semantic-validator.node.ts @@ -5,7 +5,7 @@ import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -21,9 +21,9 @@ import {DbQueryConfig, EvaluationResult} from '../types'; export class SemanticValidatorNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.SmartLLM) - private readonly smartllm: LLMProvider, + private readonly smartllm: RuntimeLLMProvider, @inject(AiIntegrationBindings.CheapLLM) - private readonly cheapllm: LLMProvider, + private readonly cheapllm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(TableSearchService) diff --git a/src/components/db-query/nodes/sql-generation.node.ts b/src/components/db-query/nodes/sql-generation.node.ts index b96f68c..3abd79a 100644 --- a/src/components/db-query/nodes/sql-generation.node.ts +++ b/src/components/db-query/nodes/sql-generation.node.ts @@ -5,7 +5,7 @@ import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider, SupportedDBs} from '../../../types'; +import {RuntimeLLMProvider, SupportedDBs} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -75,9 +75,9 @@ In the last attempt, you generated this SQL query - `); constructor( @inject(AiIntegrationBindings.SmartLLM) - private readonly sqlLLM: LLMProvider, + private readonly sqlLLM: RuntimeLLMProvider, @inject(AiIntegrationBindings.CheapLLM) - private readonly cheapllm: LLMProvider, + private readonly cheapllm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(DbSchemaHelperService) diff --git a/src/components/db-query/nodes/syntactic-validator.node.ts b/src/components/db-query/nodes/syntactic-validator.node.ts index 9b779e6..0574ac8 100644 --- a/src/components/db-query/nodes/syntactic-validator.node.ts +++ b/src/components/db-query/nodes/syntactic-validator.node.ts @@ -5,7 +5,7 @@ import {inject} from '@loopback/context'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -16,7 +16,7 @@ import {EvaluationResult, IDbConnector} from '../types'; export class SyntacticValidatorNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Connector) private readonly connector: IDbConnector, ) {} diff --git a/src/components/db-query/nodes/verify-checklist.node.ts b/src/components/db-query/nodes/verify-checklist.node.ts index 2965229..14b816e 100644 --- a/src/components/db-query/nodes/verify-checklist.node.ts +++ b/src/components/db-query/nodes/verify-checklist.node.ts @@ -6,7 +6,7 @@ import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; @@ -18,7 +18,7 @@ import {DbQueryConfig} from '../types'; export class VerifyChecklistNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.SmartLLM) - private readonly smartLlm: LLMProvider, + private readonly smartLlm: RuntimeLLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(DbSchemaHelperService) @@ -26,10 +26,10 @@ export class VerifyChecklistNode implements IGraphNode { @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) private readonly checks?: string[], @inject(AiIntegrationBindings.SmartNonThinkingLLM, {optional: true}) - private readonly smartNonThinkingLlm?: LLMProvider, + private readonly smartNonThinkingLlm?: RuntimeLLMProvider, ) {} - private get llm(): LLMProvider { + private get llm(): RuntimeLLMProvider { return this.smartNonThinkingLlm ?? this.smartLlm; } diff --git a/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts b/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts index 8fae90f..1caaa77 100644 --- a/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts +++ b/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts @@ -2,7 +2,7 @@ import {RunnableSequence} from '@langchain/core/runnables'; import {BindingScope, inject, injectable} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {AiIntegrationBindings} from '../../../../keys'; -import {EmbeddingProvider, LLMProvider} from '../../../../types'; +import {EmbeddingProvider, RuntimeLLMProvider} from '../../../../types'; import {stripThinkingTokens} from '../../../../utils'; import {DbQueryAIExtensionBindings} from '../../keys'; import {DatabaseSchema, DbQueryConfig, TableSchema} from '../../types'; @@ -35,7 +35,7 @@ export class DbKnowledgeGraphService implements KnowledgeGraph< constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject(AiIntegrationBindings.EmbeddingModel) private readonly embeddingModel: EmbeddingProvider, @inject(DbQueryAIExtensionBindings.Config) diff --git a/src/components/db-query/services/template-helper.service.ts b/src/components/db-query/services/template-helper.service.ts index a1fc058..aabf390 100644 --- a/src/components/db-query/services/template-helper.service.ts +++ b/src/components/db-query/services/template-helper.service.ts @@ -2,7 +2,7 @@ import {PromptTemplate} from '@langchain/core/prompts'; import {RunnableSequence} from '@langchain/core/runnables'; import {inject} from '@loopback/core'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import { DatabaseSchema, @@ -22,7 +22,7 @@ type ResolvedTemplate = { export class TemplateHelper { constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, ) {} extractionPrompt = PromptTemplate.fromTemplate(` diff --git a/src/components/db-query/tools/ask-about-dataset.tool.ts b/src/components/db-query/tools/ask-about-dataset.tool.ts index ededecc..9808c2f 100644 --- a/src/components/db-query/tools/ask-about-dataset.tool.ts +++ b/src/components/db-query/tools/ask-about-dataset.tool.ts @@ -1,14 +1,14 @@ import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence, RunnableToolLike} from '@langchain/core/runnables'; -import {StructuredToolInterface, tool} from '@langchain/core/tools'; +import {RunnableSequence} from '@langchain/core/runnables'; +import {tool} from '@langchain/core/tools'; import {inject} from '@loopback/context'; import {service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import z from 'zod'; import {graphTool} from '../../../decorators'; -import {IGraphTool} from '../../../graphs'; +import {IGraphTool, IRuntimeTool} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbSchemaHelperService} from '../services'; @@ -21,7 +21,7 @@ export class AskAboutDatasetTool implements IGraphTool { @inject(DbQueryAIExtensionBindings.DatasetStore) private readonly store: IDataSetStore, @inject(AiIntegrationBindings.CheapLLM) - private readonly sqlllm: LLMProvider, + private readonly sqlllm: RuntimeLLMProvider, @service(DbSchemaHelperService) private readonly dbSchemaHelper: DbSchemaHelperService, @service(SchemaStore) @@ -48,7 +48,10 @@ export class AskAboutDatasetTool implements IGraphTool { and here is the user's question - {question}`); - async build(): Promise { + /** + * Creates a runtime-agnostic tool that answers questions about an existing dataset. + */ + async createTool(): Promise { const chain = RunnableSequence.from([ this.prompt, this.sqlllm, @@ -87,4 +90,11 @@ export class AskAboutDatasetTool implements IGraphTool { }, ); } + + /** + * @deprecated Use createTool(). + */ + async build(): Promise { + return this.createTool(); + } } diff --git a/src/components/db-query/tools/get-data-as-dataset.tool.ts b/src/components/db-query/tools/get-data-as-dataset.tool.ts index a65c10e..71a3579 100644 --- a/src/components/db-query/tools/get-data-as-dataset.tool.ts +++ b/src/components/db-query/tools/get-data-as-dataset.tool.ts @@ -2,11 +2,9 @@ import {inject, service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {z} from 'zod'; import {graphTool} from '../../../decorators'; -import {IGraphTool, ToolStatus} from '../../../graphs'; +import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../graphs'; import {DbQueryGraph} from '../db-query.graph'; import {DbQueryConfig, Errors, GenerationError} from '../types'; -import {StructuredToolInterface} from '@langchain/core/tools'; -import {RunnableToolLike} from '@langchain/core/runnables'; import {DbQueryAIExtensionBindings} from '../keys'; import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../constant'; @@ -42,7 +40,10 @@ export class GetDataAsDatasetTool implements IGraphTool { }; } - async build(): Promise { + /** + * Creates a runtime-agnostic tool for dataset generation. + */ + async createTool(): Promise { const graph = await this.queryPipeline.build(); const schema = z.object({ prompt: z @@ -60,4 +61,11 @@ export class GetDataAsDatasetTool implements IGraphTool { schema, }); } + + /** + * @deprecated Use createTool(). + */ + async build(): Promise { + return this.createTool(); + } } diff --git a/src/components/db-query/tools/improve-dataset.tool.ts b/src/components/db-query/tools/improve-dataset.tool.ts index 599c9bc..4e51310 100644 --- a/src/components/db-query/tools/improve-dataset.tool.ts +++ b/src/components/db-query/tools/improve-dataset.tool.ts @@ -2,11 +2,9 @@ import {inject, service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {z} from 'zod'; import {graphTool} from '../../../decorators'; -import {IGraphTool, ToolStatus} from '../../../graphs'; +import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../graphs'; import {DbQueryGraph} from '../db-query.graph'; import {DbQueryConfig, Errors, GenerationError} from '../types'; -import {StructuredToolInterface} from '@langchain/core/tools'; -import {RunnableToolLike} from '@langchain/core/runnables'; import {DbQueryAIExtensionBindings} from '../keys'; import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../constant'; @@ -42,7 +40,10 @@ export class ImproveDatasetTool implements IGraphTool { }; } - async build(): Promise { + /** + * Creates a runtime-agnostic tool for dataset improvement. + */ + async createTool(): Promise { const graph = await this.queryPipeline.build(); const schema = z.object({ datasetId: z @@ -61,4 +62,11 @@ export class ImproveDatasetTool implements IGraphTool { schema, }); } + + /** + * @deprecated Use createTool(). + */ + async build(): Promise { + return this.createTool(); + } } diff --git a/src/components/visualization/nodes/select-visualization.node.ts b/src/components/visualization/nodes/select-visualization.node.ts index 0e38cef..8b8ab70 100644 --- a/src/components/visualization/nodes/select-visualization.node.ts +++ b/src/components/visualization/nodes/select-visualization.node.ts @@ -2,7 +2,7 @@ import {Context, inject} from '@loopback/context'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {VisualizationGraphState} from '../state'; import {VisualizationGraphNodes} from '../nodes.enum'; import {PromptTemplate} from '@langchain/core/prompts'; @@ -57,7 +57,7 @@ none: reason why the visualization is not possible with the current prompt. `); constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject.context() private readonly context: Context, ) {} diff --git a/src/components/visualization/tools/generate-visualization.tool.ts b/src/components/visualization/tools/generate-visualization.tool.ts index 4c349ac..454a837 100644 --- a/src/components/visualization/tools/generate-visualization.tool.ts +++ b/src/components/visualization/tools/generate-visualization.tool.ts @@ -2,9 +2,7 @@ import {Context, inject, service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {z} from 'zod'; import {graphTool} from '../../../decorators'; -import {IGraphTool, ToolStatus} from '../../../graphs'; -import {StructuredToolInterface} from '@langchain/core/tools'; -import {RunnableToolLike} from '@langchain/core/runnables'; +import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../graphs'; import {VisualizationGraph} from '../visualization.graph'; import {VISUALIZATION_KEY} from '../keys'; import {IVisualizer} from '../types'; @@ -40,7 +38,10 @@ export class GenerateVisualizationTool implements IGraphTool { }; } - async build(): Promise { + /** + * Creates a runtime-agnostic visualization tool. + */ + async createTool(): Promise { const visualizations = await this._getVisualizations(); const graph = await this.visualizationGraph.build(); const schema = z.object({ @@ -73,6 +74,13 @@ It supports the following types of visualizations: ${visualizations.map(v => v.n }); } + /** + * @deprecated Use createTool(). + */ + async build(): Promise { + return this.createTool(); + } + private async _getVisualizations() { const bindings = this.context.findByTag({ [VISUALIZATION_KEY]: true, diff --git a/src/components/visualization/visualizers/bar.visualizer.ts b/src/components/visualization/visualizers/bar.visualizer.ts index 25fb6f1..cd44ad5 100644 --- a/src/components/visualization/visualizers/bar.visualizer.ts +++ b/src/components/visualization/visualizers/bar.visualizer.ts @@ -1,7 +1,7 @@ import {PromptTemplate} from '@langchain/core/prompts'; import {IVisualizer} from '../types'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {inject} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {VisualizationGraphState} from '../state'; @@ -53,7 +53,7 @@ You are an expert data visualization assistant. Your task is to create a bar cha constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, ) {} async getConfig(state: VisualizationGraphState): Promise { diff --git a/src/components/visualization/visualizers/line.visualizer.ts b/src/components/visualization/visualizers/line.visualizer.ts index 5acc79f..d0356a5 100644 --- a/src/components/visualization/visualizers/line.visualizer.ts +++ b/src/components/visualization/visualizers/line.visualizer.ts @@ -1,7 +1,7 @@ import {PromptTemplate} from '@langchain/core/prompts'; import {IVisualizer} from '../types'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {inject} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {VisualizationGraphState} from '../state'; @@ -58,7 +58,7 @@ You are an expert data visualization assistant. Your task is to create a line ch constructor( @inject(AiIntegrationBindings.SmartNonThinkingLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, ) {} async getConfig(state: VisualizationGraphState): Promise { diff --git a/src/components/visualization/visualizers/pie.visualizer.ts b/src/components/visualization/visualizers/pie.visualizer.ts index 0fd2f65..4004ce5 100644 --- a/src/components/visualization/visualizers/pie.visualizer.ts +++ b/src/components/visualization/visualizers/pie.visualizer.ts @@ -1,7 +1,7 @@ import {PromptTemplate} from '@langchain/core/prompts'; import {IVisualizer} from '../types'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {inject} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {VisualizationGraphState} from '../state'; @@ -47,7 +47,7 @@ You are an expert data visualization assistant. Your task is to create a pie cha constructor( @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, ) {} async getConfig(state: VisualizationGraphState): Promise { diff --git a/src/graphs/base.graph.ts b/src/graphs/base.graph.ts index 01d5290..cc61e4c 100644 --- a/src/graphs/base.graph.ts +++ b/src/graphs/base.graph.ts @@ -2,7 +2,7 @@ import {CompiledGraph} from '@langchain/langgraph'; import {Context, inject} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {GRAPH_NODE_NAME} from '../constant'; -import {IGraphNode} from './types'; +import {IGraphNode, resolveNodeExecution} from './types'; export abstract class BaseGraph { @inject.context() @@ -22,6 +22,6 @@ export abstract class BaseGraph { } const binding = bindings[0]; const node = await this.context.get>(binding.key); - return node.execute.bind(node); + return resolveNodeExecution(node); } } diff --git a/src/graphs/chat/nodes/call-llm.node.ts b/src/graphs/chat/nodes/call-llm.node.ts index cfddd02..b065798 100644 --- a/src/graphs/chat/nodes/call-llm.node.ts +++ b/src/graphs/chat/nodes/call-llm.node.ts @@ -4,11 +4,11 @@ import {service} from '@loopback/core'; import {HttpErrors} from '@loopback/rest'; import {graphNode} from '../../../decorators'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider, ToolStore} from '../../../types'; +import {RuntimeLLMProvider, ToolStore} from '../../../types'; import {getTextContent} from '../../../utils'; import {LLMStreamEventType} from '../../event.types'; import {ChatState} from '../../state'; -import {IGraphNode, RunnableConfig} from '../../types'; +import {IGraphNode, resolveGraphTool, RunnableConfig} from '../../types'; import {ChatStore} from '../chat.store'; import {ChatNodes} from '../nodes.enum'; @@ -18,7 +18,7 @@ const debug = require('debug')('ai-integration:chat:call-llm.node'); export class CallLLMNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.ChatLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @inject(AiIntegrationBindings.Tools) private readonly tools: ToolStore, @service(ChatStore) @@ -27,7 +27,7 @@ export class CallLLMNode implements IGraphNode { async execute(state: ChatState, config: RunnableConfig): Promise { const tools = await Promise.all( - this.tools.list.map(tool => tool.build(config)), + this.tools.list.map(tool => resolveGraphTool(tool, config)), ); debug( 'Calling LLM with tools:', diff --git a/src/graphs/chat/nodes/run-tool.node.ts b/src/graphs/chat/nodes/run-tool.node.ts index 3568fdd..382723f 100644 --- a/src/graphs/chat/nodes/run-tool.node.ts +++ b/src/graphs/chat/nodes/run-tool.node.ts @@ -7,7 +7,12 @@ import {AiIntegrationBindings} from '../../../keys'; import {ToolStore} from '../../../types'; import {LLMStreamEventType} from '../../event.types'; import {ChatState} from '../../state'; -import {IGraphNode, RunnableConfig, ToolStatus} from '../../types'; +import { + IGraphNode, + resolveGraphTool, + RunnableConfig, + ToolStatus, +} from '../../types'; import {ChatStore} from '../chat.store'; import {ChatNodes} from '../nodes.enum'; @@ -55,7 +60,7 @@ export class RunToolNode implements IGraphNode { }, }); const toolObj = tools[toolCall.name as keyof typeof tools]; - const tool = await toolObj.build(config); + const tool = await resolveGraphTool(toolObj, config); config.writer?.({ type: LLMStreamEventType.Log, data: `Running tool: ${toolCall.name} with args: ${JSON.stringify(toolCall.args, undefined, 2)}`, @@ -70,7 +75,7 @@ export class RunToolNode implements IGraphNode { }); const toolMessage = new ToolMessage({ name: toolCall.name, - content: output, + content: String(output), // eslint-disable-next-line @typescript-eslint/naming-convention tool_call_id: toolCall.id!, }); diff --git a/src/graphs/chat/nodes/summarise-file.node.ts b/src/graphs/chat/nodes/summarise-file.node.ts index 40f73ba..657fd00 100644 --- a/src/graphs/chat/nodes/summarise-file.node.ts +++ b/src/graphs/chat/nodes/summarise-file.node.ts @@ -8,7 +8,7 @@ import {AnyObject} from '@loopback/repository'; import {HttpErrors} from '@loopback/rest'; import {graphNode} from '../../../decorators'; import {AiIntegrationBindings} from '../../../keys'; -import {LLMProvider} from '../../../types'; +import {RuntimeLLMProvider} from '../../../types'; import {mergeAttachments, stripThinkingTokens} from '../../../utils'; import {LLMStreamEventType} from '../../event.types'; import {ChatState} from '../../state'; @@ -22,7 +22,7 @@ const debug = require('debug')('ai-integration:chat:summarise-file.node'); export class SummariseFileNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.FileLLM) - private readonly llm: LLMProvider, + private readonly llm: RuntimeLLMProvider, @service(ChatStore) private readonly chatStore: ChatStore, ) {} diff --git a/src/graphs/types.ts b/src/graphs/types.ts index 4b324e2..b727039 100644 --- a/src/graphs/types.ts +++ b/src/graphs/types.ts @@ -1,30 +1,99 @@ import {AIMessage, HumanMessage, ToolMessage} from '@langchain/core/messages'; -import {RunnableToolLike} from '@langchain/core/runnables'; -import {StructuredToolInterface} from '@langchain/core/tools'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; import {AnyObject, Command} from '@loopback/repository'; -import {LLMStreamEvent} from './event.types'; -export type RunnableConfig = LangGraphRunnableConfig & { - writer?: (event: LLMStreamEvent) => void; +/** + * Runtime-agnostic execution config that can carry stream writers and runtime metadata. + */ +export type RunnableConfig = { + configurable?: Record; + signal?: AbortSignal; + writer?: (chunk: unknown) => void; }; +/** + * Node step execution function compatible with Mastra-like step execution. + */ +export type GraphStepExecuteFn = ( + state: T, + config: RunnableConfig, +) => Promise | Command>; + +/** + * Minimal step contract required by Phase 1 interface migration. + */ +export interface IGraphStep { + execute: GraphStepExecuteFn; +} + +/** + * Graph node contract supporting both legacy `execute` and Mastra-style `createStep`. + */ export interface IGraphNode { - execute: (state: T, config: RunnableConfig) => Promise | Command>; + createStep?(config?: RunnableConfig): Promise> | IGraphStep; + execute?: GraphStepExecuteFn; } export type SavedMessage = HumanMessage | AIMessage | ToolMessage; +/** + * Minimal runtime tool contract shared across LangGraph and Mastra-compatible tooling. + */ +export interface IRuntimeTool { + name: string; + invoke(input: TArgs): Promise; +} + +/** + * Tool contract supporting Mastra-style `createTool` and legacy `build` for compatibility. + */ export interface IGraphTool { key: string; - build( - config: LangGraphRunnableConfig, - ): Promise; - getValue?(result: Record): string; - getMetadata?(result: Record): AnyObject; + createTool?(config: RunnableConfig): Promise; + /** + * @deprecated Use `createTool()`. + */ + build?(config: RunnableConfig): Promise; + getValue?(result: unknown): string; + getMetadata?(result: unknown): AnyObject; needsReview?: boolean; } +/** + * Resolves the executable function for a node, preferring `execute` and falling back to `createStep`. + */ +export async function resolveNodeExecution( + node: IGraphNode, +): Promise> { + if (node.execute) { + return node.execute.bind(node); + } + + if (node.createStep) { + const step = await node.createStep(); + return step.execute.bind(step); + } + + throw new Error('Node must implement either execute() or createStep().'); +} + +/** + * Resolves a runtime tool from the migrated contract while preserving legacy fallback. + */ +export async function resolveGraphTool( + tool: IGraphTool, + config: RunnableConfig, +): Promise { + if (tool.createTool) { + return tool.createTool(config); + } + + if (tool.build) { + return tool.build(config); + } + + throw new Error(`Tool ${tool.key} does not implement createTool().`); +} + export type IGraphDirectEdge = { from: string; to: string; diff --git a/src/keys.ts b/src/keys.ts index 22371a1..cb0beb2 100644 --- a/src/keys.ts +++ b/src/keys.ts @@ -1,12 +1,16 @@ import {VectorStore as VectorStoreType} from '@langchain/core/vectorstores'; -import {BaseCheckpointSaver} from '@langchain/langgraph'; import {BindingKey} from '@loopback/context'; +import { + IMastraBridge, + MastraRuntimeFactory as MastraRuntimeFactoryType, +} from './services/mastra-bridge.service'; import {ITransport} from './transports/types'; import { AIIntegrationConfig, EmbeddingProvider, ICache, - LLMProvider, + IWorkflowPersistence, + RuntimeLLMProvider, ToolStore, } from './types'; import {ILimitStrategy} from './services/limit-strategies/types'; @@ -15,27 +19,31 @@ export namespace AiIntegrationBindings { export const Config = BindingKey.create( 'services.ai-reporting.config', ); - export const SmartLLM = BindingKey.create( + export const SmartLLM = BindingKey.create( 'services.ai-reporting.smartLLMProvider', ); - export const CheapLLM = BindingKey.create( + export const CheapLLM = BindingKey.create( 'services.ai-reporting.cheapLLMProvider', ); - export const FileLLM = BindingKey.create( + export const FileLLM = BindingKey.create( 'services.ai-reporting.fileLLMProvider', ); - export const ChatLLM = BindingKey.create( + export const ChatLLM = BindingKey.create( 'services.ai-reporting.chatLLMProvider', ); - export const SmartNonThinkingLLM = BindingKey.create( + export const SmartNonThinkingLLM = BindingKey.create( 'services.ai-reporting.smartNonThinkingLLMProvider', ); export const EmbeddingModel = BindingKey.create( 'services.ai-reporting.embeddingModel', ); - export const Checkpointer = BindingKey.create( - 'services.ai-reporting.checkpointer', + export const WorkflowPersistence = BindingKey.create( + 'services.ai-reporting.workflow-persistence', ); + /** + * @deprecated Use `WorkflowPersistence`. + */ + export const Checkpointer = WorkflowPersistence; export const Tools = BindingKey.create( 'services.ai-reporting.tool-store', ); @@ -52,6 +60,13 @@ export namespace AiIntegrationBindings { export const ObfHandler = BindingKey.create( 'services.ai-reporting.obf-handler', ); + export const MastraBridge = BindingKey.create( + 'services.ai-reporting.mastra-bridge', + ); + export const MastraRuntimeFactory = + BindingKey.create( + 'services.ai-reporting.mastra-runtime-factory', + ); export const SystemContext = BindingKey.create( `services.ai-reporting.system-context`, ); diff --git a/src/services/index.ts b/src/services/index.ts index cffa417..7c60c0f 100644 --- a/src/services/index.ts +++ b/src/services/index.ts @@ -1,3 +1,5 @@ export * from './generation.service'; +export * from './mastra-bridge.observer'; +export * from './mastra-bridge.service'; export * from './token-counter.service'; export * from './limit-strategies'; diff --git a/src/services/mastra-bridge.observer.ts b/src/services/mastra-bridge.observer.ts new file mode 100644 index 0000000..26a7cdc --- /dev/null +++ b/src/services/mastra-bridge.observer.ts @@ -0,0 +1,20 @@ +import {inject, LifeCycleObserver} from '@loopback/core'; +import {AiIntegrationBindings} from '../keys'; +import {IMastraBridge} from './mastra-bridge.service'; + +/** + * Initializes the Phase 0 bridge at application startup. + */ +export class MastraBridgeObserver implements LifeCycleObserver { + constructor( + @inject(AiIntegrationBindings.MastraBridge) + private readonly mastraBridge: IMastraBridge, + ) {} + + /** + * Bootstraps bridge discovery once during app start. + */ + async start(): Promise { + await this.mastraBridge.initialize(); + } +} diff --git a/src/services/mastra-bridge.service.ts b/src/services/mastra-bridge.service.ts new file mode 100644 index 0000000..cf7a632 --- /dev/null +++ b/src/services/mastra-bridge.service.ts @@ -0,0 +1,167 @@ +import { + GRAPH_NODE_NAME, + GRAPH_NODE_TAG, + TOOL_NAME, + TOOL_TAG, +} from '../constant'; +import {IGraphNode, IGraphTool} from '../graphs/types'; +import {AiIntegrationBindings} from '../keys'; +import {AnyObject} from '@loopback/repository'; +import {Context, inject, injectable, BindingScope} from '@loopback/core'; + +/** + * Lazy resolver for LoopBack-managed instances. + */ +export type BindingResolver = () => Promise; + +/** + * Descriptor for a graph node registered in LoopBack's IoC container. + */ +export interface GraphNodeBindingDescriptor { + bindingKey: string; + key: string; + resolve: BindingResolver>; +} + +/** + * Descriptor for a graph tool registered in LoopBack's IoC container. + */ +export interface GraphToolBindingDescriptor { + bindingKey: string; + key: string; + name: string; + resolve: BindingResolver; +} + +/** + * Payload used to initialize a Mastra runtime instance from LoopBack bindings. + */ +export interface MastraBootstrapPayload { + nodes: GraphNodeBindingDescriptor[]; + tools: GraphToolBindingDescriptor[]; +} + +/** + * Minimal runtime contract required by the migration bridge. + */ +export interface MastraRuntimeAdapter { + getAgent(name: string): T | undefined; + getWorkflow(name: string): T | undefined; +} + +/** + * Factory contract for creating a runtime adapter from LoopBack-registered artifacts. + */ +export type MastraRuntimeFactory = ( + payload: MastraBootstrapPayload, +) => Promise | MastraRuntimeAdapter; + +/** + * Public bridge contract exposed through LoopBack bindings. + */ +export interface IMastraBridge { + initialize(): Promise; + getAgent(name: string): T | undefined; + getWorkflow(name: string): T | undefined; + getBootstrapSnapshot(): MastraBootstrapPayload; +} + +/** + * Default no-op runtime adapter used until a real Mastra runtime is bound. + */ +class NoopMastraRuntimeAdapter implements MastraRuntimeAdapter { + /** + * Returns undefined for all agents in no-op mode. + */ + getAgent(_name: string): T | undefined { + return undefined; + } + + /** + * Returns undefined for all workflows in no-op mode. + */ + getWorkflow(_name: string): T | undefined { + return undefined; + } +} + +/** + * Phase 0 migration bridge that discovers LoopBack graph artifacts and creates + * a Mastra runtime adapter without changing current LangGraph behavior. + */ +@injectable({scope: BindingScope.SINGLETON}) +export class MastraBridgeService implements IMastraBridge { + private runtime: MastraRuntimeAdapter = new NoopMastraRuntimeAdapter(); + private payload: MastraBootstrapPayload = {nodes: [], tools: []}; + private initialized = false; + + constructor( + @inject.context() + private readonly context: Context, + @inject(AiIntegrationBindings.MastraRuntimeFactory, {optional: true}) + private readonly runtimeFactory?: MastraRuntimeFactory, + ) {} + + /** + * Initializes the bridge once by collecting tagged bindings and creating the runtime adapter. + */ + async initialize(): Promise { + if (this.initialized) { + return; + } + + this.payload = this.collectBootstrapPayload(); + this.runtime = this.runtimeFactory + ? await this.runtimeFactory(this.payload) + : new NoopMastraRuntimeAdapter(); + this.initialized = true; + } + + /** + * Returns a typed agent instance from the runtime adapter. + */ + getAgent(name: string): T | undefined { + return this.runtime.getAgent(name); + } + + /** + * Returns a typed workflow instance from the runtime adapter. + */ + getWorkflow(name: string): T | undefined { + return this.runtime.getWorkflow(name); + } + + /** + * Returns a snapshot of discovered LoopBack artifacts registered for runtime bootstrap. + */ + getBootstrapSnapshot(): MastraBootstrapPayload { + return this.payload; + } + + /** + * Discovers all node and tool bindings and exposes lazy resolvers for runtime construction. + */ + private collectBootstrapPayload(): MastraBootstrapPayload { + const nodeBindings = this.context.findByTag({ + [GRAPH_NODE_TAG]: true, + }); + const toolBindings = this.context.findByTag({ + [TOOL_TAG]: true, + }); + + const nodes: GraphNodeBindingDescriptor[] = nodeBindings.map(binding => ({ + bindingKey: binding.key, + key: String(binding.tagMap?.[GRAPH_NODE_NAME] ?? binding.key), + resolve: async () => this.context.get>(binding.key), + })); + + const tools: GraphToolBindingDescriptor[] = toolBindings.map(binding => ({ + bindingKey: binding.key, + key: String(binding.key), + name: String(binding.tagMap?.[TOOL_NAME] ?? binding.key), + resolve: async () => this.context.get(binding.key), + })); + + return {nodes, tools}; + } +} diff --git a/src/sub-modules/providers/anthropic/llms/anthropic.provider.ts b/src/sub-modules/providers/anthropic/llms/anthropic.provider.ts index 85786aa..c44ebde 100644 --- a/src/sub-modules/providers/anthropic/llms/anthropic.provider.ts +++ b/src/sub-modules/providers/anthropic/llms/anthropic.provider.ts @@ -1,10 +1,10 @@ import {AnthropicInput, ChatAnthropic} from '@langchain/anthropic'; import {Provider, ValueOrPromise} from '@loopback/core'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {BaseChatModelParams} from '@langchain/core/language_models/chat_models'; -export class Claude implements Provider { - value(): ValueOrPromise { +export class Claude implements Provider { + value(): ValueOrPromise { if (!process.env.CLAUDE_MODEL || !process.env.CLAUDE_API_KEY) { throw new Error( 'CLAUDE_MODEL and CLAUDE_API_KEY environment variables must be set', diff --git a/src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts b/src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts index fe7fe79..06fcd75 100644 --- a/src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts +++ b/src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts @@ -1,12 +1,12 @@ import {Provider, ValueOrPromise} from '@loopback/core'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {Bedrock} from './bedrock.provider'; export class BedrockNonThinking extends Bedrock - implements Provider + implements Provider { - value(): ValueOrPromise { + value(): ValueOrPromise { return this._createdInstance(false); } } diff --git a/src/sub-modules/providers/aws/llms/bedrock.provider.ts b/src/sub-modules/providers/aws/llms/bedrock.provider.ts index 2ab9f88..d82da1d 100644 --- a/src/sub-modules/providers/aws/llms/bedrock.provider.ts +++ b/src/sub-modules/providers/aws/llms/bedrock.provider.ts @@ -1,13 +1,13 @@ import {ChatBedrockConverse, ChatBedrockConverseInput} from '@langchain/aws'; import {Provider, ValueOrPromise} from '@loopback/core'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {sanitizeFilenameForAwsConverse} from '../utils'; import {BedrockInstanceConfig} from '../types'; -export class Bedrock implements Provider { +export class Bedrock implements Provider { static createInstance(config: BedrockInstanceConfig): ChatBedrockConverse { const client = new ChatBedrockConverse(config); - (client as unknown as LLMProvider).getFile = ( + (client as unknown as RuntimeLLMProvider).getFile = ( file: Express.Multer.File, ) => { return { @@ -23,7 +23,7 @@ export class Bedrock implements Provider { }; return client; } - value(): ValueOrPromise { + value(): ValueOrPromise { return this._createdInstance(true); } diff --git a/src/sub-modules/providers/cerebras/llm/cerebras.provider.ts b/src/sub-modules/providers/cerebras/llm/cerebras.provider.ts index 7f14a77..ee03ac7 100644 --- a/src/sub-modules/providers/cerebras/llm/cerebras.provider.ts +++ b/src/sub-modules/providers/cerebras/llm/cerebras.provider.ts @@ -1,8 +1,8 @@ import {ChatCerebras, ChatCerebrasInput} from '@langchain/cerebras'; import {Provider} from '@loopback/core'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; -export class Cerebras implements Provider { +export class Cerebras implements Provider { value() { if (!process.env.CEREBRAS_MODEL || !process.env.CEREBRAS_KEY) { throw new Error( diff --git a/src/sub-modules/providers/google/llms/gemini.provider.ts b/src/sub-modules/providers/google/llms/gemini.provider.ts index ea1c58b..2c91cb8 100644 --- a/src/sub-modules/providers/google/llms/gemini.provider.ts +++ b/src/sub-modules/providers/google/llms/gemini.provider.ts @@ -1,8 +1,8 @@ import {Provider} from '@loopback/core'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {ChatGoogleGenerativeAI} from '@langchain/google-genai'; -export class Gemini implements Provider { +export class Gemini implements Provider { value() { if (!process.env.GOOGLE_CHAT_MODEL || !process.env.GOOGLE_API_KEY) { throw new Error( diff --git a/src/sub-modules/providers/groq/llms/groq.provider.ts b/src/sub-modules/providers/groq/llms/groq.provider.ts index 60e8c5a..5818ae7 100644 --- a/src/sub-modules/providers/groq/llms/groq.provider.ts +++ b/src/sub-modules/providers/groq/llms/groq.provider.ts @@ -1,9 +1,9 @@ import {Provider} from '@loopback/core'; import {ChatGroq} from '@langchain/groq'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; -export class Groq implements Provider { - value(): LLMProvider { +export class Groq implements Provider { + value(): RuntimeLLMProvider { if (!process.env.GROQ_MODEL || !process.env.GROQ_API_KEY) { throw new Error( 'GROQ_MODEL and GROQ_API_KEY environment variable is not set.', diff --git a/src/sub-modules/providers/openai/llms/openai.provider.ts b/src/sub-modules/providers/openai/llms/openai.provider.ts index 5bd2217..a194263 100644 --- a/src/sub-modules/providers/openai/llms/openai.provider.ts +++ b/src/sub-modules/providers/openai/llms/openai.provider.ts @@ -1,16 +1,16 @@ import {Provider} from '@loopback/core'; -import {LLMProvider} from '../../../../types'; +import {RuntimeLLMProvider} from '../../../../types'; import {ChatOpenAI} from '@langchain/openai'; import {OpenAIInstanceConfig} from '../types'; -export class OpenAI implements Provider { +export class OpenAI implements Provider { static createInstance(config: OpenAIInstanceConfig): ChatOpenAI { return new ChatOpenAI({ model: config.model, ...config.config, }); } - value(): LLMProvider { + value(): RuntimeLLMProvider { return OpenAI.createInstance({ model: process.env.OPENAI_MODEL!, config: { diff --git a/src/types.ts b/src/types.ts index 32b4582..b7295f9 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,24 +1,24 @@ -import {ChatAnthropic} from '@langchain/anthropic'; -import {BedrockEmbeddings, ChatBedrockConverse} from '@langchain/aws'; -import {ChatCerebras} from '@langchain/cerebras'; -import { - ChatGoogleGenerativeAI, - GoogleGenerativeAIEmbeddings, -} from '@langchain/google-genai'; -import {BaseCheckpointSaver} from '@langchain/langgraph'; -import {ChatOllama, OllamaEmbeddings} from '@langchain/ollama'; -import {ChatOpenAI, OpenAIEmbeddings} from '@langchain/openai'; +import {BedrockEmbeddings} from '@langchain/aws'; +import {GoogleGenerativeAIEmbeddings} from '@langchain/google-genai'; +import {OllamaEmbeddings} from '@langchain/ollama'; +import {OpenAIEmbeddings} from '@langchain/openai'; +import {LanguageModel} from 'ai'; import {Provider} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; +import {AIMessage} from '@langchain/core/messages'; +import {RunnableConfig, RunnableInterface} from '@langchain/core/runnables'; import {IGraphTool} from './graphs/types'; -import {ChatGroq} from '@langchain/groq'; export enum SupportedDBs { PostgreSQL = 'PostgreSQL', SQLite = 'SQLite', } +/** + * Global component configuration consumed by the LoopBack integration component. + */ export type AIIntegrationConfig = { + runtime?: RuntimeEngine; useCustomSequence?: boolean; mountCore?: boolean; mountFileUtils?: boolean; @@ -34,20 +34,68 @@ export type AIIntegrationConfig = { }; }; +/** + * Runtime engine selector used for phased migration and rollbacks. + */ +export type RuntimeEngine = 'langgraph' | 'mastra'; + export type FileMessageBuilder = (file: Express.Multer.File) => AnyObject; -export type LLMProviderType = - | ChatOllama - | ChatCerebras - | ChatOpenAI - | ChatAnthropic - | ChatBedrockConverse - | ChatGoogleGenerativeAI - | ChatGroq; +/** + * Primary provider contract for Phase 1. This maps directly to the AI SDK model contract. + */ +export type LLMProvider = LanguageModel; -export type LLMProvider = LLMProviderType & { +/** + * Legacy LangGraph-compatible LLM contract used by existing graph implementations. + * + * The structure intentionally mirrors the methods used by current nodes so concrete + * LangChain chat models remain assignable without direct class dependencies. + */ +export type LegacyLLMProvider = { + bindTools( + tools: unknown[], + ): RunnableInterface< + unknown, + AIMessage, + RunnableConfig> + >; + invoke(input: unknown): Promise; + withStructuredOutput( + schema: unknown, + ): RunnableInterface>>; getFile?: FileMessageBuilder; -}; +} & RunnableInterface< + unknown, + AIMessage, + RunnableConfig> +>; + +/** + * Adapter contract for converting an AI SDK model into the legacy tool-calling interface + * while LangGraph execution remains active. + */ +export interface ILegacyLLMProviderAdapter { + toLegacyLLMProvider(): LegacyLLMProvider; +} + +/** + * Runtime-compatible union used by existing LangGraph execution paths during migration. + */ +export type RuntimeLLMProvider = LegacyLLMProvider & + Partial; + +/** + * Resolves a runtime-compatible provider into a legacy execution contract. + */ +export function resolveLegacyLLMProvider( + provider: RuntimeLLMProvider, +): LegacyLLMProvider { + if (provider.toLegacyLLMProvider) { + return provider.toLegacyLLMProvider(); + } + return provider; +} export type EmbeddingProvider = | OpenAIEmbeddings @@ -55,7 +103,23 @@ export type EmbeddingProvider = | BedrockEmbeddings | GoogleGenerativeAIEmbeddings; -export type CheckpointerProvider = Provider; +/** + * Runtime persistence contract used by workflow/checkpoint adapters. + */ +export interface IWorkflowPersistence { + save(runId: string, state: AnyObject): Promise; + load(runId: string): Promise; +} + +/** + * Provider contract for workflow persistence adapters. + */ +export type WorkflowPersistenceProvider = Provider; + +/** + * @deprecated Use `WorkflowPersistenceProvider`. + */ +export type CheckpointerProvider = WorkflowPersistenceProvider; export type ToolStore = { list: IGraphTool[]; From 56f57cac460eebe7328f31f55c9e997e5ffac09f Mon Sep 17 00:00:00 2001 From: Piyush Singh Gaur Date: Wed, 29 Apr 2026 10:06:44 +0530 Subject: [PATCH 2/3] feat: phase 2complete chatGraph migration added --- .../generation.service.integration.ts | 8 +- src/__tests__/unit/mastra-bridge.unit.ts | 8 +- src/component.ts | 3 + .../db-query/tools/ask-about-dataset.tool.ts | 48 ++- .../tools/get-data-as-dataset.tool.ts | 25 +- .../db-query/tools/improve-dataset.tool.ts | 27 +- .../tools/generate-visualization.tool.ts | 53 +++- src/decorators/tool.decorator.ts | 22 +- src/graphs/types.ts | 27 ++ src/index.ts | 1 + src/mastra/chat/mappers/message.mapper.ts | 60 ++++ src/mastra/chat/mastra-chat.agent.ts | 273 ++++++++++++++++++ .../chat/steps/context-compression.step.ts | 47 +++ src/mastra/chat/steps/init-session.step.ts | 59 ++++ src/mastra/chat/steps/save-step.step.ts | 66 +++++ src/mastra/chat/steps/stream-handler.step.ts | 268 +++++++++++++++++ src/mastra/chat/steps/summarise-file.step.ts | 99 +++++++ src/mastra/chat/types/chat.types.ts | 23 ++ src/mastra/chat/utils/safe-json.util.ts | 11 + .../chat/utils/token-accumulator.util.ts | 20 ++ src/mastra/index.ts | 3 + src/mastra/request-tool-store.ts | 34 +++ src/mastra/types.ts | 146 ++++++++++ src/providers/tools.provider.ts | 7 + src/services/generation.service.ts | 49 +++- src/services/mastra-bridge.service.ts | 95 +++++- 26 files changed, 1445 insertions(+), 37 deletions(-) create mode 100644 src/mastra/chat/mappers/message.mapper.ts create mode 100644 src/mastra/chat/mastra-chat.agent.ts create mode 100644 src/mastra/chat/steps/context-compression.step.ts create mode 100644 src/mastra/chat/steps/init-session.step.ts create mode 100644 src/mastra/chat/steps/save-step.step.ts create mode 100644 src/mastra/chat/steps/stream-handler.step.ts create mode 100644 src/mastra/chat/steps/summarise-file.step.ts create mode 100644 src/mastra/chat/types/chat.types.ts create mode 100644 src/mastra/chat/utils/safe-json.util.ts create mode 100644 src/mastra/chat/utils/token-accumulator.util.ts create mode 100644 src/mastra/index.ts create mode 100644 src/mastra/request-tool-store.ts create mode 100644 src/mastra/types.ts diff --git a/src/__tests__/integration/generation.service.integration.ts b/src/__tests__/integration/generation.service.integration.ts index 0a4a625..3e332e7 100644 --- a/src/__tests__/integration/generation.service.integration.ts +++ b/src/__tests__/integration/generation.service.integration.ts @@ -8,6 +8,7 @@ import { } from '@loopback/testlab'; import {PassThrough} from 'stream'; import {ChatGraph, LLMStreamEvent} from '../../graphs'; +import {MastraChatAgent} from '../../mastra'; import {GenerationService} from '../../services'; import {HttpTransport, SSETransport} from '../../transports'; @@ -16,10 +17,12 @@ describe(`GenerationService Integration`, () => { let dummyRequest: Request; let dummyResponse: Response; let graph: StubbedInstanceWithSinonAccessor; + let mastraAgent: StubbedInstanceWithSinonAccessor; describe('with SSETransport', () => { beforeEach(() => { graph = createStubInstance(ChatGraph); + mastraAgent = createStubInstance(MastraChatAgent); dummyResponse = { write: sinon.stub(), end: sinon.stub(), @@ -30,7 +33,7 @@ describe(`GenerationService Integration`, () => { once: sinon.stub(), } as unknown as Request; const transport = new SSETransport(dummyResponse, dummyRequest); - service = new GenerationService(graph, transport); + service = new GenerationService(graph, mastraAgent, transport, undefined); }); it('should handle generation request and return response', async () => { const dummyStream = new PassThrough({objectMode: true}); @@ -140,6 +143,7 @@ describe(`GenerationService Integration`, () => { describe('with HttpTransport', () => { beforeEach(() => { graph = createStubInstance(ChatGraph); + mastraAgent = createStubInstance(MastraChatAgent); dummyResponse = { write: sinon.stub(), end: sinon.stub(), @@ -150,7 +154,7 @@ describe(`GenerationService Integration`, () => { once: sinon.stub(), } as unknown as Request; const transport = new HttpTransport(dummyResponse, dummyRequest); - service = new GenerationService(graph, transport); + service = new GenerationService(graph, mastraAgent, transport, undefined); }); it('should handle generation request and return response', async () => { const dummyStream = new PassThrough({objectMode: true}); diff --git a/src/__tests__/unit/mastra-bridge.unit.ts b/src/__tests__/unit/mastra-bridge.unit.ts index 08185ab..bf605ae 100644 --- a/src/__tests__/unit/mastra-bridge.unit.ts +++ b/src/__tests__/unit/mastra-bridge.unit.ts @@ -81,11 +81,11 @@ describe('MastraBridgeService Unit', () => { await bridge.initialize(); expect(callCount).to.equal(1); - expect(bridge.getAgent<{name: string}>('chat-agent')?.name).to.equal( + expect(bridge.getTypedAgent<{name: string}>('chat-agent')?.name).to.equal( 'agent', ); - expect(bridge.getWorkflow<{name: string}>('db-workflow')?.name).to.equal( - 'workflow', - ); + expect( + bridge.getTypedWorkflow<{name: string}>('db-workflow')?.name, + ).to.equal('workflow'); }); }); diff --git a/src/component.ts b/src/component.ts index 619ced2..df0450f 100644 --- a/src/component.ts +++ b/src/component.ts @@ -49,6 +49,7 @@ import {Chat, Message} from './models'; import {CacheModel, ToolsProvider} from './providers'; import {RedisCache, RedisCacheRepository} from './providers/cache/redis'; import {ChatRepository, MessageRepository} from './repositories'; +import {MastraChatAgent} from './mastra'; import { ChatCountStrategy, GenerationService, @@ -96,6 +97,8 @@ export class AiIntegrationsComponent implements Component { TokenCounter, GenerationService, ChatStore, + // mastra + MastraChatAgent, // graph ChatGraph, // nodes diff --git a/src/components/db-query/tools/ask-about-dataset.tool.ts b/src/components/db-query/tools/ask-about-dataset.tool.ts index 9808c2f..c09c132 100644 --- a/src/components/db-query/tools/ask-about-dataset.tool.ts +++ b/src/components/db-query/tools/ask-about-dataset.tool.ts @@ -1,7 +1,7 @@ import {PromptTemplate} from '@langchain/core/prompts'; import {RunnableSequence} from '@langchain/core/runnables'; import {tool} from '@langchain/core/tools'; -import {inject} from '@loopback/context'; +import {Context, inject} from '@loopback/context'; import {service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import z from 'zod'; @@ -15,7 +15,18 @@ import {DbSchemaHelperService} from '../services'; import {SchemaStore} from '../services/schema.store'; import {IDataSetStore} from '../types'; -@graphTool() +@graphTool({ + description: + 'Tool for answering questions about an existing dataset, note that it can only answer questions about the dataset definition, not the data it contains. Call this only if you have a valid dataset ID available.', + inputSchema: z.object({ + datasetId: z + .string() + .describe('uuid ID of the dataset to answer the question for'), + question: z + .string() + .describe('The question that the user asked about the query.'), + }), +}) export class AskAboutDatasetTool implements IGraphTool { constructor( @inject(DbQueryAIExtensionBindings.DatasetStore) @@ -26,12 +37,25 @@ export class AskAboutDatasetTool implements IGraphTool { private readonly dbSchemaHelper: DbSchemaHelperService, @service(SchemaStore) private readonly schemaStore: SchemaStore, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], + // Use context injection so GlobalContext is resolved lazily at call time, + // not at construction time. This allows the tool to be instantiated from + // the application (singleton) context without requiring a live request. + @inject.context() + private readonly _ctx: Context, ) {} key = 'ask-about-dataset'; needsReview = false; + description = + 'Tool for answering questions about an existing dataset, note that it can only answer questions about the dataset definition, not the data it contains. Call this only if you have a valid dataset ID available.'; + inputSchema = z.object({ + datasetId: z + .string() + .describe('uuid ID of the dataset to answer the question for'), + question: z + .string() + .describe('The question that the user asked about the query.'), + }); private readonly prompt = PromptTemplate.fromTemplate(`You are an AI assistant that answers questions about a query, without revealing any technical details, you need to answer the question the user's question. @@ -52,6 +76,20 @@ export class AskAboutDatasetTool implements IGraphTool { * Creates a runtime-agnostic tool that answers questions about an existing dataset. */ async createTool(): Promise { + // Resolve GlobalContext lazily. When called from a request context the + // checks will be populated; when called from the application context + // (e.g. during Mastra bridge startup) the resolution may fail and we + // gracefully fall back to an empty list. + let checks: string[] | undefined; + try { + checks = await this._ctx.get( + DbQueryAIExtensionBindings.GlobalContext, + {optional: true}, + ); + } catch { + checks = undefined; + } + const chain = RunnableSequence.from([ this.prompt, this.sqlllm, @@ -76,7 +114,7 @@ export class AskAboutDatasetTool implements IGraphTool { question: args.question, schema: compressedSchema, context: [ - ...(this.checks ?? []), + ...(checks ?? []), ...this.dbSchemaHelper.getTablesContext(compressedSchema), ].join('\n'), }); diff --git a/src/components/db-query/tools/get-data-as-dataset.tool.ts b/src/components/db-query/tools/get-data-as-dataset.tool.ts index 71a3579..dda742e 100644 --- a/src/components/db-query/tools/get-data-as-dataset.tool.ts +++ b/src/components/db-query/tools/get-data-as-dataset.tool.ts @@ -8,10 +8,33 @@ import {DbQueryConfig, Errors, GenerationError} from '../types'; import {DbQueryAIExtensionBindings} from '../keys'; import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../constant'; -@graphTool() +@graphTool({ + description: `Query tool for generating SQL queries for a users request. Use it only when the user needs raw tabular data from the database. + Do not use this tool if the user's request involves trends, growth, decline, comparisons, distributions, patterns, or any form of analytical insight — use the 'generate-visualization' tool instead. + Note that it does not return the query, instead only a dataset ID that is not relevant to the user. + It internally fires an event that renders a grid for the dataset on the UI for the user to see.`, + inputSchema: z.object({ + prompt: z + .string() + .describe( + `Prompt from the user that will be used for generating an SQL query and create a dataset from it.`, + ), + }), +}) export class GetDataAsDatasetTool implements IGraphTool { needsReview = false; key = 'get-data-as-dataset'; + description = `Query tool for generating SQL queries for a users request. Use it only when the user needs raw tabular data from the database. + Do not use this tool if the user's request involves trends, growth, decline, comparisons, distributions, patterns, or any form of analytical insight — use the 'generate-visualization' tool instead. + Note that it does not return the query, instead only a dataset ID that is not relevant to the user. + It internally fires an event that renders a grid for the dataset on the UI for the user to see.`; + inputSchema = z.object({ + prompt: z + .string() + .describe( + `Prompt from the user that will be used for generating an SQL query and create a dataset from it.`, + ), + }); constructor( @service(DbQueryGraph) private readonly queryPipeline: DbQueryGraph, diff --git a/src/components/db-query/tools/improve-dataset.tool.ts b/src/components/db-query/tools/improve-dataset.tool.ts index 4e51310..39791da 100644 --- a/src/components/db-query/tools/improve-dataset.tool.ts +++ b/src/components/db-query/tools/improve-dataset.tool.ts @@ -8,10 +8,35 @@ import {DbQueryConfig, Errors, GenerationError} from '../types'; import {DbQueryAIExtensionBindings} from '../keys'; import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../constant'; -@graphTool() +@graphTool({ + description: + 'Tool for improving an existing dataset based on user feedback. It takes a dataset ID and a prompt describing the desired changes, and returns an updated dataset. Call this only if you have a valid dataset ID available.', + inputSchema: z.object({ + datasetId: z + .string() + .describe(`UUID ID of the existing dataset to improve`), + prompt: z + .string() + .describe( + `A description of what changes or improvements the user wants in the existing dataset.`, + ), + }), +}) export class ImproveDatasetTool implements IGraphTool { needsReview = false; key = 'improve-dataset'; + description = + 'Tool for improving an existing dataset based on user feedback. It takes a dataset ID and a prompt describing the desired changes, and returns an updated dataset. Call this only if you have a valid dataset ID available.'; + inputSchema = z.object({ + datasetId: z + .string() + .describe(`UUID ID of the existing dataset to improve`), + prompt: z + .string() + .describe( + `A description of what changes or improvements the user wants in the existing dataset.`, + ), + }); constructor( @service(DbQueryGraph) private readonly queryPipeline: DbQueryGraph, diff --git a/src/components/visualization/tools/generate-visualization.tool.ts b/src/components/visualization/tools/generate-visualization.tool.ts index 454a837..42eda85 100644 --- a/src/components/visualization/tools/generate-visualization.tool.ts +++ b/src/components/visualization/tools/generate-visualization.tool.ts @@ -7,10 +7,61 @@ import {VisualizationGraph} from '../visualization.graph'; import {VISUALIZATION_KEY} from '../keys'; import {IVisualizer} from '../types'; -@graphTool() +@graphTool({ + description: `Generates a visualization for the user's request. It takes in a prompt and an optional dataset ID. +If the user's request involves trends, growth, decline, comparisons, distributions, patterns, correlations, or any analytical insight, ALWAYS use this tool instead of 'get-data-as-dataset'. +No need to call 'get-data-as-dataset' tool before this — if the dataset ID is not provided, this tool will internally fetch the data to be visualized. +It does not return anything, instead it fires an event internally that renders the visualization on the UI for the user to see.`, + inputSchema: z.object({ + prompt: z + .string() + .describe( + `Prompt from the user that will be used for generating the visualization.`, + ), + datasetId: z + .string() + .optional() + .describe( + `ID of the dataset that needs to be visualized. Use the dataset ID from 'get-data-as-dataset' or 'improve-dataset' tool if available. If not provided, the tool will internally fetch the data.`, + ), + type: z + .string() + .optional() + .describe( + `Type of visualization to be generated (e.g. bar, line, pie). If not provided, the system will decide the best visualization based on the data and prompt.`, + ), + }), +}) export class GenerateVisualizationTool implements IGraphTool { needsReview = false; key = 'generate-visualization'; + // Note: the `type` field enum values are populated dynamically from available + // visualizer bindings at request time. The static schema here omits the enum + // constraint so the Mastra agent can be registered without resolving the + // full visualization graph at startup. + description = `Generates a visualization for the user's request. It takes in a prompt and an optional dataset ID. +If the user's request involves trends, growth, decline, comparisons, distributions, patterns, correlations, or any analytical insight, ALWAYS use this tool instead of 'get-data-as-dataset'. +No need to call 'get-data-as-dataset' tool before this — if the dataset ID is not provided, this tool will internally fetch the data to be visualized. +It does not return anything, instead it fires an event internally that renders the visualization on the UI for the user to see.`; + inputSchema = z.object({ + prompt: z + .string() + .describe( + `Prompt from the user that will be used for generating the visualization.`, + ), + datasetId: z + .string() + .optional() + .describe( + `ID of the dataset that needs to be visualized. Use the dataset ID from 'get-data-as-dataset' or 'improve-dataset' tool if available. If not provided, the tool will internally fetch the data.`, + ), + type: z + .string() + .optional() + .describe( + `Type of visualization to be generated (e.g. bar, line, pie). If not provided, the system will decide the best visualization based on the data and prompt.`, + ), + }); constructor( @service(VisualizationGraph) private readonly visualizationGraph: VisualizationGraph, diff --git a/src/decorators/tool.decorator.ts b/src/decorators/tool.decorator.ts index 95c4e2e..470155e 100644 --- a/src/decorators/tool.decorator.ts +++ b/src/decorators/tool.decorator.ts @@ -1,12 +1,32 @@ import {injectable} from '@loopback/core'; import {TOOL_NAME, TOOL_TAG} from '../constant'; -export function graphTool(): ClassDecorator { +export type GraphToolMetadata = { + /** + * Human-readable description shown to the LLM when deciding which tool to call. + * Stored as a binding tag so the Mastra bridge factory can read it at startup + * WITHOUT resolving the tool instance (which may have request-scoped dependencies). + */ + description?: string; + /** + * Zod schema describing the tool's input, stored as a binding tag for the same reason. + * Typed as `unknown` to avoid a hard dependency on `zod` in the decorator module. + */ + inputSchema?: unknown; +}; + +export function graphTool(metadata?: GraphToolMetadata): ClassDecorator { return function (target: T) { injectable({ tags: { [TOOL_NAME]: target.name, [TOOL_TAG]: true, + ...(metadata?.description !== undefined && { + toolDescription: metadata.description, + }), + ...(metadata?.inputSchema !== undefined && { + toolInputSchema: metadata.inputSchema, + }), }, })(target); }; diff --git a/src/graphs/types.ts b/src/graphs/types.ts index b727039..2249256 100644 --- a/src/graphs/types.ts +++ b/src/graphs/types.ts @@ -37,9 +37,19 @@ export type SavedMessage = HumanMessage | AIMessage | ToolMessage; /** * Minimal runtime tool contract shared across LangGraph and Mastra-compatible tooling. + * + * `description` and `schema` are optional but MUST be populated by any tool that + * needs to work with Mastra. LangChain StructuredTool instances (returned by + * `build()` / `createTool()`) already carry these properties; they just were not + * exposed through this interface previously. */ export interface IRuntimeTool { name: string; + /** Human-readable description shown to the LLM when deciding which tool to call. */ + description?: string; + /** Zod schema describing the tool's input. Typed as `unknown` to avoid a hard + * dependency on `zod` in consuming packages; cast to `ZodObject` at the call site. */ + schema?: unknown; invoke(input: TArgs): Promise; } @@ -48,6 +58,23 @@ export interface IRuntimeTool { */ export interface IGraphTool { key: string; + /** + * Human-readable description exposed at Mastra agent registration time. + * + * Populate this as a class property so the Mastra bridge factory can read it + * WITHOUT calling `createTool()` / `build()`. This avoids resolving the full + * dependency tree (e.g. graph nodes with request-scoped dependencies) at + * application startup. + */ + description?: string; + /** + * Zod schema for the tool's input, exposed at Mastra agent registration time. + * + * Same rationale as `description` — must be accessible without a `createTool()` call. + * Type is `unknown` to avoid a hard dependency on `zod` in consuming code; cast to + * `ZodObject` at the call site. + */ + inputSchema?: unknown; createTool?(config: RunnableConfig): Promise; /** * @deprecated Use `createTool()`. diff --git a/src/index.ts b/src/index.ts index 8d9e020..d2358ec 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,6 +5,7 @@ export * from './controllers'; export * from './decorators'; export * from './graphs'; export * from './keys'; +export * from './mastra'; export * from './providers'; export * from './services'; export * from './transports'; diff --git a/src/mastra/chat/mappers/message.mapper.ts b/src/mastra/chat/mappers/message.mapper.ts new file mode 100644 index 0000000..759d8f3 --- /dev/null +++ b/src/mastra/chat/mappers/message.mapper.ts @@ -0,0 +1,60 @@ +import {AIMessage, BaseMessage, ToolMessage} from '@langchain/core/messages'; +import {getTextContent} from '../../../utils'; +import {MastraAgentMessage, MastraAssistantContentPart} from '../../types'; + +/** + * Converts a LangChain `BaseMessage[]` to the `MastraAgentMessage[]` format + * expected by `IMastraChatAgentRunnable.stream()`. + * + * The output shapes are compatible with the AI SDK `CoreMessage` type so + * real `@mastra/core` Agent instances accept them without further adaptation. + */ +export function toMastraMessages( + messages: BaseMessage[], +): MastraAgentMessage[] { + const result: MastraAgentMessage[] = []; + for (const msg of messages) { + const type = msg._getType(); + if (type === 'system') { + result.push({role: 'system', content: getTextContent(msg.content)}); + } else if (type === 'human') { + result.push({role: 'user', content: getTextContent(msg.content)}); + } else if (type === 'ai') { + const aiMsg = msg as AIMessage; + if (aiMsg.tool_calls?.length) { + const parts: MastraAssistantContentPart[] = []; + const text = getTextContent(aiMsg.content); + if (text.trim()) parts.push({type: 'text', text}); + for (const tc of aiMsg.tool_calls) { + parts.push({ + type: 'tool-call', + toolCallId: tc.id ?? '', + toolName: tc.name, + args: tc.args as Record, + }); + } + result.push({role: 'assistant', content: parts}); + } else { + result.push({ + role: 'assistant', + content: getTextContent(aiMsg.content), + }); + } + } else if (type === 'tool') { + const toolMsg = msg as ToolMessage; + result.push({ + role: 'tool', + content: [ + { + type: 'tool-result', + toolCallId: toolMsg.tool_call_id, + toolName: toolMsg.name ?? '', + result: toolMsg.content, + }, + ], + }); + } + // Other message types are dropped — not used in this flow + } + return result; +} diff --git a/src/mastra/chat/mastra-chat.agent.ts b/src/mastra/chat/mastra-chat.agent.ts new file mode 100644 index 0000000..d5f1a99 --- /dev/null +++ b/src/mastra/chat/mastra-chat.agent.ts @@ -0,0 +1,273 @@ +import {BaseMessage, HumanMessage} from '@langchain/core/messages'; +import {inject, injectable, BindingScope, service} from '@loopback/core'; +import {HttpErrors} from '@loopback/rest'; +import {LLMStreamEvent, LLMStreamEventType} from '../../graphs/event.types'; +import {ChatStore} from '../../graphs/chat/chat.store'; +import {AiIntegrationBindings} from '../../keys'; +import {IMastraBridge} from '../../services/mastra-bridge.service'; +import {AIIntegrationConfig, RuntimeLLMProvider, ToolStore} from '../../types'; +import {IMastraChatAgentRunnable} from '../types'; +import { + mastraRequestToolStore, + mastraRequestWriterStore, +} from '../request-tool-store'; +import {compressContextIfNeeded} from './steps/context-compression.step'; +import {initSession} from './steps/init-session.step'; +import {summariseOneFile} from './steps/summarise-file.step'; +import {handleStream} from './steps/stream-handler.step'; +import {toMastraMessages} from './mappers/message.mapper'; +import {accumulateUsage} from './utils/token-accumulator.util'; +import {TokenAccumulator} from './types/chat.types'; + +const debug = require('debug')('ai-integration:mastra:chat-agent'); + +/** + * Registered name used to retrieve the chat agent via the Mastra bridge. + */ +export const MASTRA_CHAT_AGENT_NAME = 'chat-agent'; + +/** + * Mastra-runtime chat execution service. + * + * Orchestrates the full chat pipeline by delegating to focused step modules: + * + * 1. `initSession` — load/create chat, persist human message, build history + * 2. `summariseOneFile`* — pre-process uploaded files before the agent sees them + * 3. `compressContextIfNeeded`— trim history if it exceeds the token budget + * 4. `toMastraMessages` — convert LangChain messages → Mastra/AI SDK format + * 5. Bridge agent execution — Mastra Agent owns the CallLLM ↔ RunTool loop + * 6. `handleStream` — adapt Mastra events → LLMStreamEvent, persist steps + * 7. EndSession — emit TokenCount, update DB + */ +@injectable({scope: BindingScope.REQUEST}) +export class MastraChatAgent { + constructor( + @inject(AiIntegrationBindings.MastraBridge) + private readonly mastraBridge: IMastraBridge, + @inject(AiIntegrationBindings.FileLLM) + private readonly fileLLM: RuntimeLLMProvider, + @inject(AiIntegrationBindings.Config) + private readonly aiConfig: AIIntegrationConfig, + @inject(AiIntegrationBindings.Tools) + private readonly tools: ToolStore, + @inject(AiIntegrationBindings.SystemContext, {optional: true}) + private readonly systemContext: string[] | undefined, + @service(ChatStore) + private readonly chatStore: ChatStore, + ) {} + + // --------------------------------------------------------------------------- + // Public API + // --------------------------------------------------------------------------- + + /** + * Runs the full Mastra chat pipeline and yields `LLMStreamEvent` values + * that are 100 % compatible with the existing SSE transport. + */ + async *execute( + prompt: string, + files: Express.Multer.File[] | undefined, + abort: AbortSignal, + id?: string, + ): AsyncGenerator { + files = files ?? []; + // ── Step 1: InitSession ────────────────────────────────────────────────── + const {chatId, baseMessages, userMessage} = await initSession( + prompt, + id, + this.chatStore, + this._buildSystemPrompt(), + ); + if (!id) { + yield {type: LLMStreamEventType.Init, data: {sessionId: chatId}}; + } + + // ── Step 2: SummariseFile (pre-processing — outside the agent) ─────────── + let finalPrompt = prompt; + const tokens: TokenAccumulator = {input: 0, output: 0, map: {}}; + + for (const file of files) { + yield { + type: LLMStreamEventType.Log, + data: `Processing file: ${file.originalname}`, + }; + yield { + type: LLMStreamEventType.Status, + data: `Reading file: ${file.originalname}`, + }; + finalPrompt = await summariseOneFile({ + file, + currentPrompt: finalPrompt, + chatId, + userMessage, + tokens, + fileLLM: this.fileLLM, + chatStore: this.chatStore, + }); + } + + // ── Step 3: Build message list + fallback context compression ──────────── + const rawMessages: BaseMessage[] = [ + ...baseMessages, + new HumanMessage({content: finalPrompt}), + ]; + const compressedMessages = await compressContextIfNeeded( + rawMessages, + this.aiConfig.maxTokenCount, + ); + + // ── Step 4: Map messages to Mastra format ───────────────────────────────── + const agentMessages = toMastraMessages(compressedMessages); + + // ── Step 5: Obtain agent from bridge ───────────────────────────────────── + const agent = this.mastraBridge.getTypedAgent( + MASTRA_CHAT_AGENT_NAME, + ); + if (!agent) { + throw new HttpErrors.NotImplemented( + `Mastra chat agent '${MASTRA_CHAT_AGENT_NAME}' is not registered. ` + + 'Bind a MastraRuntimeFactory at AiIntegrationBindings.MastraRuntimeFactory ' + + "that registers a chat agent under the name 'chat-agent'.", + ); + } + + // ── Step 5b: Build per-request tool map and register for bridge tools ───── + const requestToolMap = new Map< + string, + import('../../graphs/types').IRuntimeTool + >(); + for (const graphTool of this.tools.list) { + try { + // Build tools at request time (they may have request-scoped dependencies). + const rt = graphTool.createTool + ? await graphTool.createTool({}) + : graphTool.build + ? await graphTool.build({}) + : null; + if (rt) { + // Wrap invoke to inject the lazy writer into the LangGraph config so + // internal graph nodes (e.g. RenderVisualizationNode, SaveDatasetNode) + // get config.writer and their ToolStatus events reach the SSE stream. + // Tools built with createTool() ignore the config param, but LangGraph + // tool.invoke(input, { writer }) passes it straight to every node. + const lazyWriter = { + writer: (event: unknown) => + mastraRequestWriterStore.get(chatId)?.( + event as import('../../graphs/event.types').LLMStreamEvent, + ), + }; + const wrappedRt: import('../../graphs/types').IRuntimeTool = { + name: rt.name, + description: rt.description, + schema: rt.schema, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + invoke: (input: unknown) => (rt as any).invoke(input, lazyWriter), + }; + + // Store by kebab key (e.g. 'get-data-as-dataset') — LangGraph path + requestToolMap.set(graphTool.key, wrappedRt); + // Also store by class name (e.g. 'GetDataAsDatasetTool') — Mastra factory path + const className = (graphTool as object).constructor?.name; + if (className && className !== graphTool.key) { + requestToolMap.set(className, wrappedRt); + } + } + } catch (err) { + debug( + 'Could not build tool %s for request registry: %o', + graphTool.key, + err, + ); + } + } + mastraRequestToolStore.set(chatId, requestToolMap); + debug( + 'Registered %d tools for chatId %s: %s', + requestToolMap.size, + chatId, + [...requestToolMap.keys()].join(', '), + ); + + // ── Step 6: Stream from bridge agent, adapt events to LLMStreamEvent ───── + debug( + 'Delegating to Mastra bridge agent — %d messages, %d tools', + agentMessages.length, + requestToolMap.size, + ); + try { + const agentStream = await agent.stream(agentMessages, { + signal: abort, + threadId: chatId, + }); + + for await (const event of handleStream({ + agentStream, + abort, + tools: this.tools, + chatId, + chatStore: this.chatStore, + tokens, + })) { + yield event; + } + + // Fallback: use stream-level usage promise when no per-step usage arrived + if (tokens.input === 0 && tokens.output === 0) { + try { + const streamUsage = await agentStream.usage; + if (streamUsage) { + // Mastra LanguageModelUsage uses inputTokens/outputTokens + accumulateUsage( + { + promptTokens: (streamUsage as unknown as {inputTokens?: number}) + .inputTokens, + completionTokens: ( + streamUsage as unknown as {outputTokens?: number} + ).outputTokens, + }, + 'mastra-chat', + tokens, + ); + } + } catch { + // usage not available — proceed without it + } + } + + // ── Step 7: EndSession ───────────────────────────────────────────────── + yield { + type: LLMStreamEventType.TokenCount, + data: {inputTokens: tokens.input, outputTokens: tokens.output}, + }; + await this.chatStore.updateCounts( + chatId, + tokens.input, + tokens.output, + tokens.map, + ); + } finally { + // Always clean up per-request stores so memory doesn't leak. + mastraRequestToolStore.delete(chatId); + mastraRequestWriterStore.delete(chatId); + } + } + + // --------------------------------------------------------------------------- + // Private helpers (instance-state dependent — kept in agent) + // --------------------------------------------------------------------------- + + private _buildSystemPrompt(): string { + return [ + `You are a helpful AI assistant. You MUST always use one of the available tools to handle the user's request. Never respond with just text on the first message — always call the closest matching tool, even if you are unsure. The tool will reject the request if it is not suitable.`, + `If you are not sure about the result, you can ask the user to review the result and provide feedback.`, + `Only use a single tool in a single message, but you can use multiple tools over subsequent messages if it could help with the user's requirements.`, + `If the user provides feedback, you can use that feedback to improve the result.`, + `Do not write any redundant messages before or after tool calls, be as concise as possible.`, + `Do not hallucinate details or make up information.`, + `Do not make assumptions about user's intent beyond what is explicitly provided in the prompt, and keep this in mind while calling tools.`, + `Do not use technical jargon in the response, show any internal IDs, or implementation details to the user.`, + `Current date is ${new Date().toDateString()}`, + ...(this.systemContext ?? []), + ].join('\n'); + } +} diff --git a/src/mastra/chat/steps/context-compression.step.ts b/src/mastra/chat/steps/context-compression.step.ts new file mode 100644 index 0000000..0e19a11 --- /dev/null +++ b/src/mastra/chat/steps/context-compression.step.ts @@ -0,0 +1,47 @@ +import {BaseMessage, trimMessages} from '@langchain/core/messages'; +import {DEFAULT_MAX_TOKEN_COUNT} from '../../../constant'; +import {approxTokenCounter} from '../../../utils'; + +const debug = require('debug')('ai-integration:mastra:chat-agent'); + +/** + * Mirrors `ContextCompressionNode`: trims the message list to `maxTokenCount` + * using a `last` strategy (keeps the most recent messages and always retains + * the system message). + * + * This is a pre-call guard applied once before the agent call. The Mastra + * Agent may apply its own internal compression; this prevents oversized + * initial prompts from being sent at all. + * + * @param messages Full message list including the new human message. + * @param maxTokenCount Token budget from `AIIntegrationConfig`; falls back to + * `MAX_TOKEN_COUNT` env var or the package default. + */ +export async function compressContextIfNeeded( + messages: BaseMessage[], + maxTokenCount: number | undefined, +): Promise { + const limit = +( + maxTokenCount ?? + process.env.MAX_TOKEN_COUNT ?? + DEFAULT_MAX_TOKEN_COUNT + ); + const tokenCount = messages.reduce( + (acc, m) => acc + approxTokenCounter(m.content), + 0, + ); + if (tokenCount > limit) { + debug( + 'Compressing context before agent call: %d tokens > limit %d', + tokenCount, + limit, + ); + return trimMessages(messages, { + maxTokens: limit, + strategy: 'last', + tokenCounter: approxTokenCounter, + includeSystem: true, + }); + } + return messages; +} diff --git a/src/mastra/chat/steps/init-session.step.ts b/src/mastra/chat/steps/init-session.step.ts new file mode 100644 index 0000000..b361ad9 --- /dev/null +++ b/src/mastra/chat/steps/init-session.step.ts @@ -0,0 +1,59 @@ +import { + BaseMessage, + HumanMessage, + SystemMessage, +} from '@langchain/core/messages'; +import {ChatStore} from '../../../graphs/chat/chat.store'; +import {Message} from '../../../models'; + +/** + * Result returned by `initSession`. + */ +export interface InitSessionResult { + chatId: string; + baseMessages: BaseMessage[]; + userMessage: Message; +} + +/** + * Mirrors `InitSessionNode`: loads or creates the chat, persists the human + * message, and rebuilds message history from the DB. + * + * @param prompt The raw user prompt for this turn. + * @param id Existing chat ID when continuing a session; undefined for new. + * @param chatStore LoopBack chat persistence service. + * @param systemPrompt Pre-built system prompt string. + */ +export async function initSession( + prompt: string, + id: string | undefined, + chatStore: ChatStore, + systemPrompt: string, +): Promise { + const chat = await chatStore.init(prompt, id); + const savedUserMessage = await chatStore.addHumanMessage( + chat.id, + new HumanMessage({content: prompt}), + ); + const history = await formatHistory(chat.messages ?? [], chatStore); + const systemMessage = new SystemMessage({content: systemPrompt}); + return { + chatId: chat.id, + baseMessages: [systemMessage, ...history], + userMessage: savedUserMessage, + }; +} + +/** + * Converts DB `Message` rows back to LangChain `BaseMessage` instances. + * Undefined entries (unsupported message roles) are filtered out. + */ +async function formatHistory( + dbMessages: Message[], + chatStore: ChatStore, +): Promise { + const converted = await Promise.all( + dbMessages.map(m => chatStore.toMessage(m)), + ); + return converted.filter((m): m is BaseMessage => m !== undefined); +} diff --git a/src/mastra/chat/steps/save-step.step.ts b/src/mastra/chat/steps/save-step.step.ts new file mode 100644 index 0000000..b3eee8f --- /dev/null +++ b/src/mastra/chat/steps/save-step.step.ts @@ -0,0 +1,66 @@ +import {AIMessage, ToolMessage} from '@langchain/core/messages'; +import {ChatStore} from '../../../graphs/chat/chat.store'; +import {ToolStore} from '../../../types'; +import {StepBuffer} from '../types/chat.types'; + +const debug = require('debug')('ai-integration:mastra:chat-agent'); + +/** + * Persists a completed agent step (AI message + tool messages) to the DB. + * + * Called at every `step-finish` event in the stream handler. Steps that + * contain neither text output nor tool calls are silently skipped. + * + * @param chatId Active chat session identifier. + * @param step Buffered data accumulated since the last `step-finish`. + * @param tools Tool registry — used to extract display values and metadata. + * @param chatStore LoopBack chat persistence service. + */ +export async function saveStep( + chatId: string, + step: StepBuffer, + tools: ToolStore, + chatStore: ChatStore, +): Promise { + const text = step.textChunks.join(''); + const hasToolCalls = step.toolCalls.length > 0; + if (!text.trim() && !hasToolCalls) return; + + const aiMsg = new AIMessage({ + content: text || ' ', + // eslint-disable-next-line @typescript-eslint/naming-convention + tool_calls: hasToolCalls + ? step.toolCalls.map(tc => ({ + id: tc.id, + name: tc.name, + args: tc.args, + type: 'tool_call' as const, + })) + : [], + }); + const savedAiMsg = await chatStore.addAIMessage(chatId, aiMsg); + + for (const toolCall of step.toolCalls) { + const toolResult = step.toolResults.get(toolCall.id); + if (!toolResult) continue; + const toolDef = tools.map[toolCall.name]; + if (!toolDef) { + debug('Unknown tool during save: %s', toolCall.name); + } + const output = toolDef?.getValue?.(toolResult.result) ?? toolResult.result; + const metadata = toolDef?.getMetadata?.(toolResult.result) ?? {}; + const toolMsg = new ToolMessage({ + name: toolCall.name, + content: String(output), + // eslint-disable-next-line @typescript-eslint/naming-convention + tool_call_id: toolCall.id, + }); + await chatStore.addToolMessage( + chatId, + toolMsg, + metadata, + savedAiMsg, + toolCall.args, + ); + } +} diff --git a/src/mastra/chat/steps/stream-handler.step.ts b/src/mastra/chat/steps/stream-handler.step.ts new file mode 100644 index 0000000..eaae622 --- /dev/null +++ b/src/mastra/chat/steps/stream-handler.step.ts @@ -0,0 +1,268 @@ +import {LLMStreamEvent, LLMStreamEventType} from '../../../graphs/event.types'; +import {ToolStatus} from '../../../graphs/types'; +import {ToolStore} from '../../../types'; +import {ChatStore} from '../../../graphs/chat/chat.store'; +import {MastraAgentStreamOutput} from '../../types'; +import {StepBuffer, TokenAccumulator} from '../types/chat.types'; +import {accumulateUsage} from '../utils/token-accumulator.util'; +import {safeStringify} from '../utils/safe-json.util'; +import {saveStep} from './save-step.step'; +import {mastraRequestWriterStore} from '../../request-tool-store'; + +const debug = require('debug')('ai-integration:mastra:chat-agent'); + +/** + * Parameters accepted by `handleStream`. + */ +export interface HandleStreamParams { + /** Streaming output returned by the Mastra bridge agent. */ + agentStream: MastraAgentStreamOutput; + /** Forwarded abort signal — aborts the iteration early when fired. */ + abort: AbortSignal; + /** Tool registry — used for display-value extraction and DB persistence. */ + tools: ToolStore; + /** Active chat session identifier — used for DB persistence. */ + chatId: string; + /** LoopBack chat persistence service. */ + chatStore: ChatStore; + /** + * Token accumulator shared with the caller. + * Mutated in place as usage events arrive. + */ + tokens: TokenAccumulator; +} + +function emptyStep(): StepBuffer { + return { + textChunks: [], + toolCalls: [], + toolResults: new Map(), + pendingToolEvents: [], + }; +} + +/** + * Iterates `agentStream.fullStream`, adapts Mastra events to `LLMStreamEvent`, + * persists each completed step to the DB, and accumulates token usage. + * + * ### Event ordering + * Text-deltas are buffered and emitted as a single `Message` event at + * `step-finish`. This matches LangGraph's behaviour (one `Message` event per + * LLM generation) so the frontend renders one bubble per step rather than one + * bubble per SSE chunk. + * + * `Tool` (running indicator) events are also buffered until `step-finish` so + * they appear *after* the preamble text bubble rather than interspersed. + * + * `ToolStatus` events emitted by internal tool graphs (e.g. VisualizationGraph) + * via `config.writer` are captured through `mastraRequestWriterStore` and + * drained at `tool-result` time so they arrive in the SSE stream in the correct + * order: text → tool indicator → tool status (visualisation config, dataset id…). + * + * The generator terminates when: + * - the full stream is exhausted, OR + * - `abort` is fired. + */ +export async function* handleStream( + params: HandleStreamParams, +): AsyncGenerator { + const {agentStream, abort, tools, chatId, chatStore, tokens} = params; + + // Reverse-map: Mastra class name (e.g. 'GetDataAsDatasetTool') → kebab key + // (e.g. 'get-data-as-dataset') so the `tool` SSE event matches what the + // frontend expects (the name used when the tool was originally registered). + const toolClassToKey = new Map(); + for (const t of tools.list) { + const cn = (t as object).constructor?.name; + if (cn && cn !== t.key) toolClassToKey.set(cn, t.key); + } + + // ── Writer queue: captures ToolStatus events from internal tool graphs ──── + // Tools are built with a lazy writer that pushes here. We drain it at + // tool-result time (tools have finished executing by then) and re-emit the + // events in the correct position: after the Tool(Running) indicator. + const writerQueue: LLMStreamEvent[] = []; + mastraRequestWriterStore.set(chatId, (event: LLMStreamEvent) => + writerQueue.push(event), + ); + + // Holds ToolStatus events captured from the writerQueue until step-finish + // so they are emitted after the preamble text and Tool(Running) indicators. + let pendingToolStatusEvents: LLMStreamEvent[] = []; + + let step: StepBuffer = emptyStep(); + + try { + for await (const event of agentStream.fullStream) { + if (abort.aborted) { + debug('Stream aborted'); + break; + } + + // Mastra wraps all event data under `event.payload` + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const p = (event as any).payload ?? {}; + + switch (event.type) { + case 'text-delta': { + const delta = String(p.text ?? ''); + if (delta) { + step.textChunks.push(delta); + // Do NOT yield here — accumulate until step-finish so the entire + // step's text is emitted as one Message event (one bubble). + } + break; + } + + case 'tool-call': { + const toolCallId = String(p.toolCallId ?? ''); + const toolName = String(p.toolName ?? ''); + // Map class name back to kebab key for frontend compatibility. + const toolKey = toolClassToKey.get(toolName) ?? toolName; + const args = (p.args ?? {}) as Record; + step.toolCalls.push({id: toolCallId, name: toolName, args}); + debug('Tool call: %s (%s)', toolName, toolCallId); + + // Buffer the Tool(Running) SSE event — emit at step-finish so it + // appears after the preamble text bubble. + step.pendingToolEvents.push({ + type: LLMStreamEventType.Tool, + data: { + id: toolCallId, + tool: toolKey, + data: args, + status: ToolStatus.Running, + }, + } as LLMStreamEvent); + step.pendingToolEvents.push({ + type: LLMStreamEventType.Log, + data: `Running tool: ${toolKey} with args: ${safeStringify(args)}`, + }); + break; + } + + case 'tool-result': { + const toolCallId = String(p.toolCallId ?? ''); + const toolName = String(p.toolName ?? ''); + const result = p.result; + step.toolResults.set(toolCallId, {result, toolName}); + + // The tool graph has finished executing by this point. Drain the + // writer queue so ToolStatus events (e.g. visualisation config, + // dataset ID) are captured and held for emission at step-finish. + // Patch the toolCallId as `id` into any ToolStatus events that lack + // it — the graph nodes (SaveDataSetNode, etc.) don't know the call ID. + const drained = writerQueue.splice(0).map(ev => { + if ( + ev.type === LLMStreamEventType.ToolStatus && + !(ev.data as {id?: string}).id + ) { + return { + ...ev, + data: {...(ev.data as object), id: toolCallId}, + } as LLMStreamEvent; + } + return ev; + }); + pendingToolStatusEvents.push(...drained); + + const toolDef = tools.map[toolName]; + if (!toolDef) { + debug('Unknown tool: %s', toolName); + } + const output = toolDef?.getValue?.(result) ?? result; + debug('Tool result for %s: %j', toolName, output); + // Log is filtered by SSE transport — no bubble created. + step.pendingToolEvents.push({ + type: LLMStreamEventType.Log, + data: `Tool output: ${safeStringify(output)}`, + }); + break; + } + + case 'step-finish': { + // 1. Emit text as a single Message bubble, but ONLY when this step + // has no tool calls — preamble text before a tool call (e.g. "I" + // before calling GetDataAsDatasetTool) is suppressed to avoid + // creating a separate empty-looking bubble. + const text = step.textChunks.join(''); + if (text.trim() && step.toolCalls.length === 0) { + yield { + type: LLMStreamEventType.Message, + data: {message: text}, + }; + } + + // 2. Emit buffered Tool(Running) + Log + ToolStatus events. + for (const ev of step.pendingToolEvents) { + yield ev; + } + for (const ev of pendingToolStatusEvents) { + yield ev; + } + pendingToolStatusEvents = []; + + // 3. Persist to DB. + await saveStep(chatId, step, tools, chatStore); + + // 4. Collect per-step token usage. + const stepUsage = p.output?.usage as + | {inputTokens?: number; outputTokens?: number} + | undefined; + if (stepUsage) { + accumulateUsage( + { + promptTokens: stepUsage.inputTokens, + completionTokens: stepUsage.outputTokens, + }, + 'mastra-chat', + tokens, + ); + } + + step = emptyStep(); + break; + } + + case 'finish': { + const finishUsage = p.output?.usage as + | {inputTokens?: number; outputTokens?: number} + | undefined; + if (finishUsage && tokens.input === 0 && tokens.output === 0) { + accumulateUsage( + { + promptTokens: finishUsage.inputTokens, + completionTokens: finishUsage.outputTokens, + }, + 'mastra-chat', + tokens, + ); + } + break; + } + + default: + break; + } + } + + // Flush any partial step that didn't receive a `step-finish` event. + const remainingText = step.textChunks.join(''); + if (remainingText.trim()) { + yield {type: LLMStreamEventType.Message, data: {message: remainingText}}; + } + for (const ev of step.pendingToolEvents) { + yield ev; + } + for (const ev of pendingToolStatusEvents) { + yield ev; + } + if (step.textChunks.length || step.toolCalls.length) { + await saveStep(chatId, step, tools, chatStore); + } + } finally { + // Unregister writer — the agent's finally block also calls delete, but + // removing it here too ensures cleanup even if iteration is aborted. + mastraRequestWriterStore.delete(chatId); + } +} diff --git a/src/mastra/chat/steps/summarise-file.step.ts b/src/mastra/chat/steps/summarise-file.step.ts new file mode 100644 index 0000000..2690d46 --- /dev/null +++ b/src/mastra/chat/steps/summarise-file.step.ts @@ -0,0 +1,99 @@ +import {AIMessage} from '@langchain/core/messages'; +import {ChatStore} from '../../../graphs/chat/chat.store'; +import {Message} from '../../../models'; +import {resolveLegacyLLMProvider, RuntimeLLMProvider} from '../../../types'; +import {mergeAttachments, stripThinkingTokens} from '../../../utils'; +import {TokenAccumulator} from '../types/chat.types'; +import {accumulateUsage} from '../utils/token-accumulator.util'; + +/** + * Parameters for `summariseOneFile`. + */ +export interface SummariseFileParams { + file: Express.Multer.File; + currentPrompt: string; + chatId: string; + userMessage: Message; + /** Mutated in place — updated with file-summarisation token usage. */ + tokens: TokenAccumulator; + fileLLM: RuntimeLLMProvider; + chatStore: ChatStore; +} + +/** + * Mirrors `SummariseFileNode` for a single file. + * + * Invokes the file LLM to produce a concise summary of the file in the + * context of the user prompt, persists the attachment message, and returns + * an updated prompt that embeds the summary. + */ +export async function summariseOneFile( + params: SummariseFileParams, +): Promise { + const {file, currentPrompt, chatId, userMessage, tokens, fileLLM, chatStore} = + params; + + const llm = resolveLegacyLLMProvider(fileLLM); + const fileContent = buildFileContent(file, fileLLM); + const messages = [ + { + role: 'system' as const, + content: buildFileSummaryPrompt(currentPrompt), + }, + { + role: 'user' as const, + content: [{type: 'text', text: currentPrompt}, fileContent], + }, + ]; + + const aiResponse = (await llm.invoke(messages)) as AIMessage; + const usage = aiResponse.usage_metadata; + if (usage) { + accumulateUsage( + { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }, + 'mastra-file', + tokens, + ); + } + + const summary = stripThinkingTokens(aiResponse); + await chatStore.addAttachmentMessage(chatId, userMessage, file, summary); + return mergeAttachments(currentPrompt, file.originalname, summary); +} + +// --------------------------------------------------------------------------- +// Private helpers +// --------------------------------------------------------------------------- + +function buildFileSummaryPrompt(userPrompt: string): string { + return `You are an AI assistant that summarizes file content keeping all the important details in mind. + Make sure that you don't miss any important details and summarize the content in a concise manner. + While summarizing the content, make sure that you keep the user's prompt in mind and summarize the content in a way that it can be used to answer the user's query. + You will be provided with user's original prompt and one file among the files that user provided. + You will summarize the one file at a time so don't worry about the other files mentioned in the user's prompt. + The summary should be relatively short and only contain the important details that are relevant to the user's query. + The output should just be a plain text string without any additional markdown syntax or any special formatting. + Here is the user's prompt: + ${userPrompt} + `; +} + +function buildFileContent( + file: Express.Multer.File, + fileLLM: RuntimeLLMProvider, +): object { + if (fileLLM.getFile) { + return fileLLM.getFile(file); + } + return { + type: 'file', + // eslint-disable-next-line @typescript-eslint/naming-convention + source_type: 'base64', + data: file.buffer?.toString('base64') ?? '', + // eslint-disable-next-line @typescript-eslint/naming-convention + mime_type: file.mimetype, + }; +} diff --git a/src/mastra/chat/types/chat.types.ts b/src/mastra/chat/types/chat.types.ts new file mode 100644 index 0000000..9a65026 --- /dev/null +++ b/src/mastra/chat/types/chat.types.ts @@ -0,0 +1,23 @@ +import {LLMStreamEvent} from '../../../graphs/event.types'; +import {TokenMetadata} from '../../../types'; + +/** + * Accumulates token usage across all LLM calls within a single request cycle. + */ +export interface TokenAccumulator { + input: number; + output: number; + map: TokenMetadata; +} + +/** + * Buffers data within a single agent step so we can persist it atomically once + * the `step-finish` event arrives. + */ +export interface StepBuffer { + textChunks: string[]; + toolCalls: Array<{id: string; name: string; args: Record}>; + toolResults: Map; + /** Buffered `Tool` SSE events — emitted at step-finish after the text bubble. */ + pendingToolEvents: LLMStreamEvent[]; +} diff --git a/src/mastra/chat/utils/safe-json.util.ts b/src/mastra/chat/utils/safe-json.util.ts new file mode 100644 index 0000000..01f52d1 --- /dev/null +++ b/src/mastra/chat/utils/safe-json.util.ts @@ -0,0 +1,11 @@ +/** + * Safely serialises an unknown value to a JSON string. + * Returns a placeholder when the value is not serialisable. + */ +export function safeStringify(obj: unknown): string { + try { + return JSON.stringify(obj, undefined, 2); + } catch { + return '[Unserializable args]'; + } +} diff --git a/src/mastra/chat/utils/token-accumulator.util.ts b/src/mastra/chat/utils/token-accumulator.util.ts new file mode 100644 index 0000000..68d2632 --- /dev/null +++ b/src/mastra/chat/utils/token-accumulator.util.ts @@ -0,0 +1,20 @@ +import {TokenAccumulator} from '../types/chat.types'; + +/** + * Adds per-step token usage into the running accumulator, keyed by model/phase. + * Mutates `tokens` in place. + */ +export function accumulateUsage( + usage: {promptTokens?: number; completionTokens?: number}, + modelKey: string, + tokens: TokenAccumulator, +): void { + const input = usage.promptTokens ?? 0; + const output = usage.completionTokens ?? 0; + tokens.input += input; + tokens.output += output; + const prev = tokens.map[modelKey] ?? {inputTokens: 0, outputTokens: 0}; + prev.inputTokens += input; + prev.outputTokens += output; + tokens.map[modelKey] = prev; +} diff --git a/src/mastra/index.ts b/src/mastra/index.ts new file mode 100644 index 0000000..10128b2 --- /dev/null +++ b/src/mastra/index.ts @@ -0,0 +1,3 @@ +export * from './chat/mastra-chat.agent'; +export * from './request-tool-store'; +export * from './types'; diff --git a/src/mastra/request-tool-store.ts b/src/mastra/request-tool-store.ts new file mode 100644 index 0000000..7b7d31f --- /dev/null +++ b/src/mastra/request-tool-store.ts @@ -0,0 +1,34 @@ +import {LLMStreamEvent} from '../graphs/event.types'; +import {IRuntimeTool} from '../graphs/types'; + +/** + * Per-request IRuntimeTool registry. + * + * `MastraChatAgent` populates this map at the start of each request (keyed by + * `chatId`) and removes the entry when the request finishes. The Mastra tool + * `execute()` callbacks in the host app look up the correct per-request tool + * instance here using the `threadId` that is forwarded through + * `agent.stream({ threadId })`. + * + * This sidesteps the limitation that Mastra tools are registered once at agent + * construction time, while LoopBack tools may have request-scoped dependencies. + */ +export const mastraRequestToolStore = new Map< + string, + Map +>(); + +/** + * Per-request SSE writer registry. + * + * Internal tool graphs (e.g. DbQueryGraph, VisualizationGraph) emit + * `ToolStatus` events via `config.writer` as they execute their nodes. + * `MastraChatAgent` registers a writer callback here (keyed by `chatId`) + * before streaming starts. Tools are built with a lazy writer that delegates + * to this store, so their internal `ToolStatus` events flow back into the + * SSE stream even though the tool itself runs inside the Mastra agent loop. + */ +export const mastraRequestWriterStore = new Map< + string, + (event: LLMStreamEvent) => void +>(); diff --git a/src/mastra/types.ts b/src/mastra/types.ts new file mode 100644 index 0000000..cc436d0 --- /dev/null +++ b/src/mastra/types.ts @@ -0,0 +1,146 @@ +import {AnyObject} from '@loopback/repository'; + +/** + * A single message in the format accepted by Mastra (compatible with AI SDK `CoreMessage`). + * + * The host application's `MastraRuntimeFactory` must return an adapter whose + * `getAgent('chat-agent')` call yields an object implementing `IMastraChatAgentRunnable`. + * That object's `stream()` input is typed against these message shapes. + */ +export type MastraAgentMessage = + | {role: 'system'; content: string} + | {role: 'user'; content: string} + | {role: 'assistant'; content: string | MastraAssistantContentPart[]} + | {role: 'tool'; content: MastraToolResultPart[]}; + +export type MastraAssistantContentPart = + | {type: 'text'; text: string} + | { + type: 'tool-call'; + toolCallId: string; + toolName: string; + args: Record; + }; + +export type MastraToolResultPart = { + type: 'tool-result'; + toolCallId: string; + toolName: string; + result: unknown; +}; + +/** + * Options accepted as the second argument of `IMastraChatAgentRunnable.stream()`. + */ +export interface MastraAgentInput { + signal?: AbortSignal; + /** + * Optional thread / chat identifier forwarded to `agent.stream({ threadId })`. + * Used by the host-app tool `execute()` callbacks to look up the correct + * per-request `IRuntimeTool` instances from `mastraRequestToolStore`. + */ + threadId?: string; + [key: string]: unknown; +} + +/** + * Typed union of all events emitted by the Mastra agent's `fullStream`. + * + * Mastra wraps all event data under a `payload` property. + * The `from`, `runId` fields are always present but not used by our handler. + */ +export type MastraStreamEvent = + | { + type: 'text-delta'; + payload: {text: string; id: string; [key: string]: unknown}; + [key: string]: unknown; + } + | { + type: 'tool-call'; + payload: { + toolCallId: string; + toolName: string; + args?: Record; + [key: string]: unknown; + }; + [key: string]: unknown; + } + | { + type: 'tool-result'; + payload: { + toolCallId: string; + toolName: string; + result: unknown; + args?: Record; + [key: string]: unknown; + }; + [key: string]: unknown; + } + | { + type: 'step-finish'; + payload: { + output: { + usage: { + inputTokens?: number; + outputTokens?: number; + [key: string]: unknown; + }; + [key: string]: unknown; + }; + [key: string]: unknown; + }; + [key: string]: unknown; + } + | { + type: 'finish'; + payload: { + output: { + usage: { + inputTokens?: number; + outputTokens?: number; + [key: string]: unknown; + }; + [key: string]: unknown; + }; + [key: string]: unknown; + }; + [key: string]: unknown; + } + | {type: string; payload?: unknown; [key: string]: unknown}; + +/** + * The async stream object returned by `IMastraChatAgentRunnable.stream()`. + */ +export interface MastraAgentStreamOutput { + /** + * Emits every stream event — text deltas, tool calls, tool results, step + * boundaries, and the final finish event. + */ + fullStream: AsyncIterable; + /** + * Resolves to the aggregate token usage once the full stream is consumed. + * Mastra returns `LanguageModelUsage` with `inputTokens`/`outputTokens`. + * May resolve to `undefined` if the agent does not report usage. + */ + usage: Promise<{inputTokens?: number; outputTokens?: number} | undefined>; +} + +/** + * Minimal contract for a Mastra agent capable of multi-step chat with tool calling. + * + * The library depends only on this interface — it does NOT import `@mastra/core` + * directly. The host application's `MastraRuntimeFactory` is responsible for + * creating and returning an object that satisfies this contract. + * + * A real `@mastra/core` `Agent` instance is directly assignable to this interface. + */ +export interface IMastraChatAgentRunnable { + /** + * Matches the real `@mastra/core` `Agent.stream(messages, options)` signature. + * Messages are passed as the first argument; options (threadId, signal, etc.) as second. + */ + stream( + messages: MastraAgentMessage[], + options?: MastraAgentInput, + ): Promise; +} diff --git a/src/providers/tools.provider.ts b/src/providers/tools.provider.ts index 492f325..96f17b3 100644 --- a/src/providers/tools.provider.ts +++ b/src/providers/tools.provider.ts @@ -29,7 +29,14 @@ export class ToolsProvider implements Provider { for (const binding of bindings) { const toolInstance = await this.context.get(binding.key); tools.push(toolInstance); + // Index by kebab key (LangGraph path, e.g. 'get-data-as-dataset') toolMap[toolInstance.key] = toolInstance; + // Also index by class name (Mastra path, e.g. 'GetDataAsDatasetTool') + // so saveStep / stream-handler can look up tool definitions by either name. + const className = (toolInstance as object).constructor?.name; + if (className && className !== toolInstance.key) { + toolMap[className] = toolInstance; + } } return { list: tools, diff --git a/src/services/generation.service.ts b/src/services/generation.service.ts index 7d5f1c4..ff563d0 100644 --- a/src/services/generation.service.ts +++ b/src/services/generation.service.ts @@ -1,7 +1,9 @@ import {BindingScope, inject, injectable, service} from '@loopback/core'; +import {MastraChatAgent} from '../mastra'; import {ChatGraph} from '../graphs/chat/chat.graph'; import {AiIntegrationBindings} from '../keys'; import {ITransport} from '../transports/types'; +import {AIIntegrationConfig} from '../types'; import {ILimitStrategy} from './limit-strategies/types'; @injectable({scope: BindingScope.REQUEST}) @@ -9,11 +11,16 @@ export class GenerationService { constructor( @service(ChatGraph) private readonly chatGraph: ChatGraph, + @service(MastraChatAgent) + private readonly mastraChatAgent: MastraChatAgent, @inject(AiIntegrationBindings.Transport) private readonly transport: ITransport, + @inject(AiIntegrationBindings.Config, {optional: true}) + private readonly aiConfig: AIIntegrationConfig | undefined, @inject(AiIntegrationBindings.LimitStrategy, {optional: true}) private readonly limiter?: ILimitStrategy, ) {} + async generate(prompt: string, files: Express.Multer.File[], id?: string) { await this.limiter?.check(); const abortController = new AbortController(); @@ -21,13 +28,21 @@ export class GenerationService { this.transport.onCancel(() => { abortController.abort(); }); - const stream = await this.chatGraph.execute( - prompt, - files, - abortController.signal, - id, - ); + if (this.aiConfig?.runtime === 'mastra') { + await this._runMastraFlow(prompt, files, abortController.signal, id); + } else { + await this._runLangGraphFlow(prompt, files, abortController.signal, id); + } + } + + private async _runLangGraphFlow( + prompt: string, + files: Express.Multer.File[], + abort: AbortSignal, + id?: string, + ) { + const stream = await this.chatGraph.execute(prompt, files, abort, id); try { for await (const chunk of stream) { await this.transport.send(chunk); @@ -38,4 +53,26 @@ export class GenerationService { throw error; } } + + private async _runMastraFlow( + prompt: string, + files: Express.Multer.File[], + abort: AbortSignal, + id?: string, + ) { + try { + for await (const chunk of this.mastraChatAgent.execute( + prompt, + files, + abort, + id, + )) { + await this.transport.send(chunk); + } + await this.transport.end(); + } catch (error) { + await this.transport.end(error); + throw error; + } + } } diff --git a/src/services/mastra-bridge.service.ts b/src/services/mastra-bridge.service.ts index cf7a632..f6fcbfb 100644 --- a/src/services/mastra-bridge.service.ts +++ b/src/services/mastra-bridge.service.ts @@ -30,6 +30,16 @@ export interface GraphToolBindingDescriptor { bindingKey: string; key: string; name: string; + /** + * Human-readable description stored in the binding tag by `@graphTool({ description })`. + * Available at startup without resolving the tool instance. + */ + description?: string; + /** + * Zod input schema stored in the binding tag by `@graphTool({ inputSchema })`. + * Available at startup without resolving the tool instance. + */ + inputSchema?: unknown; resolve: BindingResolver; } @@ -43,10 +53,14 @@ export interface MastraBootstrapPayload { /** * Minimal runtime contract required by the migration bridge. + * + * Returns `unknown` intentionally so that host-app factory implementations can + * return concrete agent / workflow types without a generic mismatch. Callers + * should cast the return value to the expected interface at the call site. */ export interface MastraRuntimeAdapter { - getAgent(name: string): T | undefined; - getWorkflow(name: string): T | undefined; + getAgent(name: string): unknown; + getWorkflow(name: string): unknown; } /** @@ -61,8 +75,25 @@ export type MastraRuntimeFactory = ( */ export interface IMastraBridge { initialize(): Promise; - getAgent(name: string): T | undefined; - getWorkflow(name: string): T | undefined; + /** + * Returns the underlying runtime adapter. Use this when you need to access + * the adapter beyond the typed `getTypedAgent` / `getTypedWorkflow` shortcuts. + */ + getRuntime(): MastraRuntimeAdapter; + /** Returns the raw (untyped) agent instance registered under the given name. */ + getAgent(name: string): unknown; + /** Returns the raw (untyped) workflow instance registered under the given name. */ + getWorkflow(name: string): unknown; + /** + * Type-safe accessor for a registered agent. The cast is performed internally; + * callers receive `T | undefined` without any explicit cast at the call site. + */ + getTypedAgent(name: string): T | undefined; + /** + * Type-safe accessor for a registered workflow. The cast is performed internally; + * callers receive `T | undefined` without any explicit cast at the call site. + */ + getTypedWorkflow(name: string): T | undefined; getBootstrapSnapshot(): MastraBootstrapPayload; } @@ -73,14 +104,14 @@ class NoopMastraRuntimeAdapter implements MastraRuntimeAdapter { /** * Returns undefined for all agents in no-op mode. */ - getAgent(_name: string): T | undefined { + getAgent(_name: string): unknown { return undefined; } /** * Returns undefined for all workflows in no-op mode. */ - getWorkflow(_name: string): T | undefined { + getWorkflow(_name: string): unknown { return undefined; } } @@ -111,24 +142,54 @@ export class MastraBridgeService implements IMastraBridge { } this.payload = this.collectBootstrapPayload(); - this.runtime = this.runtimeFactory - ? await this.runtimeFactory(this.payload) - : new NoopMastraRuntimeAdapter(); - this.initialized = true; + try { + this.runtime = await (this.runtimeFactory + ? this.runtimeFactory(this.payload) + : new NoopMastraRuntimeAdapter()); + this.initialized = true; + } catch (error) { + throw new Error(`Mastra runtime initialization failed: ${error}`); + } + } + + /** + * Returns the underlying runtime adapter. + */ + getRuntime(): MastraRuntimeAdapter { + if (!this.initialized) { + throw new Error('MastraBridgeService not initialized'); + } + return this.runtime; + } + + /** + * Returns the raw (untyped) agent instance registered under the given name. + */ + getAgent(name: string): unknown { + return this.getRuntime().getAgent(name); + } + + /** + * Returns the raw (untyped) workflow instance registered under the given name. + */ + getWorkflow(name: string): unknown { + return this.getRuntime().getWorkflow(name); } /** - * Returns a typed agent instance from the runtime adapter. + * Type-safe accessor: retrieves the agent and narrows it to `T`. + * The single internal cast is contained here so call sites remain cast-free. */ - getAgent(name: string): T | undefined { - return this.runtime.getAgent(name); + getTypedAgent(name: string): T | undefined { + return this.getRuntime().getAgent(name) as T | undefined; } /** - * Returns a typed workflow instance from the runtime adapter. + * Type-safe accessor: retrieves the workflow and narrows it to `T`. + * The single internal cast is contained here so call sites remain cast-free. */ - getWorkflow(name: string): T | undefined { - return this.runtime.getWorkflow(name); + getTypedWorkflow(name: string): T | undefined { + return this.getRuntime().getWorkflow(name) as T | undefined; } /** @@ -159,6 +220,8 @@ export class MastraBridgeService implements IMastraBridge { bindingKey: binding.key, key: String(binding.key), name: String(binding.tagMap?.[TOOL_NAME] ?? binding.key), + description: binding.tagMap?.['toolDescription'] as string | undefined, + inputSchema: binding.tagMap?.['toolInputSchema'], resolve: async () => this.context.get(binding.key), })); From 7d81e482e6e5803ddccf4e0edc67a7ab44617857 Mon Sep 17 00:00:00 2001 From: Piyush Singh Gaur Date: Mon, 4 May 2026 14:05:52 +0530 Subject: [PATCH 3/3] feat: migrated all graphs and sdks --- package-lock.json | 611 +++++++++++- package.json | 19 +- .../generation.controllers.acceptance.ts | 2 +- .../acceptance/db-query.graph.acceptance.ts | 95 -- .../nodes/get-tables-node.acceptance.ts | 72 -- .../unit/db-knowledge-graph.service.unit.ts | 64 +- .../db-query/unit/db-query.graph.unit.ts | 257 ----- .../mastra-steps/check-cache.step.unit.ts | 143 +++ .../check-permissions.step.unit.ts | 72 ++ .../mastra-steps/classify-change.step.unit.ts | 69 ++ .../db-query.workflow.integration.ts | 460 +++++++++ .../unit/mastra-steps/failed.step.unit.ts | 70 ++ .../generate-description.step.unit.ts | 129 +++ .../mastra-steps/is-improvement.step.unit.ts | 45 + .../mastra-steps/post-validation.step.unit.ts | 125 +++ .../mastra-steps/sql-generation.step.unit.ts | 123 +++ .../syntactic-validator.step.unit.ts | 91 ++ .../unit/nodes/check-cache.node.unit.ts | 305 ------ .../unit/nodes/check-permission.node.unit.ts | 78 -- .../unit/nodes/classify-change.node.unit.ts | 159 --- .../db-query/unit/nodes/failed.node.unit.ts | 60 -- .../unit/nodes/fix-query.node.unit.ts | 379 -------- .../unit/nodes/get-columns.node.unit.ts | 169 ---- .../unit/nodes/get-tables.node.unit.ts | 276 ------ .../unit/nodes/is-improvement.node.unit.ts | 53 - .../unit/nodes/save-dataset-node.unit.ts | 166 ---- .../nodes/semantic-validator.node.unit.ts | 237 ----- .../unit/nodes/sql-generation.node.unit.ts | 919 ------------------ .../nodes/syntactic-validator.node.unit.ts | 99 -- src/__tests__/fixtures/fake-ai-models.ts | 134 +++ src/__tests__/fixtures/test-app.ts | 43 +- .../generation.service.integration.ts | 26 +- src/__tests__/unit/chat.graph.unit.ts | 160 --- .../unit/langfuse-core.provider.unit.ts | 65 ++ src/__tests__/unit/mastra-bridge.unit.ts | 2 +- .../unit/nodes/call-llm.node.unit.ts | 100 -- .../nodes/context-compression.node.unit.ts | 69 -- .../unit/nodes/end-session.node.unit.ts | 110 --- .../unit/nodes/init-session.node.unit.ts | 93 -- .../unit/nodes/run-tool.node.unit.ts | 125 --- .../unit/nodes/summarise-file.node.unit.ts | 142 --- src/__tests__/unit/pgvector-sdk.store.unit.ts | 266 +++++ src/__tests__/unit/token-counter.unit.ts | 78 ++ .../select-and-render.step.unit.ts | 178 ++++ .../unit/visualizers/bar.visualizer.unit.ts | 196 ---- .../unit/visualizers/line.visualizer.unit.ts | 236 ----- .../unit/visualizers/pie.visualizer.unit.ts | 236 ----- src/component.ts | 24 +- .../controller/template.controller.ts | 20 +- src/components/db-query/db-query.component.ts | 63 +- src/components/db-query/db-query.graph.ts | 246 ----- src/components/db-query/index.ts | 2 - .../db-query/nodes/check-cache.node.ts | 201 ---- .../db-query/nodes/check-permissions.node.ts | 76 -- .../db-query/nodes/check-templates.node.ts | 188 ---- .../db-query/nodes/classify-change.node.ts | 80 -- src/components/db-query/nodes/failed.node.ts | 27 - .../db-query/nodes/fix-query.node.ts | 195 ---- .../db-query/nodes/generate-checklist.node.ts | 153 --- .../nodes/generate-description.node.ts | 110 --- .../db-query/nodes/get-columns.node.ts | 332 ------- .../db-query/nodes/get-tables.node.ts | 225 ----- src/components/db-query/nodes/index.ts | 16 - .../db-query/nodes/is-improvement.node.ts | 32 - .../db-query/nodes/save-dataset-node.ts | 145 --- .../db-query/nodes/semantic-validator.node.ts | 177 ---- .../db-query/nodes/sql-generation.node.ts | 229 ----- .../nodes/syntactic-validator.node.ts | 105 -- .../db-query/nodes/verify-checklist.node.ts | 194 ---- .../db-query/providers/datasets.retriever.ts | 37 - src/components/db-query/providers/index.ts | 3 +- .../db-query/providers/templates.retriever.ts | 37 - .../services/dataset-helper.service.ts | 6 +- .../db-knowledge-graph.service.ts | 32 +- .../services/template-helper.service.ts | 58 +- src/components/db-query/state.ts | 64 +- .../testing/db-query.graph.builder.ts | 39 - .../testing/generation.acceptance.builder.ts | 2 +- .../testing/get-table.node.builder.ts | 49 - src/components/db-query/testing/index.ts | 3 +- .../db-query/tools/ask-about-dataset.tool.ts | 138 --- .../tools/get-data-as-dataset.tool.ts | 48 +- .../db-query/tools/improve-dataset.tool.ts | 52 +- src/components/db-query/tools/index.ts | 1 - src/components/visualization/index.ts | 2 - .../nodes/call-query-generation.node.ts | 54 - .../nodes/get-dataset-data.node.ts | 38 - src/components/visualization/nodes/index.ts | 4 - .../nodes/render-visualization.node.ts | 47 - .../nodes/select-visualization.node.ts | 136 --- src/components/visualization/state.ts | 35 +- .../tools/generate-visualization.tool.ts | 78 +- .../visualization/visualization.graph.ts | 63 -- .../visualization/visualizer.component.ts | 33 +- .../visualizers/bar.visualizer.ts | 79 -- .../visualization/visualizers/index.ts | 3 - .../visualizers/line.visualizer.ts | 94 -- .../visualizers/pie.visualizer.ts | 73 -- src/controllers/chat.controller.ts | 2 +- src/graphs/base.graph.ts | 27 - src/graphs/chat/chat.graph.ts | 144 --- src/graphs/chat/index.ts | 5 - src/graphs/chat/nodes.enum.ts | 8 - src/graphs/chat/nodes/call-llm.node.ts | 57 -- .../chat/nodes/context-compression.node.ts | 50 - src/graphs/chat/nodes/end-session.node.ts | 43 - src/graphs/chat/nodes/index.ts | 6 - src/graphs/chat/nodes/init-session.node.ts | 81 -- src/graphs/chat/nodes/run-tool.node.ts | 94 -- src/graphs/chat/nodes/summarise-file.node.ts | 148 --- src/graphs/index.ts | 5 - src/graphs/state.ts | 12 - src/graphs/types.ts | 143 --- src/index.ts | 3 +- src/keys.ts | 91 +- src/mastra/chat/mappers/message.mapper.ts | 60 -- src/mastra/chat/mastra-chat.agent.ts | 130 ++- .../chat/steps/context-compression.step.ts | 59 +- src/mastra/chat/steps/init-session.step.ts | 31 +- src/mastra/chat/steps/save-step.step.ts | 34 +- src/mastra/chat/steps/stream-handler.step.ts | 6 +- src/mastra/chat/steps/summarise-file.step.ts | 87 +- src/mastra/chat/types/chat.types.ts | 2 +- src/mastra/chat/utils/adapt-stream.util.ts | 103 ++ .../chat/utils/normalize-messages.util.ts | 85 ++ src/mastra/db-query/index.ts | 15 + .../db-query/mastra-db-query.workflow.ts | 470 +++++++++ .../services/dataset-search.service.ts | 59 ++ src/mastra/db-query/services/index.ts | 3 + .../mastra-template-helper.service.ts | 317 ++++++ .../services/template-search.service.ts | 58 ++ src/mastra/db-query/types/db-query.types.ts | 52 + src/mastra/db-query/utils/index.ts | 2 + src/mastra/db-query/utils/prompt.util.ts | 22 + src/mastra/db-query/utils/thinking.util.ts | 24 + .../conditions/db-query.conditions.ts | 62 ++ .../db-query/workflow/conditions/index.ts | 8 + src/mastra/db-query/workflow/index.ts | 2 + .../workflow/steps/check-cache.step.ts | 198 ++++ .../workflow/steps/check-permissions.step.ts | 90 ++ .../workflow/steps/check-templates.step.ts | 193 ++++ .../workflow/steps/classify-change.step.ts | 98 ++ .../db-query/workflow/steps/failed.step.ts | 32 + .../db-query/workflow/steps/fix-query.step.ts | 199 ++++ .../workflow/steps/generate-checklist.step.ts | 173 ++++ .../steps/generate-description.step.ts | 122 +++ .../workflow/steps/get-columns.step.ts | 299 ++++++ .../workflow/steps/get-tables.step.ts | 246 +++++ src/mastra/db-query/workflow/steps/index.ts | 32 + .../workflow/steps/is-improvement.step.ts | 42 + .../workflow/steps/post-validation.step.ts | 100 ++ .../workflow/steps/save-dataset.step.ts | 153 +++ .../workflow/steps/semantic-validator.step.ts | 191 ++++ .../workflow/steps/sql-generation.step.ts | 221 +++++ .../steps/syntactic-validator.step.ts | 122 +++ .../workflow/steps/verify-checklist.step.ts | 191 ++++ src/mastra/index.ts | 2 + src/mastra/request-tool-store.ts | 4 +- src/mastra/types.ts | 2 - src/mastra/visualization/index.ts | 9 + .../mastra-visualization.workflow.ts | 204 ++++ .../services/bar.visualizer.service.ts | 130 +++ src/mastra/visualization/services/index.ts | 3 + .../services/line.visualizer.service.ts | 151 +++ .../services/pie.visualizer.service.ts | 120 +++ .../types/visualization.types.ts | 136 +++ .../workflow/conditions/index.ts | 8 + .../conditions/visualization.conditions.ts | 55 ++ src/mastra/visualization/workflow/index.ts | 4 + .../steps/call-query-generation.step.ts | 84 ++ .../workflow/steps/get-dataset-data.step.ts | 51 + .../visualization/workflow/steps/index.ts | 11 + .../steps/render-visualization.step.ts | 73 ++ .../steps/select-visualization.step.ts | 148 +++ src/models/message.model.ts | 2 +- src/providers/tools.provider.ts | 2 +- .../vector-stores/inmemory.vector.ts | 60 +- .../chat => services}/chat-metadata.type.ts | 0 src/{graphs/chat => services}/chat.store.ts | 127 ++- src/services/generation.service.ts | 27 +- src/services/index.ts | 2 + src/services/mastra-bridge.service.ts | 2 +- src/services/token-counter.service.ts | 50 +- .../db/postgresql/vector-store/index.ts | 2 +- .../vector-store/pgvector-sdk.store.ts | 217 +++++ .../postgresql/vector-store/pgvector.store.ts | 53 - src/sub-modules/obf/langfuse/index.ts | 2 + .../obf/langfuse/langfuse-core.provider.ts | 38 + .../obf/langfuse/langfuse-mastra.component.ts | 43 + .../obf/langfuse/langfuse.provider.ts | 24 +- .../anthropic/llms/anthropic.provider.ts | 29 - .../anthropic/llms/claude-sdk.provider.ts | 32 + .../providers/anthropic/llms/index.ts | 3 +- .../bedrock-embedding-sdk.provider.ts | 25 + .../embedding/bedrock-embedding.provider.ts | 21 - .../providers/aws/embedding/index.ts | 3 +- src/sub-modules/providers/aws/index.ts | 1 - .../llms/bedrock-non-thinking-sdk.provider.ts | 19 + .../aws/llms/bedrock-non-thinking.provider.ts | 12 - .../aws/llms/bedrock-sdk.provider.ts | 39 + .../providers/aws/llms/bedrock.provider.ts | 61 -- src/sub-modules/providers/aws/llms/index.ts | 6 +- src/sub-modules/providers/aws/types.ts | 6 - .../cerebras/llm/cerebras-sdk.provider.ts | 31 + .../cerebras/llm/cerebras.provider.ts | 25 - .../providers/cerebras/llm/index.ts | 3 +- .../gemini-embedding-sdk.provider.ts | 20 + .../embedding/gemini-embedding.provider.ts | 20 - .../providers/google/embedding/index.ts | 3 +- .../google/llms/gemini-sdk.provider.ts | 28 + .../providers/google/llms/gemini.provider.ts | 17 - .../providers/google/llms/index.ts | 3 +- .../providers/groq/llms/groq-sdk.provider.ts | 30 + .../providers/groq/llms/groq.provider.ts | 18 - src/sub-modules/providers/groq/llms/index.ts | 2 +- .../providers/ollama/embedding/index.ts | 3 +- .../ollama-embedding-sdk.provider.ts | 27 + .../embedding/ollama-embedding.provider.ts | 15 - .../providers/ollama/llms/index.ts | 3 +- .../ollama/llms/ollama-sdk.provider.ts | 30 + .../providers/ollama/llms/ollama.provider.ts | 16 - .../providers/openai/llms/index.ts | 3 +- .../openai/llms/openai-sdk.provider.ts | 57 ++ .../providers/openai/llms/openai.provider.ts | 24 - src/sub-modules/providers/openai/types.ts | 6 - src/transports/http.transport.ts | 2 +- src/transports/sse.transport.ts | 2 +- src/transports/types.ts | 2 +- src/types.ts | 99 +- .../event.types.ts => types/events.ts} | 9 + src/types/tool.ts | 103 ++ src/utils.ts | 69 +- 232 files changed, 9408 insertions(+), 10431 deletions(-) delete mode 100644 src/__tests__/db-query/acceptance/db-query.graph.acceptance.ts delete mode 100644 src/__tests__/db-query/acceptance/nodes/get-tables-node.acceptance.ts delete mode 100644 src/__tests__/db-query/unit/db-query.graph.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/check-cache.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/check-permissions.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/classify-change.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/db-query.workflow.integration.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/failed.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/generate-description.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/is-improvement.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/post-validation.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/sql-generation.step.unit.ts create mode 100644 src/__tests__/db-query/unit/mastra-steps/syntactic-validator.step.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/failed.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/is-improvement.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts delete mode 100644 src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts create mode 100644 src/__tests__/fixtures/fake-ai-models.ts delete mode 100644 src/__tests__/unit/chat.graph.unit.ts create mode 100644 src/__tests__/unit/langfuse-core.provider.unit.ts delete mode 100644 src/__tests__/unit/nodes/call-llm.node.unit.ts delete mode 100644 src/__tests__/unit/nodes/context-compression.node.unit.ts delete mode 100644 src/__tests__/unit/nodes/end-session.node.unit.ts delete mode 100644 src/__tests__/unit/nodes/init-session.node.unit.ts delete mode 100644 src/__tests__/unit/nodes/run-tool.node.unit.ts delete mode 100644 src/__tests__/unit/nodes/summarise-file.node.unit.ts create mode 100644 src/__tests__/unit/pgvector-sdk.store.unit.ts create mode 100644 src/__tests__/unit/token-counter.unit.ts create mode 100644 src/__tests__/visualization/unit/mastra-steps/select-and-render.step.unit.ts delete mode 100644 src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts delete mode 100644 src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts delete mode 100644 src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts delete mode 100644 src/components/db-query/db-query.graph.ts delete mode 100644 src/components/db-query/nodes/check-cache.node.ts delete mode 100644 src/components/db-query/nodes/check-permissions.node.ts delete mode 100644 src/components/db-query/nodes/check-templates.node.ts delete mode 100644 src/components/db-query/nodes/classify-change.node.ts delete mode 100644 src/components/db-query/nodes/failed.node.ts delete mode 100644 src/components/db-query/nodes/fix-query.node.ts delete mode 100644 src/components/db-query/nodes/generate-checklist.node.ts delete mode 100644 src/components/db-query/nodes/generate-description.node.ts delete mode 100644 src/components/db-query/nodes/get-columns.node.ts delete mode 100644 src/components/db-query/nodes/get-tables.node.ts delete mode 100644 src/components/db-query/nodes/index.ts delete mode 100644 src/components/db-query/nodes/is-improvement.node.ts delete mode 100644 src/components/db-query/nodes/save-dataset-node.ts delete mode 100644 src/components/db-query/nodes/semantic-validator.node.ts delete mode 100644 src/components/db-query/nodes/sql-generation.node.ts delete mode 100644 src/components/db-query/nodes/syntactic-validator.node.ts delete mode 100644 src/components/db-query/nodes/verify-checklist.node.ts delete mode 100644 src/components/db-query/providers/datasets.retriever.ts delete mode 100644 src/components/db-query/providers/templates.retriever.ts delete mode 100644 src/components/db-query/testing/db-query.graph.builder.ts delete mode 100644 src/components/db-query/testing/get-table.node.builder.ts delete mode 100644 src/components/db-query/tools/ask-about-dataset.tool.ts delete mode 100644 src/components/visualization/nodes/call-query-generation.node.ts delete mode 100644 src/components/visualization/nodes/get-dataset-data.node.ts delete mode 100644 src/components/visualization/nodes/index.ts delete mode 100644 src/components/visualization/nodes/render-visualization.node.ts delete mode 100644 src/components/visualization/nodes/select-visualization.node.ts delete mode 100644 src/components/visualization/visualization.graph.ts delete mode 100644 src/components/visualization/visualizers/bar.visualizer.ts delete mode 100644 src/components/visualization/visualizers/index.ts delete mode 100644 src/components/visualization/visualizers/line.visualizer.ts delete mode 100644 src/components/visualization/visualizers/pie.visualizer.ts delete mode 100644 src/graphs/base.graph.ts delete mode 100644 src/graphs/chat/chat.graph.ts delete mode 100644 src/graphs/chat/index.ts delete mode 100644 src/graphs/chat/nodes.enum.ts delete mode 100644 src/graphs/chat/nodes/call-llm.node.ts delete mode 100644 src/graphs/chat/nodes/context-compression.node.ts delete mode 100644 src/graphs/chat/nodes/end-session.node.ts delete mode 100644 src/graphs/chat/nodes/index.ts delete mode 100644 src/graphs/chat/nodes/init-session.node.ts delete mode 100644 src/graphs/chat/nodes/run-tool.node.ts delete mode 100644 src/graphs/chat/nodes/summarise-file.node.ts delete mode 100644 src/graphs/index.ts delete mode 100644 src/graphs/state.ts delete mode 100644 src/graphs/types.ts delete mode 100644 src/mastra/chat/mappers/message.mapper.ts create mode 100644 src/mastra/chat/utils/adapt-stream.util.ts create mode 100644 src/mastra/chat/utils/normalize-messages.util.ts create mode 100644 src/mastra/db-query/index.ts create mode 100644 src/mastra/db-query/mastra-db-query.workflow.ts create mode 100644 src/mastra/db-query/services/dataset-search.service.ts create mode 100644 src/mastra/db-query/services/index.ts create mode 100644 src/mastra/db-query/services/mastra-template-helper.service.ts create mode 100644 src/mastra/db-query/services/template-search.service.ts create mode 100644 src/mastra/db-query/types/db-query.types.ts create mode 100644 src/mastra/db-query/utils/index.ts create mode 100644 src/mastra/db-query/utils/prompt.util.ts create mode 100644 src/mastra/db-query/utils/thinking.util.ts create mode 100644 src/mastra/db-query/workflow/conditions/db-query.conditions.ts create mode 100644 src/mastra/db-query/workflow/conditions/index.ts create mode 100644 src/mastra/db-query/workflow/index.ts create mode 100644 src/mastra/db-query/workflow/steps/check-cache.step.ts create mode 100644 src/mastra/db-query/workflow/steps/check-permissions.step.ts create mode 100644 src/mastra/db-query/workflow/steps/check-templates.step.ts create mode 100644 src/mastra/db-query/workflow/steps/classify-change.step.ts create mode 100644 src/mastra/db-query/workflow/steps/failed.step.ts create mode 100644 src/mastra/db-query/workflow/steps/fix-query.step.ts create mode 100644 src/mastra/db-query/workflow/steps/generate-checklist.step.ts create mode 100644 src/mastra/db-query/workflow/steps/generate-description.step.ts create mode 100644 src/mastra/db-query/workflow/steps/get-columns.step.ts create mode 100644 src/mastra/db-query/workflow/steps/get-tables.step.ts create mode 100644 src/mastra/db-query/workflow/steps/index.ts create mode 100644 src/mastra/db-query/workflow/steps/is-improvement.step.ts create mode 100644 src/mastra/db-query/workflow/steps/post-validation.step.ts create mode 100644 src/mastra/db-query/workflow/steps/save-dataset.step.ts create mode 100644 src/mastra/db-query/workflow/steps/semantic-validator.step.ts create mode 100644 src/mastra/db-query/workflow/steps/sql-generation.step.ts create mode 100644 src/mastra/db-query/workflow/steps/syntactic-validator.step.ts create mode 100644 src/mastra/db-query/workflow/steps/verify-checklist.step.ts create mode 100644 src/mastra/visualization/index.ts create mode 100644 src/mastra/visualization/mastra-visualization.workflow.ts create mode 100644 src/mastra/visualization/services/bar.visualizer.service.ts create mode 100644 src/mastra/visualization/services/index.ts create mode 100644 src/mastra/visualization/services/line.visualizer.service.ts create mode 100644 src/mastra/visualization/services/pie.visualizer.service.ts create mode 100644 src/mastra/visualization/types/visualization.types.ts create mode 100644 src/mastra/visualization/workflow/conditions/index.ts create mode 100644 src/mastra/visualization/workflow/conditions/visualization.conditions.ts create mode 100644 src/mastra/visualization/workflow/index.ts create mode 100644 src/mastra/visualization/workflow/steps/call-query-generation.step.ts create mode 100644 src/mastra/visualization/workflow/steps/get-dataset-data.step.ts create mode 100644 src/mastra/visualization/workflow/steps/index.ts create mode 100644 src/mastra/visualization/workflow/steps/render-visualization.step.ts create mode 100644 src/mastra/visualization/workflow/steps/select-visualization.step.ts rename src/{graphs/chat => services}/chat-metadata.type.ts (100%) rename src/{graphs/chat => services}/chat.store.ts (65%) create mode 100644 src/sub-modules/db/postgresql/vector-store/pgvector-sdk.store.ts delete mode 100644 src/sub-modules/db/postgresql/vector-store/pgvector.store.ts create mode 100644 src/sub-modules/obf/langfuse/langfuse-core.provider.ts create mode 100644 src/sub-modules/obf/langfuse/langfuse-mastra.component.ts delete mode 100644 src/sub-modules/providers/anthropic/llms/anthropic.provider.ts create mode 100644 src/sub-modules/providers/anthropic/llms/claude-sdk.provider.ts create mode 100644 src/sub-modules/providers/aws/embedding/bedrock-embedding-sdk.provider.ts delete mode 100644 src/sub-modules/providers/aws/embedding/bedrock-embedding.provider.ts create mode 100644 src/sub-modules/providers/aws/llms/bedrock-non-thinking-sdk.provider.ts delete mode 100644 src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts create mode 100644 src/sub-modules/providers/aws/llms/bedrock-sdk.provider.ts delete mode 100644 src/sub-modules/providers/aws/llms/bedrock.provider.ts delete mode 100644 src/sub-modules/providers/aws/types.ts create mode 100644 src/sub-modules/providers/cerebras/llm/cerebras-sdk.provider.ts delete mode 100644 src/sub-modules/providers/cerebras/llm/cerebras.provider.ts create mode 100644 src/sub-modules/providers/google/embedding/gemini-embedding-sdk.provider.ts delete mode 100644 src/sub-modules/providers/google/embedding/gemini-embedding.provider.ts create mode 100644 src/sub-modules/providers/google/llms/gemini-sdk.provider.ts delete mode 100644 src/sub-modules/providers/google/llms/gemini.provider.ts create mode 100644 src/sub-modules/providers/groq/llms/groq-sdk.provider.ts delete mode 100644 src/sub-modules/providers/groq/llms/groq.provider.ts create mode 100644 src/sub-modules/providers/ollama/embedding/ollama-embedding-sdk.provider.ts delete mode 100644 src/sub-modules/providers/ollama/embedding/ollama-embedding.provider.ts create mode 100644 src/sub-modules/providers/ollama/llms/ollama-sdk.provider.ts delete mode 100644 src/sub-modules/providers/ollama/llms/ollama.provider.ts create mode 100644 src/sub-modules/providers/openai/llms/openai-sdk.provider.ts delete mode 100644 src/sub-modules/providers/openai/llms/openai.provider.ts delete mode 100644 src/sub-modules/providers/openai/types.ts rename src/{graphs/event.types.ts => types/events.ts} (80%) create mode 100644 src/types/tool.ts diff --git a/package-lock.json b/package-lock.json index 8906d8d..cf62d4c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -24,6 +24,12 @@ "winston": "^3.15.0" }, "devDependencies": { + "@ai-sdk/amazon-bedrock": "^4.0.97", + "@ai-sdk/anthropic": "^3.0.72", + "@ai-sdk/cerebras": "^2.0.46", + "@ai-sdk/google": "^3.0.65", + "@ai-sdk/groq": "^3.0.36", + "@ai-sdk/openai": "^3.0.54", "@commitlint/cli": "^19.8.1", "@commitlint/config-conventional": "^16.2.1", "@google/generative-ai": "^0.24.1", @@ -58,6 +64,7 @@ "jsonwebtoken": "^9.0.2", "loopback-connector-sqlite3": "^3.0.0", "mochawesome": "^7.1.3", + "ollama-ai-provider": "^1.2.0", "pg": "^8.16.3", "semantic-release": "^25.0.1", "source-map-support": "^0.5.21", @@ -120,6 +127,200 @@ "dev": true, "license": "MIT" }, + "node_modules/@ai-sdk/amazon-bedrock": { + "version": "4.0.97", + "resolved": "https://registry.npmjs.org/@ai-sdk/amazon-bedrock/-/amazon-bedrock-4.0.97.tgz", + "integrity": "sha512-T0uI1FeZ3IHM6n8x9Taylh1wOY1BYQ86AlLxcVs+oONLHTetSpev9JxJa/QID/3a5L6ve8p1+oXxtyVRrv7N7A==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/anthropic": "3.0.72", + "@ai-sdk/provider": "3.0.9", + "@ai-sdk/provider-utils": "4.0.24", + "@smithy/eventstream-codec": "^4.0.1", + "@smithy/util-utf8": "^4.0.0", + "aws4fetch": "^1.0.20" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/amazon-bedrock/node_modules/@ai-sdk/provider": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.9.tgz", + "integrity": "sha512-/ngMKqKdL9dSlY/eQ3NFDzzFyw0Hix+cbFFlyuKEKcOgpHdBt/spKUvX/i0wGrDLFPYJeVvv3N0j92LxWRL7yQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/amazon-bedrock/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.24", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.24.tgz", + "integrity": "sha512-oXIw1oLmuBILuvHgSj6w5LOV8oSnFRouPSv0MGkG9sRMeukZ9JnMF17kldaRQaRq8lSJIxo6aS3NzWlVmSb+4Q==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.8" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/amazon-bedrock/node_modules/@aws-crypto/crc32": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/crc32/-/crc32-5.2.0.tgz", + "integrity": "sha512-nLbCWqQNgUiwwtFsen1AdzAtvuLRsQS8rYgMuxCrdKf9kOssamGLuPwyTY9wyYblNr9+1XM8v6zoDTPPSIeANg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/@ai-sdk/amazon-bedrock/node_modules/@smithy/eventstream-codec": { + "version": "4.2.14", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.14.tgz", + "integrity": "sha512-erZq0nOIpzfeZdCyzZjdJb4nVSKLUmSkaQUVkRGQTXs30gyUGeKnrYEg+Xe1W5gE3aReS7IgsvANwVPxSzY6Pw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/crc32": "5.2.0", + "@smithy/types": "^4.14.1", + "@smithy/util-hex-encoding": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@ai-sdk/amazon-bedrock/node_modules/@smithy/util-utf8": { + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/@smithy/util-utf8/-/util-utf8-4.2.2.tgz", + "integrity": "sha512-75MeYpjdWRe8M5E3AW0O4Cx3UadweS+cwdXjwYGBW5h/gxxnbeZ877sLPX/ZJA9GVTlL/qG0dXP29JWFCD1Ayw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@ai-sdk/anthropic": { + "version": "3.0.72", + "resolved": "https://registry.npmjs.org/@ai-sdk/anthropic/-/anthropic-3.0.72.tgz", + "integrity": "sha512-t0j9mggxylA9uP0hi12NlRk2npYh4QkE7JIpws2MdV/18QzHcsT6+TNGIjbOPayLQDjrmRKx78Ym7iZkg9qRxQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@ai-sdk/provider-utils": "4.0.24" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/anthropic/node_modules/@ai-sdk/provider": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.9.tgz", + "integrity": "sha512-/ngMKqKdL9dSlY/eQ3NFDzzFyw0Hix+cbFFlyuKEKcOgpHdBt/spKUvX/i0wGrDLFPYJeVvv3N0j92LxWRL7yQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/anthropic/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.24", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.24.tgz", + "integrity": "sha512-oXIw1oLmuBILuvHgSj6w5LOV8oSnFRouPSv0MGkG9sRMeukZ9JnMF17kldaRQaRq8lSJIxo6aS3NzWlVmSb+4Q==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.8" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/cerebras": { + "version": "2.0.46", + "resolved": "https://registry.npmjs.org/@ai-sdk/cerebras/-/cerebras-2.0.46.tgz", + "integrity": "sha512-MK8uzm5k+G0HM/pjHfrQGvPsIUy1zSqat716UNpuGo33GQGsABf3jorXyuKmmsGhQpyoKB03DqKZs821FID3Iw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/openai-compatible": "2.0.42", + "@ai-sdk/provider": "3.0.9", + "@ai-sdk/provider-utils": "4.0.24" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/cerebras/node_modules/@ai-sdk/provider": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.9.tgz", + "integrity": "sha512-/ngMKqKdL9dSlY/eQ3NFDzzFyw0Hix+cbFFlyuKEKcOgpHdBt/spKUvX/i0wGrDLFPYJeVvv3N0j92LxWRL7yQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/cerebras/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.24", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.24.tgz", + "integrity": "sha512-oXIw1oLmuBILuvHgSj6w5LOV8oSnFRouPSv0MGkG9sRMeukZ9JnMF17kldaRQaRq8lSJIxo6aS3NzWlVmSb+4Q==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.8" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/@ai-sdk/gateway": { "version": "3.0.104", "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.104.tgz", @@ -137,6 +338,198 @@ "zod": "^3.25.76 || ^4.1.8" } }, + "node_modules/@ai-sdk/google": { + "version": "3.0.65", + "resolved": "https://registry.npmjs.org/@ai-sdk/google/-/google-3.0.65.tgz", + "integrity": "sha512-SwdaJ6IqguyiVuDRgiRM4sHj7uUO4AETlQFFLF3jcEvu/3yrgIHfw2aM6bBNKSdalw0j25Pedx6qyHc2DWJwrg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@ai-sdk/provider-utils": "4.0.24" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/google/node_modules/@ai-sdk/provider": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.9.tgz", + "integrity": "sha512-/ngMKqKdL9dSlY/eQ3NFDzzFyw0Hix+cbFFlyuKEKcOgpHdBt/spKUvX/i0wGrDLFPYJeVvv3N0j92LxWRL7yQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/google/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.24", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.24.tgz", + "integrity": "sha512-oXIw1oLmuBILuvHgSj6w5LOV8oSnFRouPSv0MGkG9sRMeukZ9JnMF17kldaRQaRq8lSJIxo6aS3NzWlVmSb+4Q==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.8" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/groq": { + "version": "3.0.36", + "resolved": "https://registry.npmjs.org/@ai-sdk/groq/-/groq-3.0.36.tgz", + "integrity": "sha512-77onMo3RZg6wG9qZQuekqS18YDS1znRZNN6PuBOsSm/TjryttUq4VOhk1HXogc3QWoEaTYGFxjZgZGp3wTJuSg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@ai-sdk/provider-utils": "4.0.24" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/groq/node_modules/@ai-sdk/provider": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.9.tgz", + "integrity": "sha512-/ngMKqKdL9dSlY/eQ3NFDzzFyw0Hix+cbFFlyuKEKcOgpHdBt/spKUvX/i0wGrDLFPYJeVvv3N0j92LxWRL7yQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/groq/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.24", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.24.tgz", + "integrity": "sha512-oXIw1oLmuBILuvHgSj6w5LOV8oSnFRouPSv0MGkG9sRMeukZ9JnMF17kldaRQaRq8lSJIxo6aS3NzWlVmSb+4Q==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.8" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/openai": { + "version": "3.0.54", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-3.0.54.tgz", + "integrity": "sha512-j1qrNe/ebUKuE+fETzS+CVnczs11jQBR9y9M6aoKtJZAosg6SZnPC1Bb92e2u6yaSK+88TZoFhiY67uYphPitw==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@ai-sdk/provider-utils": "4.0.24" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/openai-compatible": { + "version": "2.0.42", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai-compatible/-/openai-compatible-2.0.42.tgz", + "integrity": "sha512-hjq485U/dpi6Hvjzw5+F1vohCrB1kibGHlUFknYGa4nOoCnSvFM1lTXEIyTAkjK1uXgTbNk8vw66lbEyWT12jg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@ai-sdk/provider-utils": "4.0.24" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/openai-compatible/node_modules/@ai-sdk/provider": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.9.tgz", + "integrity": "sha512-/ngMKqKdL9dSlY/eQ3NFDzzFyw0Hix+cbFFlyuKEKcOgpHdBt/spKUvX/i0wGrDLFPYJeVvv3N0j92LxWRL7yQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/openai-compatible/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.24", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.24.tgz", + "integrity": "sha512-oXIw1oLmuBILuvHgSj6w5LOV8oSnFRouPSv0MGkG9sRMeukZ9JnMF17kldaRQaRq8lSJIxo6aS3NzWlVmSb+4Q==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.8" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.9.tgz", + "integrity": "sha512-/ngMKqKdL9dSlY/eQ3NFDzzFyw0Hix+cbFFlyuKEKcOgpHdBt/spKUvX/i0wGrDLFPYJeVvv3N0j92LxWRL7yQ==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider-utils": { + "version": "4.0.24", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.24.tgz", + "integrity": "sha512-oXIw1oLmuBILuvHgSj6w5LOV8oSnFRouPSv0MGkG9sRMeukZ9JnMF17kldaRQaRq8lSJIxo6aS3NzWlVmSb+4Q==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.9", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.8" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/@ai-sdk/provider": { "version": "3.0.8", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", @@ -199,6 +592,40 @@ "license": "MIT", "peer": true }, + "node_modules/@aws-crypto/crc32": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/crc32/-/crc32-3.0.0.tgz", + "integrity": "sha512-IzSgsrxUcsrejQbPVilIKy16kAT52EwB6zSaI+M3xxIhKh5+aldEyvI+z6erM7TCLB2BJsFrtHjp6/4/sr+3dA==", + "license": "Apache-2.0", + "optional": true, + "peer": true, + "dependencies": { + "@aws-crypto/util": "^3.0.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^1.11.1" + } + }, + "node_modules/@aws-crypto/crc32/node_modules/@aws-crypto/util": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@aws-crypto/util/-/util-3.0.0.tgz", + "integrity": "sha512-2OJlpeJpCR48CC8r+uKVChzs9Iungj9wkZrl8Z041DWEWvyIHILYKCPNzJghKsivj+S3mLo6BVc7mBNzdxA46w==", + "license": "Apache-2.0", + "optional": true, + "peer": true, + "dependencies": { + "@aws-sdk/types": "^3.222.0", + "@aws-sdk/util-utf8-browser": "^3.0.0", + "tslib": "^1.11.1" + } + }, + "node_modules/@aws-crypto/crc32/node_modules/tslib": { + "version": "1.14.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", + "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==", + "license": "0BSD", + "optional": true, + "peer": true + }, "node_modules/@aws-crypto/crc32c": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/@aws-crypto/crc32c/-/crc32c-5.2.0.tgz", @@ -1004,15 +1431,15 @@ } }, "node_modules/@aws-sdk/eventstream-handler-node/node_modules/@smithy/eventstream-codec": { - "version": "4.2.8", - "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.8.tgz", - "integrity": "sha512-jS/O5Q14UsufqoGhov7dHLOPCzkYJl9QDzusI2Psh4wyYx/izhzvX9P4D69aTxcdfVhEPhjK+wYyn/PzLjKbbw==", + "version": "4.2.14", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.14.tgz", + "integrity": "sha512-erZq0nOIpzfeZdCyzZjdJb4nVSKLUmSkaQUVkRGQTXs30gyUGeKnrYEg+Xe1W5gE3aReS7IgsvANwVPxSzY6Pw==", "devOptional": true, "license": "Apache-2.0", "dependencies": { "@aws-crypto/crc32": "5.2.0", - "@smithy/types": "^4.12.0", - "@smithy/util-hex-encoding": "^4.2.0", + "@smithy/types": "^4.14.1", + "@smithy/util-hex-encoding": "^4.2.2", "tslib": "^2.6.2" }, "engines": { @@ -1482,15 +1909,15 @@ } }, "node_modules/@aws-sdk/middleware-websocket/node_modules/@smithy/eventstream-codec": { - "version": "4.2.8", - "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.8.tgz", - "integrity": "sha512-jS/O5Q14UsufqoGhov7dHLOPCzkYJl9QDzusI2Psh4wyYx/izhzvX9P4D69aTxcdfVhEPhjK+wYyn/PzLjKbbw==", + "version": "4.2.14", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.14.tgz", + "integrity": "sha512-erZq0nOIpzfeZdCyzZjdJb4nVSKLUmSkaQUVkRGQTXs30gyUGeKnrYEg+Xe1W5gE3aReS7IgsvANwVPxSzY6Pw==", "devOptional": true, "license": "Apache-2.0", "dependencies": { "@aws-crypto/crc32": "5.2.0", - "@smithy/types": "^4.12.0", - "@smithy/util-hex-encoding": "^4.2.0", + "@smithy/types": "^4.14.1", + "@smithy/util-hex-encoding": "^4.2.2", "tslib": "^2.6.2" }, "engines": { @@ -1958,6 +2385,17 @@ } } }, + "node_modules/@aws-sdk/util-utf8-browser": { + "version": "3.259.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz", + "integrity": "sha512-UvFa/vR+e19XookZF8RzFZBrw2EUkQWxiBW0yYQAhvk3C+QVGl0H3ouca8LDBlBfQKXwmW3huo/59H8rwb1wJw==", + "license": "Apache-2.0", + "optional": true, + "peer": true, + "dependencies": { + "tslib": "^2.3.1" + } + }, "node_modules/@aws-sdk/xml-builder": { "version": "3.972.2", "resolved": "https://registry.npmjs.org/@aws-sdk/xml-builder/-/xml-builder-3.972.2.tgz", @@ -6160,6 +6598,48 @@ "node": ">=18.0.0" } }, + "node_modules/@smithy/eventstream-codec": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-2.2.0.tgz", + "integrity": "sha512-8janZoJw85nJmQZc4L8TuePp2pk1nxLgkxIR0TUjKJ5Dkj5oelB9WtiSSGXCQvNsJl0VSTvK/2ueMXxvpa9GVw==", + "license": "Apache-2.0", + "optional": true, + "peer": true, + "dependencies": { + "@aws-crypto/crc32": "3.0.0", + "@smithy/types": "^2.12.0", + "@smithy/util-hex-encoding": "^2.2.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@smithy/eventstream-codec/node_modules/@smithy/types": { + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz", + "integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==", + "license": "Apache-2.0", + "optional": true, + "peer": true, + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@smithy/eventstream-codec/node_modules/@smithy/util-hex-encoding": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/@smithy/util-hex-encoding/-/util-hex-encoding-2.2.0.tgz", + "integrity": "sha512-7iKXR+/4TpLK194pVjKiasIyqMtTYJsgKgM242Y9uzt5dhHnUDvMNb+3xIhRJ9QhvqGii/5cRUt4fJn3dtXNHQ==", + "license": "Apache-2.0", + "optional": true, + "peer": true, + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@smithy/eventstream-serde-browser": { "version": "4.2.8", "resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-browser/-/eventstream-serde-browser-4.2.8.tgz", @@ -6230,14 +6710,14 @@ } }, "node_modules/@smithy/eventstream-serde-universal/node_modules/@smithy/eventstream-codec": { - "version": "4.2.8", - "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.8.tgz", - "integrity": "sha512-jS/O5Q14UsufqoGhov7dHLOPCzkYJl9QDzusI2Psh4wyYx/izhzvX9P4D69aTxcdfVhEPhjK+wYyn/PzLjKbbw==", + "version": "4.2.14", + "resolved": "https://registry.npmjs.org/@smithy/eventstream-codec/-/eventstream-codec-4.2.14.tgz", + "integrity": "sha512-erZq0nOIpzfeZdCyzZjdJb4nVSKLUmSkaQUVkRGQTXs30gyUGeKnrYEg+Xe1W5gE3aReS7IgsvANwVPxSzY6Pw==", "license": "Apache-2.0", "dependencies": { "@aws-crypto/crc32": "5.2.0", - "@smithy/types": "^4.12.0", - "@smithy/util-hex-encoding": "^4.2.0", + "@smithy/types": "^4.14.1", + "@smithy/util-hex-encoding": "^4.2.2", "tslib": "^2.6.2" }, "engines": { @@ -6660,9 +7140,9 @@ } }, "node_modules/@smithy/types": { - "version": "4.12.0", - "resolved": "https://registry.npmjs.org/@smithy/types/-/types-4.12.0.tgz", - "integrity": "sha512-9YcuJVTOBDjg9LWo23Qp0lTQ3D7fQsQtwle0jVfpbUHy9qBwCEgKuVH4FqFB3VYu0nwdHKiEMA+oXz7oV8X1kw==", + "version": "4.14.1", + "resolved": "https://registry.npmjs.org/@smithy/types/-/types-4.14.1.tgz", + "integrity": "sha512-59b5HtSVrVR/eYNei3BUj3DCPKD/G7EtDDe7OEJE7i7FtQFugYo6MxbotS8mVJkLNVf8gYaAlEBwwtJ9HzhWSg==", "license": "Apache-2.0", "dependencies": { "tslib": "^2.6.2" @@ -6737,12 +7217,12 @@ } }, "node_modules/@smithy/util-buffer-from": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-4.2.0.tgz", - "integrity": "sha512-kAY9hTKulTNevM2nlRtxAG2FQ3B2OR6QIrPY3zE5LqJy1oxzmgBGsHLWTcNhWXKchgA0WHW+mZkQrng/pgcCew==", + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-4.2.2.tgz", + "integrity": "sha512-FDXD7cvUoFWwN6vtQfEta540Y/YBe5JneK3SoZg9bThSoOAC/eGeYEua6RkBgKjGa/sz6Y+DuBZj3+YEY21y4Q==", "license": "Apache-2.0", "dependencies": { - "@smithy/is-array-buffer": "^4.2.0", + "@smithy/is-array-buffer": "^4.2.2", "tslib": "^2.6.2" }, "engines": { @@ -6750,9 +7230,9 @@ } }, "node_modules/@smithy/util-buffer-from/node_modules/@smithy/is-array-buffer": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-4.2.0.tgz", - "integrity": "sha512-DZZZBvC7sjcYh4MazJSGiWMI2L7E0oCiRHREDzIxi/M2LY79/21iXt6aPLHge82wi5LsuRF5A06Ds3+0mlh6CQ==", + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-4.2.2.tgz", + "integrity": "sha512-n6rQ4N8Jj4YTQO3YFrlgZuwKodf4zUFs7EJIWH86pSCWBaAtAGBFfCM7Wx6D2bBJ2xqFNxGBSrUWswT3M0VJow==", "license": "Apache-2.0", "dependencies": { "tslib": "^2.6.2" @@ -6821,9 +7301,9 @@ } }, "node_modules/@smithy/util-hex-encoding": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/@smithy/util-hex-encoding/-/util-hex-encoding-4.2.0.tgz", - "integrity": "sha512-CCQBwJIvXMLKxVbO88IukazJD9a4kQ9ZN7/UMGBjBcJYvatpWk+9g870El4cB8/EJxfe+k+y0GmR9CAzkF+Nbw==", + "version": "4.2.2", + "resolved": "https://registry.npmjs.org/@smithy/util-hex-encoding/-/util-hex-encoding-4.2.2.tgz", + "integrity": "sha512-Qcz3W5vuHK4sLQdyT93k/rfrUwdJ8/HZ+nMUOyGdpeGA1Wxt65zYwi3oEl9kOM+RswvYq90fzkNDahPS8K0OIg==", "license": "Apache-2.0", "dependencies": { "tslib": "^2.6.2" @@ -8249,6 +8729,13 @@ "license": "MIT", "peer": true }, + "node_modules/aws4fetch": { + "version": "1.0.20", + "resolved": "https://registry.npmjs.org/aws4fetch/-/aws4fetch-1.0.20.tgz", + "integrity": "sha512-/djoAN709iY65ETD6LKCtyyEI04XIBP5xVvfmNxsEP0uJB5tyaGBztSryRr4HqMStr9R06PisQE7m9zDTXKu6g==", + "dev": true, + "license": "MIT" + }, "node_modules/axios": { "version": "1.13.6", "resolved": "https://registry.npmjs.org/axios/-/axios-1.13.6.tgz", @@ -20602,6 +21089,60 @@ "whatwg-fetch": "^3.6.20" } }, + "node_modules/ollama-ai-provider": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/ollama-ai-provider/-/ollama-ai-provider-1.2.0.tgz", + "integrity": "sha512-jTNFruwe3O/ruJeppI/quoOUxG7NA6blG3ZyQj3lei4+NnJo7bi3eIRWqlVpRlu/mbzbFXeJSBuYQWF6pzGKww==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "^1.0.0", + "@ai-sdk/provider-utils": "^2.0.0", + "partial-json": "0.1.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, + "node_modules/ollama-ai-provider/node_modules/@ai-sdk/provider": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/ollama-ai-provider/node_modules/@ai-sdk/provider-utils": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.8.tgz", + "integrity": "sha512-fqhG+4sCVv8x7nFzYnFo19ryhAa3w096Kmc3hWxMQfW/TubPOmt3A6tYZhl4mUfQWWQMsuSkLrtjlWuXBVSGQA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "nanoid": "^3.3.8", + "secure-json-parse": "^2.7.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.23.8" + } + }, "node_modules/on-finished": { "version": "2.4.1", "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", @@ -21343,6 +21884,13 @@ "node": ">= 0.8" } }, + "node_modules/partial-json": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/partial-json/-/partial-json-0.1.7.tgz", + "integrity": "sha512-Njv/59hHaokb/hRUjce3Hdv12wd60MtM9Z5Olmn+nehe0QDAsRtRbJPvJ0Z91TusF0SuZRIvnM+S4l6EIP8leA==", + "dev": true, + "license": "MIT" + }, "node_modules/pascal-case": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/pascal-case/-/pascal-case-3.1.2.tgz", @@ -23067,6 +23615,13 @@ "node": ">=6" } }, + "node_modules/secure-json-parse": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/secure-json-parse/-/secure-json-parse-2.7.0.tgz", + "integrity": "sha512-6aU+Rwsezw7VR8/nyvKTx8QpWH9FrcYiXXlqC4z5d5XQBDRqtbfsRjnwGyqbi3gddNtWHuEk9OANUotL26qKUw==", + "dev": true, + "license": "BSD-3-Clause" + }, "node_modules/semantic-release": { "version": "25.0.2", "resolved": "https://registry.npmjs.org/semantic-release/-/semantic-release-25.0.2.tgz", diff --git a/package.json b/package.json index c47d8b6..b4699a7 100644 --- a/package.json +++ b/package.json @@ -127,34 +127,28 @@ "@loopback/core": "^6.1.11" }, "dependencies": { - "@langchain/community": "^0.3.50", - "@langchain/core": "^0.3.80", - "@langchain/langgraph": "^0.4.4", "@loopback/repository": "^7.0.14", "@sourceloop/chat-service": "15.0.3", "@sourceloop/core": "17.0.3", "@sourceloop/file-utils": "0.3.7", "ai": "^6.0.168", - "langchain": "^0.3.37", "loopback4-authentication": "^12.2.0", "loopback4-authorization": "^7.0.3", "tslib": "^2.8.0", "winston": "^3.15.0" }, "devDependencies": { + "@ai-sdk/amazon-bedrock": "^4.0.97", + "@ai-sdk/anthropic": "^3.0.72", + "@ai-sdk/cerebras": "^2.0.46", + "@ai-sdk/google": "^3.0.65", + "@ai-sdk/groq": "^3.0.36", + "@ai-sdk/openai": "^3.0.54", "@commitlint/cli": "^19.8.1", "@commitlint/config-conventional": "^16.2.1", "@google/generative-ai": "^0.24.1", "@istanbuljs/nyc-config-typescript": "^1.0.2", - "@langchain/anthropic": "^0.3.26", - "@langchain/aws": "^0.1.13", - "@langchain/cerebras": "^0.0.2", - "@langchain/google-genai": "^0.2.16", - "@langchain/groq": "^0.2.3", - "@langchain/ollama": "^0.2.3", - "@langchain/openai": "^0.6.11", "@langfuse/core": "^4.4.2", - "@langfuse/langchain": "^4.4.2", "@loopback/build": "^11.0.9", "@loopback/eslint-config": "^15.0.5", "@loopback/testlab": "^7.0.9", @@ -176,6 +170,7 @@ "jsonwebtoken": "^9.0.2", "loopback-connector-sqlite3": "^3.0.0", "mochawesome": "^7.1.3", + "ollama-ai-provider": "^1.2.0", "pg": "^8.16.3", "semantic-release": "^25.0.1", "source-map-support": "^0.5.21", diff --git a/src/__tests__/acceptance/generation.controllers.acceptance.ts b/src/__tests__/acceptance/generation.controllers.acceptance.ts index 3fb4bf6..f599ff8 100644 --- a/src/__tests__/acceptance/generation.controllers.acceptance.ts +++ b/src/__tests__/acceptance/generation.controllers.acceptance.ts @@ -1,7 +1,7 @@ import {BindingScope} from '@loopback/core'; import {Client, expect} from '@loopback/testlab'; import {DbQueryAIExtensionBindings, IDataSetStore} from '../../components'; -import {LLMStreamEvent, LLMStreamEventType} from '../../graphs'; +import {LLMStreamEvent, LLMStreamEventType} from '../../types/events'; import {AiIntegrationBindings} from '../../keys'; import {PermissionKey} from '../../permissions'; import {HttpTransport} from '../../transports'; diff --git a/src/__tests__/db-query/acceptance/db-query.graph.acceptance.ts b/src/__tests__/db-query/acceptance/db-query.graph.acceptance.ts deleted file mode 100644 index 4d4a310..0000000 --- a/src/__tests__/db-query/acceptance/db-query.graph.acceptance.ts +++ /dev/null @@ -1,95 +0,0 @@ -import {Context} from '@loopback/core'; -import {AuthenticationBindings} from 'loopback4-authentication'; -import {IAuthUserWithPermissions} from 'loopback4-authorization'; -import { - DatabaseSchema, - DataSetHelper, - DbQueryAIExtensionBindings, - DbQueryGraph, - SchemaStore, -} from '../../../components'; -import {TestApp} from '../../fixtures/test-app'; -import { - seedCurrencies, - seedEmployees, - seedExchangeRates, - setupApplication, -} from '../../test-helper'; -import {dbQueryToolTests} from '../../../components/db-query/testing/db-query.graph.builder'; - -describe(`DB Query Graph Acceptance`, () => { - let app: TestApp; - let schema: DatabaseSchema; - let graphBuilder: DbQueryGraph; - let datasetHelper: DataSetHelper; - - before('checkIfCanRun', function () { - if (process.env.RUN_WITH_LLM !== 'true') { - // eslint-disable-next-line @typescript-eslint/no-invalid-this - this.skip(); - } - }); - - before('setupApplication', async () => { - ({app} = await setupApplication({})); - await seedEmployees(app); - await seedCurrencies(app); - await seedExchangeRates(app); - app - .bind(DbQueryAIExtensionBindings.GlobalContext) - .to([ - `Every value with currency_id should be converted to USD before returning to the user.`, - ]); - const schemaService = await app.get(`services.SchemaStore`); - schema = schemaService.get(); - }); - - after(async () => { - if (app) { - await app.stop(); - } - }); - - beforeEach(async () => { - const ctx = new Context(app, 'newCtx'); - ctx.bind(AuthenticationBindings.CURRENT_USER).to({ - userTenantId: 'test-tenant', - tenantId: 'test-tenant', - permissions: ['1', '2', '3'], - } as unknown as IAuthUserWithPermissions); - graphBuilder = await ctx.get(`services.DbQueryGraph`); - datasetHelper = await ctx.get(`services.DataSetHelper`); - }); - - const cases = dbQueryToolTests([ - { - prompt: - 'Show the salary of the employee Charlie White in USD, the result should just have one column named "salary" with 2 decimal places', - result: [ - { - salary: 9952.61, - }, - ], - }, - { - prompt: - 'Show all the employees who have salary greater than 8000 USD, the result should have just 1 column `name`, results ordered by name in ascending order', - result: [ - { - name: 'Charlie White', - }, - { - name: 'Nameless Gonbei', - }, - ], - }, - ]); - - for (const testCase of cases) { - it(`should execute the graph for ${testCase.desc}`, async () => { - await testCase.fn(schema, graphBuilder, async id => { - return datasetHelper.getDataFromDataset(id); - }); - }); - } -}); diff --git a/src/__tests__/db-query/acceptance/nodes/get-tables-node.acceptance.ts b/src/__tests__/db-query/acceptance/nodes/get-tables-node.acceptance.ts deleted file mode 100644 index 287ae45..0000000 --- a/src/__tests__/db-query/acceptance/nodes/get-tables-node.acceptance.ts +++ /dev/null @@ -1,72 +0,0 @@ -import {Context} from '@loopback/core'; -import {AuthenticationBindings} from 'loopback4-authentication'; -import {IAuthUserWithPermissions} from 'loopback4-authorization'; -import { - DatabaseSchema, - DbQueryAIExtensionBindings, - GetTablesNode, - SchemaStore, -} from '../../../../components'; -import {TestApp} from '../../../fixtures/test-app'; -import {setupApplication} from '../../../test-helper'; -import {getTableNodeTests} from '../../../../components/db-query/testing/get-table.node.builder'; - -describe('GetTablesNode Acceptance', function () { - let app: TestApp; - let node: GetTablesNode; - let schema: DatabaseSchema; - - before('checkIfCanRun', function () { - if (process.env.RUN_WITH_LLM !== 'true') { - // eslint-disable-next-line @typescript-eslint/no-invalid-this - this.skip(); - } - }); - - before('setupApplication', async () => { - ({app} = await setupApplication({})); - app.bind(DbQueryAIExtensionBindings.GlobalContext).to([]); - const schemaService = await app.get(`services.SchemaStore`); - schema = schemaService.get(); - }); - - after(async () => { - if (app) { - await app.stop(); - } - }); - - beforeEach(async () => { - const ctx = new Context(app, 'newCtx'); - ctx.bind(AuthenticationBindings.CURRENT_USER).to({ - userTenantId: 'test-tenant', - } as unknown as IAuthUserWithPermissions); - node = await ctx.get(`services.GetTablesNode`); - }); - - const cases = getTableNodeTests([ - { - query: 'Find all the resources that joined in the last month', - expectedTables: ['employees'], - }, - { - query: 'Find all the resources that have salary greater than 1000 USD', - expectedTables: ['employees', 'exchange_rates'], - }, - { - query: 'Show all the currencies that do not have any exchange rates', - expectedTables: ['currencies'], - }, - { - query: - 'Show the latest exchange rate for each currency with currency name', - expectedTables: ['currencies', 'exchange_rates'], - }, - ]); - - for (const test of cases) { - it(`should return tables for - ${test.desc}`, async () => { - await test.fn(schema, node); - }); - } -}); diff --git a/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts b/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts index 4da5395..ca0f95a 100644 --- a/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts +++ b/src/__tests__/db-query/unit/db-knowledge-graph.service.unit.ts @@ -1,20 +1,44 @@ -import {expect, sinon} from '@loopback/testlab'; +import {expect} from '@loopback/testlab'; import {DbKnowledgeGraphService} from '../../../components'; -import {EmbeddingProvider, RuntimeLLMProvider} from '../../../types'; +import {EmbeddingProvider, LLMProvider} from '../../../types'; +import {EmbeddingModelV2} from '@ai-sdk/provider'; +import {createFakeStreamingLanguageModel} from '../../fixtures/fake-ai-models'; + +function createTableEmbeddingModel(): EmbeddingModelV2 { + return { + specificationVersion: 'v2', + provider: 'test', + modelId: 'test-embedding', + maxEmbeddingsPerCall: null, + supportsParallelCalls: true, + doEmbed: async ({values}: {values: string[]}) => { + const embeddings = values.map((v: string) => { + if (v.startsWith('employee_salaries')) return [0.1, 0.2, 0.3]; + if (v.startsWith('employees')) return [0.1, 0.2, 0.3]; + if (v.startsWith('orders')) return [0.9, 0.8, 0.7]; + return [0.1, 0.2, 0.6]; + }); + return {embeddings, warnings: []}; + }, + } as unknown as EmbeddingModelV2; +} describe(`DbKnowledgeGraphService Unit`, function () { let service: DbKnowledgeGraphService; - let llmStub: sinon.SinonStub; - let embedStub: sinon.SinonStub; + let fakeModel: LLMProvider; beforeEach(() => { - llmStub = sinon.stub(); - embedStub = sinon.stub(); + fakeModel = createFakeStreamingLanguageModel( + JSON.stringify({ + concept: 'employees', + description: 'test description', + domain: 'test domain', + confidence: 0.9, + }), + ) as unknown as LLMProvider; service = new DbKnowledgeGraphService( - llmStub as unknown as RuntimeLLMProvider, - { - embedDocuments: embedStub, - } as unknown as EmbeddingProvider, + fakeModel, + createTableEmbeddingModel() as unknown as EmbeddingProvider, { models: [], knowledgeGraph: { @@ -28,26 +52,6 @@ describe(`DbKnowledgeGraphService Unit`, function () { }); it('should generate a knowledge graph for a schema and should be able to find from it', async () => { - embedStub.callsFake(async doc => { - if (doc[0].startsWith('employee_salaries')) { - return [[0.1, 0.2, 0.3]]; - } - if (doc[0].startsWith('employees')) { - return [[0.1, 0.2, 0.3]]; - } - if (doc[0].startsWith('orders')) { - return [[0.9, 0.8, 0.7]]; - } - return [[0.1, 0.2, 0.6]]; - }); - llmStub.resolves({ - content: JSON.stringify({ - concept: 'employees', - description: 'test description', - domain: 'test domain', - confidence: 0.9, - }), - }); const schema = { tables: { // eslint-disable-next-line @typescript-eslint/naming-convention diff --git a/src/__tests__/db-query/unit/db-query.graph.unit.ts b/src/__tests__/db-query/unit/db-query.graph.unit.ts deleted file mode 100644 index 04c69c7..0000000 --- a/src/__tests__/db-query/unit/db-query.graph.unit.ts +++ /dev/null @@ -1,257 +0,0 @@ -import {Context} from '@loopback/core'; -import {expect, sinon} from '@loopback/testlab'; -import { - DbQueryGraph, - DbQueryNodes, - EvaluationResult, - MAX_ATTEMPTS, -} from '../../../components'; -import {GRAPH_NODE_NAME} from '../../../constant'; -import {buildNodeStub} from '../../test-helper'; - -describe(`DbQueryGraph Unit`, function () { - let graph: DbQueryGraph; - let stubMap: Record; - - beforeEach(async () => { - const context = new Context('test-context'); - context.bind('DbQueryGraph').toClass(DbQueryGraph); - stubMap = {} as Record; - for (const key of Object.values(DbQueryNodes)) { - const stub = buildNodeStub(); - context - .bind(`services.${key}`) - .to(stub) - .tag({ - [GRAPH_NODE_NAME]: key, - }); - stubMap[key] = stub.execute; - } - // Parallel branches must return partial state to avoid LastValue conflicts - stubMap[DbQueryNodes.GetTables].callsFake(async () => ({})); - stubMap[DbQueryNodes.CheckCache].callsFake(async () => ({})); - stubMap[DbQueryNodes.GetColumns].callsFake(async () => ({})); - stubMap[DbQueryNodes.ClassifyChange].callsFake(async () => ({})); - stubMap[DbQueryNodes.FixQuery].callsFake(async () => ({})); - // Checklist + Description run in parallel — must return partial state - stubMap[DbQueryNodes.GenerateChecklist].callsFake(async () => ({ - validationChecklist: '1. Test check', - })); - stubMap[DbQueryNodes.GenerateDescription].callsFake( - async (state: Record) => - state.description ? {} : {description: 'Test description'}, - ); - // VerifyChecklist runs in parallel with SqlGeneration — must return partial state - stubMap[DbQueryNodes.VerifyChecklist].callsFake(async () => ({})); - // Validators run in parallel — must return partial state - stubMap[DbQueryNodes.SyntacticValidator].callsFake(async () => ({ - syntacticStatus: EvaluationResult.Pass, - })); - stubMap[DbQueryNodes.SemanticValidator].callsFake(async () => ({ - semanticStatus: EvaluationResult.Pass, - })); - graph = await context.get('DbQueryGraph'); - }); - - it('should follow the ideal flow of the graph for proper SQL generation', async () => { - const compiledGraph = await graph.build(); - - await compiledGraph.invoke( - { - prompt: 'test prompt', - schema: { - tables: {}, - relations: [], - }, - }, - {recursionLimit: 100}, - ); - - expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SyntacticValidator].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SemanticValidator].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.Failed].called).to.be.false(); - }); - - it('should fix query via FixQuery if syntactic validation fails with query error', async () => { - const compiledGraph = await graph.build(); - let syntacticRetryCount = 0; - stubMap[DbQueryNodes.SyntacticValidator].callsFake(async () => { - if (syntacticRetryCount < 1) { - syntacticRetryCount++; - return { - syntacticStatus: EvaluationResult.QueryError, - syntacticFeedback: 'Syntactic validation failed', - }; - } - return {syntacticStatus: EvaluationResult.Pass}; - }); - - await compiledGraph.invoke( - { - prompt: 'test prompt', - schema: { - tables: {}, - relations: [], - }, - }, - {recursionLimit: 100}, - ); - - expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true(); - // SqlGeneration called once; FixQuery handles the retry - expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.FixQuery].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SyntacticValidator].calledTwice).to.be.true(); - // Semantic runs in parallel with syntactic on both attempts - expect(stubMap[DbQueryNodes.SemanticValidator].calledTwice).to.be.true(); - expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.Failed].called).to.be.false(); - }); - - it('should retry table select if syntactic validation fails with table error', async () => { - const compiledGraph = await graph.build(); - let syntacticRetryCount = 0; - stubMap[DbQueryNodes.SyntacticValidator].callsFake(async () => { - if (syntacticRetryCount < 1) { - syntacticRetryCount++; - return { - syntacticStatus: EvaluationResult.TableError, - syntacticFeedback: 'Table not found', - }; - } - return {syntacticStatus: EvaluationResult.Pass}; - }); - - await compiledGraph.invoke( - { - prompt: 'test prompt', - schema: { - tables: {}, - relations: [], - }, - }, - {recursionLimit: 100}, - ); - - expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true(); - // GetTables called twice: initial + retry after table error - expect(stubMap[DbQueryNodes.GetTables].calledTwice).to.be.true(); - // SqlGeneration called twice: once per full pipeline pass - expect(stubMap[DbQueryNodes.SqlGeneration].calledTwice).to.be.true(); - expect(stubMap[DbQueryNodes.SyntacticValidator].calledTwice).to.be.true(); - expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.Failed].called).to.be.false(); - }); - - it('should fail if syntactic validation fails more than max attempts allowed', async () => { - const compiledGraph = await graph.build(); - stubMap[DbQueryNodes.SyntacticValidator].callsFake(async () => ({ - syntacticStatus: EvaluationResult.QueryError, - syntacticFeedback: 'Syntactic validation failed', - })); - - await compiledGraph.invoke( - { - prompt: 'test prompt', - schema: { - tables: {}, - relations: [], - }, - }, - {recursionLimit: 100}, - ); - - expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true(); - // SqlGeneration runs once; FixQuery handles subsequent retries - expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true(); - expect( - stubMap[DbQueryNodes.SyntacticValidator].getCalls().length, - ).to.be.eql(MAX_ATTEMPTS); - // FixQuery called MAX_ATTEMPTS - 1 times (first attempt via SqlGeneration) - expect(stubMap[DbQueryNodes.FixQuery].getCalls().length).to.be.eql( - MAX_ATTEMPTS - 1, - ); - expect(stubMap[DbQueryNodes.Failed].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SaveDataset].called).to.be.false(); - }); - - it('should fix query via FixQuery if semantic validation fails with query error', async () => { - const compiledGraph = await graph.build(); - let semanticRetryCount = 0; - stubMap[DbQueryNodes.SemanticValidator].callsFake(async () => { - if (semanticRetryCount < 1) { - semanticRetryCount++; - return { - semanticStatus: EvaluationResult.QueryError, - semanticFeedback: 'Semantic validation failed', - }; - } - return {semanticStatus: EvaluationResult.Pass}; - }); - - await compiledGraph.invoke( - { - prompt: 'test prompt', - schema: { - tables: {}, - relations: [], - }, - }, - {recursionLimit: 100}, - ); - - expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true(); - // SqlGeneration called once; FixQuery handles the retry - expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.FixQuery].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SyntacticValidator].calledTwice).to.be.true(); - expect(stubMap[DbQueryNodes.SemanticValidator].calledTwice).to.be.true(); - expect(stubMap[DbQueryNodes.SaveDataset].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.Failed].called).to.be.false(); - }); - - it('should fail if validation fails more than max attempts allowed', async () => { - const compiledGraph = await graph.build(); - stubMap[DbQueryNodes.SyntacticValidator].callsFake(async () => ({ - syntacticStatus: EvaluationResult.QueryError, - syntacticFeedback: 'Syntactic validation failed', - })); - stubMap[DbQueryNodes.SemanticValidator].callsFake(async () => ({ - semanticStatus: EvaluationResult.QueryError, - semanticFeedback: 'Semantic validation failed', - })); - - await compiledGraph.invoke( - { - prompt: 'test prompt', - schema: { - tables: {}, - relations: [], - }, - }, - {recursionLimit: 100}, - ); - - expect(stubMap[DbQueryNodes.IsImprovement].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.CheckCache].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.GetTables].calledOnce).to.be.true(); - // SqlGeneration runs once; FixQuery handles retries - expect(stubMap[DbQueryNodes.SqlGeneration].calledOnce).to.be.true(); - // With both validators failing, feedbacks grow by 2 per iteration - // so it reaches MAX_ATTEMPTS faster - expect(stubMap[DbQueryNodes.Failed].calledOnce).to.be.true(); - expect(stubMap[DbQueryNodes.SaveDataset].called).to.be.false(); - }); -}); diff --git a/src/__tests__/db-query/unit/mastra-steps/check-cache.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/check-cache.step.unit.ts new file mode 100644 index 0000000..8cab0a2 --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/check-cache.step.unit.ts @@ -0,0 +1,143 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DataSetHelper} from '../../../../components/db-query/services'; +import {DatasetActionType} from '../../../../components/db-query/constant'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + CacheResults, + QueryCacheMetadata, +} from '../../../../components/db-query/types'; +import {IVectorStoreDocument, LLMProvider} from '../../../../types'; +import {DatasetSearchService} from '../../../../mastra/db-query/services/dataset-search.service'; +import {checkCacheStep} from '../../../../mastra/db-query/workflow/steps/check-cache.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; +import {createFakeLanguageModel} from '../../../fixtures/fake-ai-models'; + +describe('checkCacheStep (Mastra)', function () { + let datasetSearchStub: {search: sinon.SinonStub}; + let dataSetHelperStub: { + checkPermissions: sinon.SinonStub; + find: sinon.SinonStub; + }; + let onUsageSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + const baseState = { + prompt: 'What is the salary of John?', + schema: {tables: {}, relations: []}, + } as unknown as DbQueryState; + + beforeEach(() => { + datasetSearchStub = {search: sinon.stub()}; + dataSetHelperStub = {checkPermissions: sinon.stub(), find: sinon.stub()}; + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy}; + }); + + it('returns {} when sampleSql already set (fast exit)', async () => { + const state = { + ...baseState, + sampleSql: 'SELECT 1', + } as unknown as DbQueryState; + + const result = await checkCacheStep(state, context, { + datasetSearch: datasetSearchStub as unknown as DatasetSearchService, + llm: createFakeLanguageModel( + CacheResults.NotRelevant, + ) as unknown as LLMProvider, + dataSetHelper: dataSetHelperStub as unknown as DataSetHelper, + }); + + expect(result).to.deepEqual({}); + sinon.assert.notCalled(datasetSearchStub.search); + }); + + it('returns {} when no documents found in cache', async () => { + datasetSearchStub.search.resolves([]); + + const result = await checkCacheStep(baseState, context, { + datasetSearch: datasetSearchStub as unknown as DatasetSearchService, + llm: createFakeLanguageModel( + CacheResults.NotRelevant, + ) as unknown as LLMProvider, + dataSetHelper: dataSetHelperStub as unknown as DataSetHelper, + }); + + expect(result).to.deepEqual({}); + }); + + it('returns {} when LLM classifies as NotRelevant and calls onUsage', async () => { + datasetSearchStub.search.resolves([ + { + pageContent: 'Some past prompt', + metadata: {description: 'some description'}, + } as IVectorStoreDocument, + ]); + + const result = await checkCacheStep(baseState, context, { + datasetSearch: datasetSearchStub as unknown as DatasetSearchService, + llm: createFakeLanguageModel( + CacheResults.NotRelevant, + 5, + 2, + ) as unknown as LLMProvider, + dataSetHelper: dataSetHelperStub as unknown as DataSetHelper, + }); + + expect(result).to.deepEqual({}); + sinon.assert.calledOnce(onUsageSpy); + const [inputTokens, outputTokens] = onUsageSpy.firstCall.args; + expect(inputTokens).to.be.a.Number(); + expect(outputTokens).to.be.a.Number(); + }); + + it('returns sampleSql when LLM classifies as Similar', async () => { + const sql = "SELECT salary FROM employees WHERE name = 'John'"; + datasetSearchStub.search.resolves([ + { + pageContent: 'What is the salary of John?', + metadata: {description: 'Returns salary', query: sql, datasetId: '123'}, + } as unknown as IVectorStoreDocument, + ]); + + const result = await checkCacheStep(baseState, context, { + datasetSearch: datasetSearchStub as unknown as DatasetSearchService, + llm: createFakeLanguageModel( + `${CacheResults.Similar} 1`, + ) as unknown as LLMProvider, + dataSetHelper: dataSetHelperStub as unknown as DataSetHelper, + }); + + expect(result.sampleSql).to.equal(sql); + sinon.assert.calledOnce(onUsageSpy); + }); + + it('returns fromCache=true when LLM classifies as AsIs and permissions pass', async () => { + const sql = "SELECT salary FROM employees WHERE name = 'John'"; + datasetSearchStub.search.resolves([ + { + pageContent: 'What is the salary of John?', + metadata: {description: 'Returns salary', query: sql, datasetId: '42'}, + } as unknown as IVectorStoreDocument, + ]); + dataSetHelperStub.checkPermissions.resolves([]); + dataSetHelperStub.find.resolves([ + { + id: '42', + query: sql, + actions: [{type: DatasetActionType.Liked}], + }, + ]); + + const result = await checkCacheStep(baseState, context, { + datasetSearch: datasetSearchStub as unknown as DatasetSearchService, + llm: createFakeLanguageModel( + `${CacheResults.AsIs} 1`, + ) as unknown as LLMProvider, + dataSetHelper: dataSetHelperStub as unknown as DataSetHelper, + }); + + expect(result.fromCache).to.be.true(); + expect(result.datasetId).to.equal('42'); + sinon.assert.calledOnce(onUsageSpy); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/check-permissions.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/check-permissions.step.unit.ts new file mode 100644 index 0000000..5548ee9 --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/check-permissions.step.unit.ts @@ -0,0 +1,72 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {Errors} from '../../../../components/db-query/types'; +import {checkPermissionsStep} from '../../../../mastra/db-query/workflow/steps/check-permissions.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; +import {LLMProvider} from '../../../../types'; +import {createFakeLanguageModel} from '../../../fixtures/fake-ai-models'; + +describe('checkPermissionsStep (Mastra)', function () { + const baseState = { + prompt: 'Show all salaries', + schema: { + tables: { + 'public.employees': {columns: []}, + 'public.salaries': {columns: []}, + }, + relations: [], + }, + } as unknown as DbQueryState; + + let onUsageSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + beforeEach(() => { + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy} as unknown as MastraDbQueryContext; + }); + + it('returns empty when all permissions are granted', async () => { + const permissions = {findMissingPermissions: sinon.stub().returns([])}; + + const result = await checkPermissionsStep(baseState, context, { + llm: createFakeLanguageModel('unused') as unknown as LLMProvider, + permissions: permissions as never, + }); + + expect(result).to.deepEqual({}); + sinon.assert.notCalled(onUsageSpy); + }); + + it('sets PermissionError status when missing permissions', async () => { + const permissions = { + findMissingPermissions: sinon.stub().returns(['salaries_read']), + }; + + const result = await checkPermissionsStep(baseState, context, { + llm: createFakeLanguageModel( + 'You do not have access to salary data', + ) as unknown as LLMProvider, + permissions: permissions as never, + }); + + expect(result.status).to.equal(Errors.PermissionError); + expect(result.replyToUser).to.equal( + 'You do not have access to salary data', + ); + sinon.assert.calledOnce(onUsageSpy); + }); + + it('calls findMissingPermissions with lowercase table names (without schema prefix)', async () => { + const permissions = {findMissingPermissions: sinon.stub().returns([])}; + + await checkPermissionsStep(baseState, context, { + llm: createFakeLanguageModel('unused') as unknown as LLMProvider, + permissions: permissions as never, + }); + + const tableNames: string[] = + permissions.findMissingPermissions.firstCall.args[0]; + expect(tableNames).to.containDeep(['employees', 'salaries']); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/classify-change.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/classify-change.step.unit.ts new file mode 100644 index 0000000..2313096 --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/classify-change.step.unit.ts @@ -0,0 +1,69 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {ChangeType} from '../../../../components/db-query/types'; +import {classifyChangeStep} from '../../../../mastra/db-query/workflow/steps/classify-change.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; +import {LLMProvider} from '../../../../types'; +import {createFakeLanguageModel} from '../../../fixtures/fake-ai-models'; + +describe('classifyChangeStep (Mastra)', function () { + const baseState = { + prompt: 'Also filter by department', + sampleSql: 'SELECT * FROM employees', + sampleSqlPrompt: 'Show all employees', + schema: {tables: {}, relations: []}, + } as unknown as DbQueryState; + + let onUsageSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + beforeEach(() => { + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy} as unknown as MastraDbQueryContext; + }); + + it('returns empty when no sampleSql in state', async () => { + const state = {prompt: 'New query'} as unknown as DbQueryState; + const result = await classifyChangeStep(state, context, { + llm: createFakeLanguageModel('minor') as unknown as LLMProvider, + }); + + expect(result).to.deepEqual({}); + sinon.assert.notCalled(onUsageSpy); + }); + + it('classifies as minor when LLM returns "minor"', async () => { + const result = await classifyChangeStep(baseState, context, { + llm: createFakeLanguageModel('minor') as unknown as LLMProvider, + }); + + expect(result.changeType).to.equal(ChangeType.Minor); + sinon.assert.calledOnce(onUsageSpy); + }); + + it('classifies as rewrite when LLM returns "rewrite"', async () => { + const result = await classifyChangeStep(baseState, context, { + llm: createFakeLanguageModel('rewrite') as unknown as LLMProvider, + }); + + expect(result.changeType).to.equal(ChangeType.Rewrite); + }); + + it('defaults to major when LLM returns unrecognized text', async () => { + const result = await classifyChangeStep(baseState, context, { + llm: createFakeLanguageModel( + 'unknown classification', + ) as unknown as LLMProvider, + }); + + expect(result.changeType).to.equal(ChangeType.Major); + }); + + it('defaults to major when LLM returns "major"', async () => { + const result = await classifyChangeStep(baseState, context, { + llm: createFakeLanguageModel('major') as unknown as LLMProvider, + }); + + expect(result.changeType).to.equal(ChangeType.Major); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/db-query.workflow.integration.ts b/src/__tests__/db-query/unit/mastra-steps/db-query.workflow.integration.ts new file mode 100644 index 0000000..90fcb54 --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/db-query.workflow.integration.ts @@ -0,0 +1,460 @@ +/** + * DbQuery Mastra Workflow — Integration Test + * + * Verifies the workflow orchestration end-to-end using real step functions and + * real condition functions. Each test wires steps together exactly as + * `MastraDbQueryWorkflow.run()` does, using stubs for LLM and service deps. + * + * Scenarios tested: + * 1. Happy path — SQL generated, both validators pass → accepted on first try + * 2. Retry loop — syntactic failure on attempt 1, fix + revalidate → accepted + * 3. Cache hit (fromCache) — workflow short-circuits at routing point + * 4. Template hit (fromTemplate) — routed to saveDataset immediately + * 5. Max attempts exceeded — failedStep called after MAX_ATTEMPTS feedbacks + * 6. Condition functions — pure routing logic tested independently + */ + +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + EvaluationResult, + GenerationError, + ChangeType, +} from '../../../../components/db-query/types'; +import {MAX_ATTEMPTS} from '../../../../components/db-query/constant'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; +import { + checkPostCacheAndTablesConditions, + checkPostValidationConditions, +} from '../../../../mastra/db-query/workflow/conditions/db-query.conditions'; +import {mergeValidationResults} from '../../../../mastra/db-query/workflow/steps/post-validation.step'; +import {sqlGenerationStep} from '../../../../mastra/db-query/workflow/steps/sql-generation.step'; +import {syntacticValidatorStep} from '../../../../mastra/db-query/workflow/steps/syntactic-validator.step'; +import {semanticValidatorStep} from '../../../../mastra/db-query/workflow/steps/semantic-validator.step'; +import {checkCacheStep} from '../../../../mastra/db-query/workflow/steps/check-cache.step'; +import {classifyChangeStep} from '../../../../mastra/db-query/workflow/steps/classify-change.step'; +import {failedStep} from '../../../../mastra/db-query/workflow/steps/failed.step'; +import {createFakeLanguageModel} from '../../../fixtures/fake-ai-models'; + +// ── Shared stub factories ──────────────────────────────────────────────────── + +function makeSchemaHelper() { + return { + asString: sinon.stub().returns('employees(id, name, salary)'), + getTablesContext: sinon.stub().returns([]), + buildSchema: sinon.stub().returns({tables: {}, relations: []}), + }; +} + +function makeConnector( + shouldFail = false, + errorMsg = 'syntax error near SELECT', +) { + return { + validate: shouldFail + ? sinon.stub().rejects(new Error(errorMsg)) + : sinon.stub().resolves(undefined), + execute: sinon.stub().resolves([]), + }; +} + +function makeTableSearchService() { + return { + search: sinon.stub().resolves([]), + getTables: sinon.stub().resolves([]), + }; +} + +function makeDatasetSearchService() { + return { + search: sinon.stub().resolves([]), + }; +} + +function makeDataSetHelper() { + return { + checkPermissions: sinon.stub().resolves([]), + find: sinon.stub().resolves(null), + }; +} + +const BASE_SCHEMA = { + tables: {employees: {}, departments: {}}, + relations: [], +}; + +function makeBaseState(overrides: Partial = {}): DbQueryState { + return { + prompt: 'Get all employee names', + schema: BASE_SCHEMA, + feedbacks: [], + directCall: false, + } as unknown as DbQueryState & typeof overrides; +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +describe('DbQuery Mastra Workflow — Integration', function () { + let onUsageSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + beforeEach(() => { + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy}; + }); + + // ── 1. Happy path ────────────────────────────────────────────────────────── + describe('happy path: SQL generated and validated on first attempt', function () { + it('returns accepted routing decision and correct SQL', async () => { + const state = makeBaseState(); + const fakeLlm = createFakeLanguageModel( + 'SELECT name FROM employees', + ) as unknown as LLMProvider; + + // Step: SQL generation + const sqlResult = await sqlGenerationStep(state, context, { + sqlLLM: fakeLlm, + cheapLLM: fakeLlm, + config: {db: {dialect: 'pg'}} as never, + schemaHelper: makeSchemaHelper() as never, + }); + const stateAfterSql = Object.assign({}, state, sqlResult); + + expect(stateAfterSql.sql).to.equal('SELECT name FROM employees'); + expect(stateAfterSql.status).to.equal(EvaluationResult.Pass); + + // Step: Syntactic validation — connector succeeds + const syntacticResult = await syntacticValidatorStep( + stateAfterSql, + context, + { + llm: fakeLlm, + connector: makeConnector(false) as never, + }, + ); + const stateAfterSyntactic = Object.assign( + {}, + stateAfterSql, + syntacticResult, + ); + expect(stateAfterSyntactic.syntacticStatus).to.equal( + EvaluationResult.Pass, + ); + + // Step: Semantic validation — LLM returns (what the step looks for) + const semanticLlm = createFakeLanguageModel( + '', + ) as unknown as LLMProvider; + const semanticResult = await semanticValidatorStep( + stateAfterSyntactic, + context, + { + smartLlm: semanticLlm, + cheapLlm: semanticLlm, + config: { + db: {dialect: 'pg'}, + nodes: {semanticValidatorNode: {useSmartLLM: false}}, + } as never, + tableSearchService: makeTableSearchService() as never, + schemaHelper: makeSchemaHelper() as never, + }, + ); + const stateAfterSemantic = Object.assign( + {}, + stateAfterSyntactic, + semanticResult, + ); + + // PostValidation merge + const merged = mergeValidationResults(stateAfterSemantic); + const finalState = Object.assign({}, stateAfterSemantic, merged); + + // Routing decision + const condition = checkPostValidationConditions(finalState); + expect(condition).to.equal('accepted'); + expect(finalState.status).to.equal(EvaluationResult.Pass); + + // Usage was tracked + sinon.assert.called(onUsageSpy); + }); + }); + + // ── 2. Retry loop ────────────────────────────────────────────────────────── + describe('retry loop: syntax failure on attempt 1, pass on attempt 2', function () { + it('routes to fixSql then accepted after fix', async () => { + const state = makeBaseState(); + const fakeSqlResponse = 'SELECT name FROM employees'; + + // ── Attempt 1: syntactic validation fails ──────────────────────────── + const syntacticErrorLlm = createFakeLanguageModel( + // Categorize prompt output: query_erroremployees + 'query_erroremployees', + ) as unknown as LLMProvider; + + // Simulate SQL generation setting status to Pass + const stateWithSql: DbQueryState = Object.assign({}, state, { + sql: fakeSqlResponse, + status: EvaluationResult.Pass, + }) as unknown as DbQueryState; + + const syntacticResult1 = await syntacticValidatorStep( + stateWithSql, + context, + { + llm: syntacticErrorLlm, + connector: makeConnector(true) as never, + }, + ); + const stateAfterSyntacticFail = Object.assign( + {}, + stateWithSql, + syntacticResult1, + {semanticStatus: EvaluationResult.Pass, feedbacks: []}, + ) as unknown as DbQueryState; + + const merged1 = mergeValidationResults(stateAfterSyntacticFail); + const stateAfterMerge1 = Object.assign( + {}, + stateAfterSyntacticFail, + merged1, + ) as unknown as DbQueryState; + + // Condition should route to fixSql + const condition1 = checkPostValidationConditions(stateAfterMerge1); + expect(condition1).to.equal('fixSql'); + expect(stateAfterMerge1.feedbacks?.length).to.be.greaterThan(0); + + // ── Attempt 2: connector succeeds after fix ────────────────────────── + const stateForRetry: DbQueryState = Object.assign({}, stateAfterMerge1, { + sql: 'SELECT name FROM employees WHERE active = true', + }) as unknown as DbQueryState; + + const syntacticResult2 = await syntacticValidatorStep( + stateForRetry, + context, + { + llm: syntacticErrorLlm, + connector: makeConnector(false) as never, // now passes + }, + ); + const semanticPassLlm = createFakeLanguageModel( + '', + ) as unknown as LLMProvider; + const semanticResult2 = await semanticValidatorStep( + Object.assign( + {}, + stateForRetry, + syntacticResult2, + ) as unknown as DbQueryState, + context, + { + smartLlm: semanticPassLlm, + cheapLlm: semanticPassLlm, + config: { + db: {dialect: 'pg'}, + nodes: {semanticValidatorNode: {useSmartLLM: false}}, + } as never, + tableSearchService: makeTableSearchService() as never, + schemaHelper: makeSchemaHelper() as never, + }, + ); + + const stateAfterRetry = Object.assign( + {}, + stateForRetry, + syntacticResult2, + semanticResult2, + ) as unknown as DbQueryState; + + const merged2 = mergeValidationResults(stateAfterRetry); + const finalState = Object.assign({}, stateAfterRetry, merged2); + + const condition2 = checkPostValidationConditions(finalState); + expect(condition2).to.equal('accepted'); + + // Confirm we went through 1 retry (feedbacks from first round survive) + expect(stateAfterMerge1.feedbacks?.length).to.be.greaterThan(0); + }); + }); + + // ── 3. Cache hit ─────────────────────────────────────────────────────────── + describe('cache hit: fromCache short-circuits workflow', function () { + it('checkPostCacheAndTablesConditions returns fromCache when fromCache=true', async () => { + // checkCacheStep parses response as " " — 'as-is 1' means + // AsIs match at index 1 (first document in relevantDocs) + const cacheLlm = createFakeLanguageModel( + 'as-is 1', + ) as unknown as LLMProvider; + const datasetSearchStub = makeDatasetSearchService(); + datasetSearchStub.search.resolves([ + { + pageContent: 'Get all employee names', + metadata: { + description: 'Returns all employee names', + datasetId: '42', + sql: 'SELECT name FROM employees', + }, + }, + ]); + const dataSetHelperStub = makeDataSetHelper(); + // find returns an array; step destructures: const [dataset] = await find() + dataSetHelperStub.find.resolves([ + { + id: '42', + query: 'SELECT name FROM employees', + prompt: 'Get all employee names', + actions: [], + }, + ]); + + const state = makeBaseState(); + const cacheResult = await checkCacheStep(state, context, { + datasetSearch: datasetSearchStub as never, + llm: cacheLlm, + dataSetHelper: dataSetHelperStub as never, + }); + + const stateAfterCache = Object.assign({}, state, cacheResult); + + // Routing condition + const condition = checkPostCacheAndTablesConditions(stateAfterCache); + expect(condition).to.equal('fromCache'); + expect(stateAfterCache.fromCache).to.be.true(); + }); + }); + + // ── 4. Template hit ───────────────────────────────────────────────────────── + describe('template hit: fromTemplate short-circuits workflow', function () { + it('checkPostCacheAndTablesConditions returns fromTemplate when fromTemplate=true', () => { + const state = Object.assign({}, makeBaseState(), { + fromTemplate: true, + sql: 'SELECT * FROM templates', + }) as unknown as DbQueryState; + + const condition = checkPostCacheAndTablesConditions(state); + expect(condition).to.equal('fromTemplate'); + }); + }); + + // ── 5. Max attempts exceeded ─────────────────────────────────────────────── + describe('max attempts guard: failedStep called when feedbacks >= MAX_ATTEMPTS', function () { + it('emits Failed ToolStatus and returns replyToUser', async () => { + const writerSpy = sinon.spy(); + const ctx: MastraDbQueryContext = { + onUsage: onUsageSpy, + writer: writerSpy, + }; + + const feedbacks = Array.from( + {length: MAX_ATTEMPTS}, + (_, i) => `Round ${i + 1}: query_error`, + ); + const state = Object.assign({}, makeBaseState(), { + sql: 'SELECT INVALID', + feedbacks, + status: GenerationError.Failed, + }) as unknown as DbQueryState; + + // Simulate the guard condition check that precedes failedStep + const shouldFail = (state.feedbacks?.length ?? 0) >= MAX_ATTEMPTS; + expect(shouldFail).to.be.true(); + + const result = await failedStep(state, ctx); + expect(result.replyToUser).to.be.a.String(); + expect(result.replyToUser).to.match(/not able to generate/i); + + // Writer must have been called with ToolStatus.Failed + sinon.assert.called(writerSpy); + type WriterEvent = {type: string; data: {status: string}}; + const writtenEvents: WriterEvent[] = writerSpy.args.map( + (args: unknown[]) => args[0] as WriterEvent, + ); + const failedEvent = writtenEvents.find(e => e.data?.status === 'failed'); + expect(failedEvent).to.not.be.undefined(); + expect(failedEvent!.type).to.equal('tool-status'); + }); + }); + + // ── 6. Condition functions (pure) ────────────────────────────────────────── + describe('checkPostCacheAndTablesConditions — pure routing function', function () { + it('returns continue when no special conditions are set', () => { + const state = makeBaseState(); + expect(checkPostCacheAndTablesConditions(state)).to.equal('continue'); + }); + + it('returns failed when status is GenerationError.Failed', () => { + const state = Object.assign({}, makeBaseState(), { + status: GenerationError.Failed, + }) as unknown as DbQueryState; + expect(checkPostCacheAndTablesConditions(state)).to.equal('failed'); + }); + + it('prefers fromTemplate over fromCache', () => { + const state = Object.assign({}, makeBaseState(), { + fromTemplate: true, + fromCache: true, + }) as unknown as DbQueryState; + expect(checkPostCacheAndTablesConditions(state)).to.equal('fromTemplate'); + }); + }); + + describe('checkPostValidationConditions — pure routing function', function () { + it('returns accepted on Pass', () => { + const state = Object.assign({}, makeBaseState(), { + status: EvaluationResult.Pass, + }) as unknown as DbQueryState; + expect(checkPostValidationConditions(state)).to.equal('accepted'); + }); + + it('returns reselectTables on TableError', () => { + const state = Object.assign({}, makeBaseState(), { + status: EvaluationResult.TableError, + }) as unknown as DbQueryState; + expect(checkPostValidationConditions(state)).to.equal('reselectTables'); + }); + + it('returns fixSql on QueryError', () => { + const state = Object.assign({}, makeBaseState(), { + status: EvaluationResult.QueryError, + }) as unknown as DbQueryState; + expect(checkPostValidationConditions(state)).to.equal('fixSql'); + }); + + it('returns failed on unknown status', () => { + const state = Object.assign({}, makeBaseState(), { + status: 'unknown_status', + }) as unknown as DbQueryState; + expect(checkPostValidationConditions(state)).to.equal('failed'); + }); + }); + + // ── 7. classifyChangeStep routing decision ───────────────────────────────── + describe('classifyChangeStep — LLM routes to ChangeType', function () { + it('returns ChangeType.Minor when LLM responds minor', async () => { + const state = Object.assign({}, makeBaseState(), { + sampleSql: 'SELECT name FROM employees', + sampleSqlPrompt: 'Get all employees', + description: 'Returns all employee names', + }) as unknown as DbQueryState; + + const fakeLlm = createFakeLanguageModel( + ChangeType.Minor, + ) as unknown as LLMProvider; + + const result = await classifyChangeStep(state, context, {llm: fakeLlm}); + expect(result.changeType).to.equal(ChangeType.Minor); + }); + + it('returns undefined changeType when no sampleSql (fresh query)', async () => { + const state = makeBaseState(); // no sampleSql + const fakeLlm = createFakeLanguageModel( + 'major', + ) as unknown as LLMProvider; + + const result = await classifyChangeStep(state, context, {llm: fakeLlm}); + // With no sampleSql the step skips the LLM call + expect(result.changeType).to.be.undefined(); + }); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/failed.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/failed.step.unit.ts new file mode 100644 index 0000000..d5266f8 --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/failed.step.unit.ts @@ -0,0 +1,70 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {LLMStreamEventType, ToolStatus} from '../../../../types/events'; +import {failedStep} from '../../../../mastra/db-query/workflow/steps/failed.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; + +describe('failedStep (Mastra)', function () { + let writerSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + beforeEach(() => { + writerSpy = sinon.spy(); + context = {writer: writerSpy} as unknown as MastraDbQueryContext; + }); + + it('emits a ToolStatus.Failed writer event', async () => { + const state = {} as DbQueryState; + await failedStep(state, context); + + sinon.assert.calledOnce(writerSpy); + const call = writerSpy.firstCall.args[0]; + expect(call.type).to.equal(LLMStreamEventType.ToolStatus); + expect(call.data.status).to.equal(ToolStatus.Failed); + }); + + it('returns the existing replyToUser when already set', async () => { + const state = { + replyToUser: 'Custom error message', + } as unknown as DbQueryState; + const result = await failedStep(state, context); + + expect(result.replyToUser).to.equal('Custom error message'); + }); + + it('generates a default replyToUser when not set', async () => { + const state = {} as DbQueryState; + const result = await failedStep(state, context); + + expect(result.replyToUser).to.match(/not able to generate a valid SQL/); + }); + + it('includes feedbacks in the default message when provided', async () => { + const state = { + feedbacks: ['Table not found', 'Syntax error'], + } as unknown as DbQueryState; + const result = await failedStep(state, context); + + expect(result.replyToUser).to.match(/Table not found/); + expect(result.replyToUser).to.match(/Syntax error/); + }); + + it('produces an empty error list when feedbacks is an empty array', async () => { + const state = {feedbacks: []} as unknown as DbQueryState; + const result = await failedStep(state, context); + + expect(result.replyToUser).to.match(/I am sorry/); + // feedbacks.join('\n') on an empty array produces '' — the trailing newline is present + expect(result.replyToUser).to.match( + /These were the errors I encountered:\n$/, + ); + }); + + it('works without a writer in context', async () => { + const ctxNoWriter = {} as MastraDbQueryContext; + const state = {} as DbQueryState; + const result = await failedStep(state, ctxNoWriter); + + expect(result.replyToUser).to.be.a.String(); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/generate-description.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/generate-description.step.unit.ts new file mode 100644 index 0000000..35042fc --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/generate-description.step.unit.ts @@ -0,0 +1,129 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {LLMStreamEventType} from '../../../../types/events'; +import {generateDescriptionStep} from '../../../../mastra/db-query/workflow/steps/generate-description.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; +import {LLMProvider} from '../../../../types'; +import {createFakeStreamingLanguageModel} from '../../../fixtures/fake-ai-models'; + +describe('generateDescriptionStep (Mastra)', function () { + const baseState = { + prompt: 'Show me all employees', + sql: 'SELECT * FROM employees', + schema: {tables: {}, relations: []}, + } as unknown as DbQueryState; + + const fakeSchemaHelper = { + asString: sinon + .stub() + .returns('CREATE TABLE employees (id INT, name TEXT)'), + getTablesContext: sinon.stub().returns(['No special rules']), + }; + + let writerSpy: sinon.SinonSpy; + let onUsageSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + beforeEach(() => { + writerSpy = sinon.spy(); + onUsageSpy = sinon.spy(); + context = { + writer: writerSpy, + onUsage: onUsageSpy, + } as unknown as MastraDbQueryContext; + (fakeSchemaHelper.asString as sinon.SinonStub).resetHistory(); + (fakeSchemaHelper.getTablesContext as sinon.SinonStub).resetHistory(); + }); + + it('returns empty when generateDescription is explicitly disabled', async () => { + const result = await generateDescriptionStep(baseState, context, { + llm: createFakeStreamingLanguageModel( + 'ignored', + ) as unknown as LLMProvider, + config: { + nodes: {sqlGenerationNode: {generateDescription: false}}, + } as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(result).to.deepEqual({}); + sinon.assert.notCalled(onUsageSpy); + }); + + it('returns empty when sql is absent from state', async () => { + const stateNoSql = { + ...baseState, + sql: undefined, + } as unknown as DbQueryState; + + const result = await generateDescriptionStep(stateNoSql, context, { + llm: createFakeStreamingLanguageModel( + 'ignored', + ) as unknown as LLMProvider, + config: {} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(result).to.deepEqual({}); + sinon.assert.notCalled(onUsageSpy); + }); + + it('streams description and returns it in state', async () => { + const result = await generateDescriptionStep(baseState, context, { + llm: createFakeStreamingLanguageModel( + 'Retrieves all employees', + ) as unknown as LLMProvider, + config: {} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(result.description).to.equal('Retrieves all employees'); + }); + + it('calls onUsage with token counts from the stream', async () => { + await generateDescriptionStep(baseState, context, { + llm: createFakeStreamingLanguageModel( + 'desc text', + 20, + 8, + ) as unknown as LLMProvider, + config: {} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + sinon.assert.calledOnce(onUsageSpy); + const [inputTokens, outputTokens] = onUsageSpy.firstCall.args; + expect(inputTokens).to.equal(20); + expect(outputTokens).to.equal(8); + }); + + it('emits ToolStatus writer events for each streamed chunk', async () => { + await generateDescriptionStep(baseState, context, { + llm: createFakeStreamingLanguageModel( + 'hello world', + ) as unknown as LLMProvider, + config: {} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + const toolStatusCalls = writerSpy.args.filter( + args => args[0]?.type === LLMStreamEventType.ToolStatus, + ); + expect(toolStatusCalls.length).to.be.greaterThan(0); + const chunk = toolStatusCalls[0][0].data.thinkingToken; + expect(chunk).to.be.a.String(); + }); + + it('calls schemaHelper.asString and getTablesContext', async () => { + await generateDescriptionStep(baseState, context, { + llm: createFakeStreamingLanguageModel('desc') as unknown as LLMProvider, + config: {} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + sinon.assert.calledOnce(fakeSchemaHelper.asString as sinon.SinonStub); + sinon.assert.calledOnce( + fakeSchemaHelper.getTablesContext as sinon.SinonStub, + ); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/is-improvement.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/is-improvement.step.unit.ts new file mode 100644 index 0000000..3855e5e --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/is-improvement.step.unit.ts @@ -0,0 +1,45 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {IDataSetStore} from '../../../../components/db-query/types'; +import {isImprovementStep} from '../../../../mastra/db-query/workflow/steps/is-improvement.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; + +describe('isImprovementStep (Mastra)', function () { + const context = {} as MastraDbQueryContext; + + it('returns empty when no datasetId in state', async () => { + const state = {prompt: 'Show employees'} as unknown as DbQueryState; + const store = {findById: sinon.stub()} as unknown as IDataSetStore; + + const result = await isImprovementStep(state, context, {store}); + + expect(result).to.deepEqual({}); + sinon.assert.notCalled(store.findById as sinon.SinonStub); + }); + + it('loads dataset and enriches state when datasetId is set', async () => { + const state = { + datasetId: 'ds-1', + prompt: 'also add department filter', + schema: {tables: {}, relations: []}, + } as unknown as DbQueryState; + + const store = { + findById: sinon.stub().resolves({ + query: 'SELECT * FROM employees', + prompt: 'Show all employees', + }), + } as unknown as IDataSetStore; + + const result = await isImprovementStep(state, context, {store}); + + expect(result.sampleSql).to.equal('SELECT * FROM employees'); + expect(result.sampleSqlPrompt).to.equal('Show all employees'); + expect(result.prompt).to.match(/Show all employees/); + expect(result.prompt).to.match(/also add department filter/); + sinon.assert.calledOnceWithExactly( + store.findById as sinon.SinonStub, + 'ds-1', + ); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/post-validation.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/post-validation.step.unit.ts new file mode 100644 index 0000000..ee2b58d --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/post-validation.step.unit.ts @@ -0,0 +1,125 @@ +import {expect} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {EvaluationResult} from '../../../../components/db-query/types'; +import {mergeValidationResults} from '../../../../mastra/db-query/workflow/steps/post-validation.step'; + +describe('mergeValidationResults (Mastra)', function () { + it('returns Pass and clears per-round fields when both validators pass', () => { + const state = { + syntacticStatus: EvaluationResult.Pass, + semanticStatus: EvaluationResult.Pass, + feedbacks: ['old feedback'], + } as unknown as DbQueryState; + + const result = mergeValidationResults(state); + + expect(result.status).to.equal(EvaluationResult.Pass); + expect(result.syntacticStatus).to.be.undefined(); + expect(result.semanticStatus).to.be.undefined(); + expect(result.syntacticFeedback).to.be.undefined(); + expect(result.semanticFeedback).to.be.undefined(); + }); + + it('clears "Query Validation Failed" feedbacks on pass', () => { + const state = { + syntacticStatus: EvaluationResult.Pass, + semanticStatus: EvaluationResult.Pass, + feedbacks: ['Query Validation Failed: syntax error', 'other feedback'], + } as unknown as DbQueryState; + + const result = mergeValidationResults(state); + + expect(result.feedbacks).to.not.containEql( + 'Query Validation Failed: syntax error', + ); + expect(result.feedbacks).to.containEql('other feedback'); + }); + + it('uses syntactic status when syntactic validator fails', () => { + const state = { + syntacticStatus: 'query_error', + syntacticFeedback: 'Query Validation Failed: bad syntax', + feedbacks: [], + } as unknown as DbQueryState; + + const result = mergeValidationResults(state); + + expect(result.status).to.equal('query_error'); + expect(result.feedbacks).to.containEql( + 'Query Validation Failed: bad syntax', + ); + }); + + it('uses semantic status when only semantic validator fails', () => { + const state = { + syntacticStatus: EvaluationResult.Pass, + semanticStatus: 'wrong_result', + semanticFeedback: 'Data does not match expected output', + feedbacks: [], + } as unknown as DbQueryState; + + const result = mergeValidationResults(state); + + expect(result.status).to.equal('wrong_result'); + expect(result.feedbacks).to.containEql( + 'Data does not match expected output', + ); + }); + + it('prefers syntactic over semantic when both fail', () => { + const state = { + syntacticStatus: 'query_error', + syntacticFeedback: 'Syntactic issue', + semanticStatus: 'wrong_result', + semanticFeedback: 'Semantic issue', + feedbacks: [], + } as unknown as DbQueryState; + + const result = mergeValidationResults(state); + + expect(result.status).to.equal('query_error'); + expect(result.feedbacks).to.containEql('Syntactic issue'); + expect(result.feedbacks).to.containEql('Semantic issue'); + }); + + it('accumulates feedbacks from previous rounds', () => { + const state = { + syntacticStatus: 'query_error', + syntacticFeedback: 'New error', + feedbacks: ['Previous round error'], + } as unknown as DbQueryState; + + const result = mergeValidationResults(state); + + expect(result.feedbacks).to.containEql('Previous round error'); + expect(result.feedbacks).to.containEql('New error'); + }); + + it('merges errorTables from both validators', () => { + const state = { + syntacticStatus: 'query_error', + syntacticErrorTables: ['employees'], + semanticErrorTables: ['departments'], + feedbacks: [], + } as unknown as DbQueryState; + + const result = mergeValidationResults(state); + + expect(result.syntacticErrorTables).to.containDeep([ + 'employees', + 'departments', + ]); + expect(result.semanticErrorTables).to.containDeep([ + 'employees', + 'departments', + ]); + }); + + it('handles state with no validators run (both undefined)', () => { + const state = {feedbacks: []} as unknown as DbQueryState; + const result = mergeValidationResults(state); + + // Both undefined → treated as pass + expect(result.status).to.equal(EvaluationResult.Pass); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/sql-generation.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/sql-generation.step.unit.ts new file mode 100644 index 0000000..4c71563 --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/sql-generation.step.unit.ts @@ -0,0 +1,123 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + ChangeType, + EvaluationResult, + GenerationError, +} from '../../../../components/db-query/types'; +import {LLMProvider} from '../../../../types'; +import {sqlGenerationStep} from '../../../../mastra/db-query/workflow/steps/sql-generation.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; +import {createFakeLanguageModel} from '../../../fixtures/fake-ai-models'; + +const fakeSchemaHelper = { + asString: () => 'employees(id, name)', + getTablesContext: () => [], +}; + +describe('sqlGenerationStep (Mastra)', function () { + let onUsageSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + const baseState = { + prompt: 'Get all employee names', + schema: {tables: {employees: {}}, relations: []}, + } as unknown as DbQueryState; + + beforeEach(() => { + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy}; + }); + + it('generates SQL successfully and calls onUsage', async () => { + const state = { + ...baseState, + schema: {tables: {employees: {}, departments: {}}, relations: []}, + } as unknown as DbQueryState; + + const result = await sqlGenerationStep(state, context, { + sqlLLM: createFakeLanguageModel( + 'SELECT name FROM employees', + ) as unknown as LLMProvider, + cheapLLM: createFakeLanguageModel( + 'SELECT name FROM employees', + ) as unknown as LLMProvider, + config: {db: {dialect: 'pg'}} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(result.sql).to.equal('SELECT name FROM employees'); + expect(result.status).to.equal(EvaluationResult.Pass); + sinon.assert.calledOnce(onUsageSpy); + const [inputTokens, outputTokens, model] = onUsageSpy.firstCall.args; + expect(inputTokens).to.be.a.Number(); + expect(outputTokens).to.be.a.Number(); + expect(model).to.be.a.String(); + }); + + it('uses cheapLLM for single-table queries', async () => { + const cheapModel = createFakeLanguageModel('SELECT name FROM employees'); + const smartModel = createFakeLanguageModel('SELECT name FROM employees'); + const cheapSpy = sinon.spy(cheapModel, 'doGenerate'); + const smartSpy = sinon.spy(smartModel, 'doGenerate'); + + await sqlGenerationStep(baseState, context, { + sqlLLM: smartModel as unknown as LLMProvider, + cheapLLM: cheapModel as unknown as LLMProvider, + config: {db: {dialect: 'pg'}} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(cheapSpy.calledOnce).to.be.true(); + expect(smartSpy.called).to.be.false(); + }); + + it('uses cheapLLM for ChangeType.Minor', async () => { + const state = { + ...baseState, + schema: {tables: {employees: {}, departments: {}}, relations: []}, + changeType: ChangeType.Minor, + } as unknown as DbQueryState; + const cheapModel = createFakeLanguageModel('SELECT name FROM employees'); + const smartModel = createFakeLanguageModel('SELECT name FROM employees'); + const cheapSpy = sinon.spy(cheapModel, 'doGenerate'); + const smartSpy = sinon.spy(smartModel, 'doGenerate'); + + await sqlGenerationStep(state, context, { + sqlLLM: smartModel as unknown as LLMProvider, + cheapLLM: cheapModel as unknown as LLMProvider, + config: {db: {dialect: 'pg'}} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(cheapSpy.calledOnce).to.be.true(); + expect(smartSpy.called).to.be.false(); + }); + + it('returns Failed status when LLM returns empty SQL', async () => { + const result = await sqlGenerationStep(baseState, context, { + sqlLLM: createFakeLanguageModel(' ') as unknown as LLMProvider, + cheapLLM: createFakeLanguageModel(' ') as unknown as LLMProvider, + config: {db: {dialect: 'pg'}} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(result.status).to.equal(GenerationError.Failed); + expect(result.sql).to.be.undefined(); + }); + + it('strips markdown code fences from SQL output', async () => { + const result = await sqlGenerationStep(baseState, context, { + sqlLLM: createFakeLanguageModel( + '```sql\nSELECT name FROM employees\n```', + ) as unknown as LLMProvider, + cheapLLM: createFakeLanguageModel( + '```sql\nSELECT name FROM employees\n```', + ) as unknown as LLMProvider, + config: {db: {dialect: 'pg'}} as never, + schemaHelper: fakeSchemaHelper as never, + }); + + expect(result.sql).to.equal('SELECT name FROM employees'); + }); +}); diff --git a/src/__tests__/db-query/unit/mastra-steps/syntactic-validator.step.unit.ts b/src/__tests__/db-query/unit/mastra-steps/syntactic-validator.step.unit.ts new file mode 100644 index 0000000..b1cd6c5 --- /dev/null +++ b/src/__tests__/db-query/unit/mastra-steps/syntactic-validator.step.unit.ts @@ -0,0 +1,91 @@ +import {expect, sinon} from '@loopback/testlab'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + EvaluationResult, + IDbConnector, +} from '../../../../components/db-query/types'; +import {LLMProvider} from '../../../../types'; +import {syntacticValidatorStep} from '../../../../mastra/db-query/workflow/steps/syntactic-validator.step'; +import {MastraDbQueryContext} from '../../../../mastra/db-query/types/db-query.types'; +import {createFakeLanguageModel} from '../../../fixtures/fake-ai-models'; + +describe('syntacticValidatorStep (Mastra)', function () { + let connectorStub: {validate: sinon.SinonStub}; + let onUsageSpy: sinon.SinonSpy; + let context: MastraDbQueryContext; + + const baseState = { + sql: 'SELECT name FROM employees', + schema: {tables: {employees: {}, departments: {}}, relations: []}, + } as unknown as DbQueryState; + + beforeEach(() => { + connectorStub = {validate: sinon.stub()}; + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy}; + }); + + it('returns Pass when connector validates successfully', async () => { + connectorStub.validate.resolves(); + + const result = await syntacticValidatorStep(baseState, context, { + llm: createFakeLanguageModel('') as unknown as LLMProvider, + connector: connectorStub as unknown as IDbConnector, + }); + + expect(result.syntacticStatus).to.equal(EvaluationResult.Pass); + sinon.assert.notCalled(onUsageSpy); + }); + + it('calls LLM to categorize error when connector throws', async () => { + connectorStub.validate.rejects( + new Error('relation "employees" does not exist'), + ); + + const result = await syntacticValidatorStep(baseState, context, { + llm: createFakeLanguageModel( + 'table_not_foundemployees', + ) as unknown as LLMProvider, + connector: connectorStub as unknown as IDbConnector, + }); + + expect(result.syntacticStatus).to.equal('table_not_found'); + expect(result.syntacticErrorTables).to.deepEqual(['employees']); + sinon.assert.calledOnce(onUsageSpy); + const [inputTokens, outputTokens] = onUsageSpy.firstCall.args; + expect(inputTokens).to.be.a.Number(); + expect(outputTokens).to.be.a.Number(); + }); + + it('categorizes as query_error and parses tables correctly', async () => { + connectorStub.validate.rejects( + new Error('syntax error at or near "SELCT"'), + ); + + const result = await syntacticValidatorStep(baseState, context, { + llm: createFakeLanguageModel( + 'query_erroremployees, departments', + ) as unknown as LLMProvider, + connector: connectorStub as unknown as IDbConnector, + }); + + expect(result.syntacticStatus).to.equal('query_error'); + expect(result.syntacticErrorTables).to.deepEqual([ + 'employees', + 'departments', + ]); + }); + + it('includes syntacticFeedback in the result on failure', async () => { + connectorStub.validate.rejects(new Error('some db error')); + + const result = await syntacticValidatorStep(baseState, context, { + llm: createFakeLanguageModel( + 'query_error', + ) as unknown as LLMProvider, + connector: connectorStub as unknown as IDbConnector, + }); + + expect(result.syntacticFeedback).to.match(/Query Validation Failed/); + }); +}); diff --git a/src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts b/src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts deleted file mode 100644 index c3de19b..0000000 --- a/src/__tests__/db-query/unit/nodes/check-cache.node.unit.ts +++ /dev/null @@ -1,305 +0,0 @@ -import {BaseRetriever} from '@langchain/core/retrievers'; -import { - createStubInstance, - expect, - sinon, - StubbedInstanceWithSinonAccessor, -} from '@loopback/testlab'; -import { - CacheResults, - CheckCacheNode, - DatasetActionType, - DataSetHelper, - DbQueryState, - QueryCacheMetadata, -} from '../../../../components'; -import {RuntimeLLMProvider} from '../../../../types'; - -describe('CheckCacheNode Unit', function () { - let node: CheckCacheNode; - let cacheStub: sinon.SinonStub; - let llmStub: sinon.SinonStub; - let datasetHelperStub: StubbedInstanceWithSinonAccessor; - - beforeEach(() => { - cacheStub = sinon.stub(); - llmStub = sinon.stub(); - datasetHelperStub = createStubInstance(DataSetHelper); - const cache = { - invoke: cacheStub, - } as unknown as BaseRetriever; - const llm = llmStub as unknown as RuntimeLLMProvider; - - node = new CheckCacheNode(cache, llm, datasetHelperStub); - datasetHelperStub.stubs.checkPermissions.resolves([]); - }); - - it('should return state as it is if no relevant query found in cache', async () => { - llmStub.resolves({ - content: CacheResults.NotRelevant, - }); - cacheStub.resolves([]); - const state = { - prompt: 'What is the salary of Akshat?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({}); - }); - - it('should return state with sampleSql if relevant query found in cache', async () => { - llmStub.resolves({ - content: CacheResults.Similar + ' 1', - }); - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: {query: `SELECT * FROM employees WHERE name = 'Akshat'`}, - }, - ]); - const state = { - prompt: 'What is the salary of Dhruv?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({ - sampleSql: "SELECT * FROM employees WHERE name = 'Akshat'", - sampleSqlPrompt: 'What is the salary of Akshat?', - }); - }); - - it('should return state with datasetId and fromCache true if exact query found in cache with matching permissions, and if user has liked it in the past', async () => { - llmStub.resolves({ - content: CacheResults.AsIs + ' 1', - }); - datasetHelperStub.stubs.checkPermissions.resolves([]); - datasetHelperStub.stubs.find.resolves([ - { - id: '123', - description: 'What is the salary of Akshat?', - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - prompt: 'What is the salary of Akshat?', - createdBy: 'test-user', - votes: 0, - tables: ['employees'], - schemaHash: 'hash', - tenantId: 'test-tenant', - actions: undefined, - }, - ]); - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: { - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - datasetId: '123', - }, - }, - ]); - const state = { - prompt: 'What is the salary of Akshat?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({ - fromCache: true, - datasetId: '123', - replyToUser: `I found this dataset in the cache - What is the salary of Akshat?`, - }); - }); - - it('should return state with datasetId and fromCache true if exact query found in cache with matching permissions, and if user has not seen it in past', async () => { - llmStub.resolves({ - content: CacheResults.AsIs + ' 1', - }); - datasetHelperStub.stubs.checkPermissions.resolves([]); - datasetHelperStub.stubs.find.resolves([ - { - id: '123', - description: 'What is the salary of Akshat?', - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - prompt: 'What is the salary of Akshat?', - createdBy: 'test-user', - votes: 0, - tables: ['employees'], - schemaHash: 'hash', - tenantId: 'test-tenant', - actions: [ - { - action: DatasetActionType.Liked, - datasetId: '123', - userId: 'test-user', - }, - ], - }, - ]); - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: { - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - datasetId: '123', - }, - }, - ]); - const state = { - prompt: 'What is the salary of Akshat?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({ - fromCache: true, - datasetId: '123', - replyToUser: `I found this dataset in the cache - What is the salary of Akshat?`, - }); - }); - - it('should not return state with datasetId and fromCache true even if exact query found in cache with matching permissions, if it was disliked by the user', async () => { - llmStub.resolves({ - content: CacheResults.AsIs + ' 1', - }); - datasetHelperStub.stubs.checkPermissions.resolves([]); - datasetHelperStub.stubs.find.resolves([ - { - id: '123', - description: 'What is the salary of Akshat?', - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - prompt: 'What is the salary of Akshat?', - createdBy: 'test-user', - votes: 0, - tables: ['employees'], - schemaHash: 'hash', - tenantId: 'test-tenant', - actions: [ - { - action: DatasetActionType.Disliked, - datasetId: '123', - userId: 'test-user', - }, - ], - }, - ]); - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: { - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - datasetId: '123', - }, - }, - ]); - const state = { - prompt: 'What is the salary of Akshat?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({}); - }); - - it('should return existing state if exact query found in cache but with missing permissions', async () => { - llmStub.resolves({ - content: `${CacheResults.AsIs} 1`, - }); - datasetHelperStub.stubs.checkPermissions.resolves(['some permission']); - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: { - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - datasetId: '123', - }, - }, - ]); - const state = { - prompt: 'What is the salary of Akshat?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({}); - }); - - it('should return state as is if sampleSql already exists', async () => { - const state = { - prompt: 'What is the salary of Akshat?', - sampleSql: 'SELECT salary FROM employees WHERE name = "existing"', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({}); - sinon.assert.notCalled(cacheStub); - sinon.assert.notCalled(llmStub); - }); - - it('should return state as is if LLM returns invalid index', async () => { - llmStub.resolves({ - content: `${CacheResults.AsIs} 5`, - }); // Index out of bounds - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: { - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - datasetId: '123', - }, - }, - ]); - const state = { - prompt: 'What is the salary of Akshat?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({}); - }); - - it('should return state as is if LLM returns non-numeric index', async () => { - llmStub.resolves({ - content: `${CacheResults.AsIs} abc`, - }); - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: { - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - datasetId: '123', - }, - }, - ]); - const state = { - prompt: 'What is the salary of Akshat?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({}); - }); - - it('should return state as is if LLM returns not-relevant', async () => { - llmStub.resolves({ - content: `${CacheResults.NotRelevant} 1`, - }); - cacheStub.resolves([ - { - pageContent: 'What is the salary of Akshat?', - metadata: { - query: `SELECT * FROM employees WHERE name = 'Akshat'`, - datasetId: '123', - }, - }, - ]); - const state = { - prompt: 'What is the salary of Dhruv?', - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({}); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts b/src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts deleted file mode 100644 index 835ae8c..0000000 --- a/src/__tests__/db-query/unit/nodes/check-permission.node.unit.ts +++ /dev/null @@ -1,78 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import {IAuthUserWithPermissions} from 'loopback4-authorization'; -import { - CheckPermissionsNode, - DbQueryState, - Errors, - PermissionHelper, -} from '../../../../components'; -import {RuntimeLLMProvider} from '../../../../types'; -import {Currency, Employee, ExchangeRate} from '../../../fixtures/models'; - -describe('CheckPermissionsNode Unit', function () { - let node: CheckPermissionsNode; - let llmStub: sinon.SinonStub; - - beforeEach(() => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - const permissionHelper = new PermissionHelper( - { - models: [ - { - model: Employee, - readPermissionKey: '1', - }, - { - model: ExchangeRate, - readPermissionKey: '2', - }, - { - model: Currency, - readPermissionKey: '3', - }, - ], - }, - { - tenantId: 'test-tenant', - userTenantId: 'test-tenant', - permissions: ['1'], - } as unknown as IAuthUserWithPermissions, - ); - node = new CheckPermissionsNode(llm, permissionHelper); - }); - - it('should return state as it is if no permission is missing', async () => { - const state = { - schema: { - tables: { - employees: {}, - }, - }, - } as unknown as DbQueryState; - const result = await node.execute(state, {}); - expect(result).to.deepEqual(state); - }); - - it('should permission error status when a permission is missing', async () => { - llmStub.resolves({ - content: - 'You do not have permissions to access the required tables and cannot proceed with the request. Please provide a new request.', - }); - const state = { - schema: { - tables: { - // eslint-disable-next-line @typescript-eslint/naming-convention - exchange_rates: {}, - }, - }, - } as unknown as DbQueryState; - const result = await node.execute(state, {}); - expect(result).to.deepEqual({ - ...state, - status: Errors.PermissionError, - replyToUser: - 'You do not have permissions to access the required tables and cannot proceed with the request. Please provide a new request.', - }); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts b/src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts deleted file mode 100644 index 5ef9746..0000000 --- a/src/__tests__/db-query/unit/nodes/classify-change.node.unit.ts +++ /dev/null @@ -1,159 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {ChangeType, ClassifyChangeNode} from '../../../../components'; -import {DbQueryState} from '../../../../components/db-query/state'; -import {RuntimeLLMProvider} from '../../../../types'; - -describe('ClassifyChangeNode Unit', function () { - let node: ClassifyChangeNode; - let llmStub: sinon.SinonStub; - - beforeEach(() => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - node = new ClassifyChangeNode(llm); - }); - - afterEach(() => { - sinon.restore(); - }); - - it('should return empty state when sampleSql is not present', async () => { - const state = { - prompt: 'Get all users', - schema: {tables: {}, relations: []}, - sampleSql: undefined, - sampleSqlPrompt: undefined, - } as unknown as DbQueryState; - - const result = await node.execute(state, {} as LangGraphRunnableConfig); - - expect(result).to.deepEqual({}); - sinon.assert.notCalled(llmStub); - }); - - it('should classify as Minor for small changes', async () => { - llmStub.resolves({ - content: 'minor', - }); - - const state = { - prompt: 'Get users with age > 25', - schema: {tables: {}, relations: []}, - sampleSql: 'SELECT * FROM users WHERE age > 20', - sampleSqlPrompt: 'Get users with age > 20', - } as unknown as DbQueryState; - - const result = await node.execute(state, {} as LangGraphRunnableConfig); - - expect(result.changeType).to.equal(ChangeType.Minor); - sinon.assert.calledOnce(llmStub); - }); - - it('should classify as Major for structural changes', async () => { - llmStub.resolves({ - content: 'major', - }); - - const state = { - prompt: 'Get users with their orders and total amount', - schema: {tables: {}, relations: []}, - sampleSql: 'SELECT * FROM users', - sampleSqlPrompt: 'Get all users', - } as unknown as DbQueryState; - - const result = await node.execute(state, {} as LangGraphRunnableConfig); - - expect(result.changeType).to.equal(ChangeType.Major); - sinon.assert.calledOnce(llmStub); - }); - - it('should classify as Rewrite for fundamentally different queries', async () => { - llmStub.resolves({ - content: 'rewrite', - }); - - const state = { - prompt: 'Get monthly revenue breakdown by product category', - schema: {tables: {}, relations: []}, - sampleSql: 'SELECT * FROM users', - sampleSqlPrompt: 'Get all users', - } as unknown as DbQueryState; - - const result = await node.execute(state, {} as LangGraphRunnableConfig); - - expect(result.changeType).to.equal(ChangeType.Rewrite); - sinon.assert.calledOnce(llmStub); - }); - - it('should default to Major for unrecognized LLM responses', async () => { - llmStub.resolves({ - content: 'something unexpected', - }); - - const state = { - prompt: 'Get users', - schema: {tables: {}, relations: []}, - sampleSql: 'SELECT * FROM users', - sampleSqlPrompt: 'Get all users', - } as unknown as DbQueryState; - - const result = await node.execute(state, {} as LangGraphRunnableConfig); - - expect(result.changeType).to.equal(ChangeType.Major); - }); - - it('should pass original and new descriptions to the LLM', async () => { - llmStub.resolves({ - content: 'minor', - }); - - const state = { - prompt: 'Get users with age > 30', - schema: {tables: {}, relations: []}, - sampleSql: 'SELECT * FROM users WHERE age > 20', - sampleSqlPrompt: 'Get users with age > 20', - } as unknown as DbQueryState; - - await node.execute(state, {} as LangGraphRunnableConfig); - - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.containEql('Get users with age > 20'); - expect(prompt.value).to.containEql('Get users with age > 30'); - }); - - it('should handle empty sampleSqlPrompt gracefully', async () => { - llmStub.resolves({ - content: 'major', - }); - - const state = { - prompt: 'Get all users', - schema: {tables: {}, relations: []}, - sampleSql: 'SELECT * FROM users', - sampleSqlPrompt: undefined, - } as unknown as DbQueryState; - - const result = await node.execute(state, {} as LangGraphRunnableConfig); - - expect(result.changeType).to.equal(ChangeType.Major); - sinon.assert.calledOnce(llmStub); - }); - - it('should handle LLM response with extra whitespace and casing', async () => { - llmStub.resolves({ - content: ' Minor \n', - }); - - const state = { - prompt: 'Get users with age > 25', - schema: {tables: {}, relations: []}, - sampleSql: 'SELECT * FROM users WHERE age > 20', - sampleSqlPrompt: 'Get users with age > 20', - } as unknown as DbQueryState; - - const result = await node.execute(state, {} as LangGraphRunnableConfig); - - expect(result.changeType).to.equal(ChangeType.Minor); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/failed.node.unit.ts b/src/__tests__/db-query/unit/nodes/failed.node.unit.ts deleted file mode 100644 index f0a20c0..0000000 --- a/src/__tests__/db-query/unit/nodes/failed.node.unit.ts +++ /dev/null @@ -1,60 +0,0 @@ -import {expect} from '@loopback/testlab'; -import {DbQueryState, FailedNode} from '../../../../components'; - -describe('FailedNode Unit', function () { - let node: FailedNode; - - beforeEach(() => { - node = new FailedNode(); - }); - - it('should return state as it is if it has replyToUser set', async () => { - const state = { - schema: { - tables: { - employees: {}, - }, - }, - replyToUser: 'Test reply', - } as unknown as DbQueryState; - const result = await node.execute(state, {}); - expect(result).to.deepEqual(state); - }); - - it('should return state with feedbacks based response if replyToUser is not set', async () => { - const state = { - schema: { - tables: { - employees: {}, - }, - }, - feedbacks: ['Error 1', 'Error 2'], - } as unknown as DbQueryState; - const result = await node.execute(state, {}); - expect(result).to.deepEqual({ - ...state, - replyToUser: - 'I am sorry, I was not able to generate a valid SQL query for your request. Please try again with a more detailed or a more specific prompt.\n' + - 'These were the errors I encountered:\n' + - 'Error 1\n' + - 'Error 2', - }); - }); - - it('should return state with default feedbacks if no feedbacks are provided', async () => { - const state = { - schema: { - tables: { - employees: {}, - }, - }, - } as unknown as DbQueryState; - const result = await node.execute(state, {}); - expect(result).to.deepEqual({ - ...state, - replyToUser: - 'I am sorry, I was not able to generate a valid SQL query for your request. Please try again with a more detailed or a more specific prompt.\n' + - 'These were the errors I encountered:\nNo errors reported.', - }); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts b/src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts deleted file mode 100644 index 4bac2ae..0000000 --- a/src/__tests__/db-query/unit/nodes/fix-query.node.unit.ts +++ /dev/null @@ -1,379 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import { - EvaluationResult, - FixQueryNode, - GenerationError, -} from '../../../../components'; -import {DbSchemaHelperService} from '../../../../components/db-query/services'; -import {DbQueryState} from '../../../../components/db-query/state'; -import {RuntimeLLMProvider, SupportedDBs} from '../../../../types'; - -describe('FixQueryNode Unit', function () { - let node: FixQueryNode; - let llmStub: sinon.SinonStub; - let schemaHelper: DbSchemaHelperService; - - beforeEach(() => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - schemaHelper = { - asString: sinon.stub().returns('CREATE TABLE users (id INT, name TEXT);'), - getTablesContext: sinon.stub().returns([]), - } as unknown as DbSchemaHelperService; - - node = new FixQueryNode( - llm, - { - db: {dialect: SupportedDBs.PostgreSQL}, - models: [], - }, - schemaHelper, - ); - }); - - afterEach(() => { - sinon.restore(); - }); - - it('should fix a query and return Pass status with corrected SQL', async () => { - llmStub.resolves({ - content: 'SELECT id, name FROM users WHERE id = 1;', - }); - - const state = { - prompt: 'Get user by id 1', - sql: 'SELECT id, nama FROM users WHERE id = 1;', - schema: { - tables: { - users: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Users table', - context: [], - hash: 'hash1', - }, - }, - relations: [], - }, - feedbacks: [ - 'Query Validation Failed by DB: query_error with error column "nama" does not exist', - ], - syntacticErrorTables: ['users'], - semanticErrorTables: undefined, - validationChecklist: undefined, - }; - - const result = await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - expect(result.status).to.equal(EvaluationResult.Pass); - expect(result.sql).to.equal('SELECT id, name FROM users WHERE id = 1;'); - sinon.assert.calledOnce(llmStub); - }); - - it('should return Failed status when LLM returns empty response', async () => { - llmStub.resolves({ - content: '', - }); - - const state = { - prompt: 'Get user by id 1', - sql: 'SELECT * FROM users', - schema: {tables: {users: {}}, relations: []}, - feedbacks: ['Some error'], - syntacticErrorTables: ['users'], - semanticErrorTables: undefined, - validationChecklist: undefined, - }; - - const result = await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - expect(result.status).to.equal(GenerationError.Failed); - expect(result.replyToUser).to.containEql('Failed to fix SQL query'); - }); - - it('should strip markdown code fences from the LLM response', async () => { - llmStub.resolves({ - content: '```sql\nSELECT * FROM users WHERE id = 1;\n```', - }); - - const state = { - prompt: 'Get user by id 1', - sql: 'SELECT * FROM users WHERE id = ?', - schema: { - tables: { - users: { - columns: {id: {type: 'number'}}, - primaryKey: ['id'], - description: '', - context: [], - hash: '', - }, - }, - relations: [], - }, - feedbacks: ['Some error'], - syntacticErrorTables: ['users'], - semanticErrorTables: undefined, - validationChecklist: undefined, - }; - - const result = await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - expect(result.sql).to.equal('SELECT * FROM users WHERE id = 1;'); - }); - - it('should trim schema to only error-related tables', async () => { - llmStub.resolves({ - content: 'SELECT u.id, u.name FROM users u;', - }); - - const state = { - prompt: 'Get users', - sql: 'SELECT u.id, u.nama FROM users u JOIN orders o ON u.id = o.user_id;', - schema: { - tables: { - users: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Users table', - context: [], - hash: 'hash1', - }, - orders: { - columns: { - id: {type: 'number', required: true, id: true}, - // eslint-disable-next-line @typescript-eslint/naming-convention - user_id: {type: 'number', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Orders table', - context: [], - hash: 'hash2', - }, - }, - relations: [ - { - table: 'orders', - column: 'user_id', - referencedTable: 'users', - referencedColumn: 'id', - }, - { - table: 'products', - column: 'category_id', - referencedTable: 'categories', - referencedColumn: 'id', - }, - ], - }, - feedbacks: ['Column nama not found in users'], - syntacticErrorTables: ['users'], - semanticErrorTables: undefined, - validationChecklist: undefined, - }; - - await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - // Verify schemaHelper.asString was called with trimmed schema containing only error tables - const asStringStub = schemaHelper.asString as sinon.SinonStub; - const trimmedSchema = asStringStub.firstCall.args[0]; - expect(Object.keys(trimmedSchema.tables)).to.deepEqual(['users']); - expect(trimmedSchema.relations).to.have.length(1); - expect(trimmedSchema.relations[0].table).to.equal('orders'); - expect(trimmedSchema.relations[0].referencedTable).to.equal('users'); - }); - - it('should merge syntactic and semantic error tables', async () => { - llmStub.resolves({ - content: 'SELECT * FROM users JOIN orders ON users.id = orders.user_id;', - }); - - const state = { - prompt: 'Get users with orders', - sql: 'SELECT * FROM users JOIN orders ON users.id = orders.uid;', - schema: { - tables: { - users: { - columns: {}, - primaryKey: [], - description: '', - context: [], - hash: '', - }, - orders: { - columns: {}, - primaryKey: [], - description: '', - context: [], - hash: '', - }, - products: { - columns: {}, - primaryKey: [], - description: '', - context: [], - hash: '', - }, - }, - relations: [], - }, - feedbacks: ['Error in query'], - syntacticErrorTables: ['users'], - semanticErrorTables: ['orders'], - validationChecklist: undefined, - }; - - await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - const asStringStub = schemaHelper.asString as sinon.SinonStub; - const trimmedSchema = asStringStub.firstCall.args[0]; - expect(Object.keys(trimmedSchema.tables).sort()).to.deepEqual([ - 'orders', - 'users', - ]); - }); - - it('should include validation checklist in the prompt when available', async () => { - llmStub.resolves({ - content: 'SELECT * FROM users;', - }); - - const state = { - prompt: 'Get all users', - sql: 'SELECT * FROM usr;', - schema: {tables: {users: {}}, relations: []}, - feedbacks: ['Table usr not found'], - syntacticErrorTables: ['users'], - semanticErrorTables: undefined, - validationChecklist: - '1. Always use full table names\n2. Include id column', - }; - - await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.containEql('Always use full table names'); - expect(prompt.value).to.containEql('Include id column'); - }); - - it('should include historical errors in the prompt when multiple feedbacks exist', async () => { - llmStub.resolves({ - content: 'SELECT * FROM users WHERE id = 1;', - }); - - const state = { - prompt: 'Get user by id', - sql: 'SELECT * FROM users WHERE id == 1;', - schema: {tables: {users: {}}, relations: []}, - feedbacks: [ - 'First error: syntax issue', - 'Second error: wrong operator', - 'Third error: still wrong', - ], - syntacticErrorTables: ['users'], - semanticErrorTables: undefined, - validationChecklist: undefined, - }; - - await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - const prompt = llmStub.firstCall.args[0]; - // Last feedback is the current error - expect(prompt.value).to.containEql('Third error: still wrong'); - // Historical errors should be included - expect(prompt.value).to.containEql('First error: syntax issue'); - expect(prompt.value).to.containEql('Second error: wrong operator'); - }); - - it('should handle empty error tables gracefully', async () => { - llmStub.resolves({ - content: 'SELECT * FROM users;', - }); - - const state = { - prompt: 'Get all users', - sql: 'SELECT * FROM users', - schema: { - tables: { - users: { - columns: {}, - primaryKey: [], - description: '', - context: [], - hash: '', - }, - }, - relations: [], - }, - feedbacks: ['Some validation error'], - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - validationChecklist: undefined, - }; - - await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - const asStringStub = schemaHelper.asString as sinon.SinonStub; - const trimmedSchema = asStringStub.firstCall.args[0]; - expect(Object.keys(trimmedSchema.tables)).to.deepEqual([]); - }); - - it('should pass the current query and prompt to the LLM', async () => { - llmStub.resolves({ - content: 'SELECT * FROM users;', - }); - - const state = { - prompt: 'Get all active users', - sql: 'SELECT * FROM usr WHERE active = true;', - schema: {tables: {users: {}}, relations: []}, - feedbacks: ['Table usr does not exist'], - syntacticErrorTables: ['users'], - semanticErrorTables: undefined, - validationChecklist: undefined, - }; - - await node.execute( - state as unknown as DbQueryState, - {} as LangGraphRunnableConfig, - ); - - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.containEql('Get all active users'); - expect(prompt.value).to.containEql( - 'SELECT * FROM usr WHERE active = true;', - ); - expect(prompt.value).to.containEql('Table usr does not exist'); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts b/src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts deleted file mode 100644 index 5c20ece..0000000 --- a/src/__tests__/db-query/unit/nodes/get-columns.node.unit.ts +++ /dev/null @@ -1,169 +0,0 @@ -import {juggler} from '@loopback/repository'; -import {expect, sinon} from '@loopback/testlab'; -import { - DbQueryState, - DbSchemaHelperService, - GenerationError, - GetColumnsNode, - SqliteConnector, -} from '../../../../components'; -import {RuntimeLLMProvider} from '../../../../types'; -import {Employee, ExchangeRate} from '../../../fixtures/models'; -import {IAuthUserWithPermissions} from 'loopback4-authorization'; - -describe('GetColumnsNode Unit', function () { - let node: GetColumnsNode; - let llmStub: sinon.SinonStub; - let schemaHelper: DbSchemaHelperService; - - beforeEach(async () => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - - schemaHelper = new DbSchemaHelperService( - new SqliteConnector( - new juggler.DataSource({ - connector: 'sqlite3', - file: ':memory:', - name: 'db', - debug: true, - }), - {} as unknown as IAuthUserWithPermissions, - ), - {models: []}, - ); - - node = new GetColumnsNode( - llm, - schemaHelper, - { - models: [], - columnSelection: true, - }, - ['test context'], - ); - }); - - it('should return state with filtered schema containing only relevant columns', async () => { - const originalSchema = schemaHelper.modelToSchema('', [ - Employee, - ExchangeRate, - ]); - - // Create a state with filtered tables (simulating output from get-tables node) - const filteredSchema = { - tables: { - employees: originalSchema.tables.employees, - // eslint-disable-next-line @typescript-eslint/naming-convention - exchange_rates: originalSchema.tables.exchange_rates, - }, - relations: originalSchema.relations.filter( - r => - (r.table === 'employees' || r.table === 'exchange_rates') && - (r.referencedTable === 'employees' || - r.referencedTable === 'exchange_rates'), - ), - }; - - const state = { - prompt: 'Get me the employee with name Akshat and their salary in USD', - schema: filteredSchema, - } as unknown as DbQueryState; - - // Mock LLM response with selected columns - llmStub.resolves({ - content: - '{\n "employees": ["name", "salary", "currency_id"],\n "exchange_rates": ["currency_id", "rate"]\n}', - }); - - const result = await node.execute(state, {}); - - // Verify that the schema is filtered and contains the expected tables - expect(result.schema?.tables).to.have.property('employees'); - expect(result.schema?.tables).to.have.property('exchange_rates'); - - // Verify that employees table has the selected columns plus primary key - const employeeColumns = Object.keys( - result.schema?.tables.employees.columns || {}, - ); - expect(employeeColumns).to.containEql('id'); - expect(employeeColumns).to.containEql('name'); - expect(employeeColumns).to.containEql('salary'); - expect(employeeColumns).to.containEql('currency_id'); - - // Verify that exchange_rates table has the selected columns plus primary key - const exchangeRateColumns = Object.keys( - result.schema?.tables.exchange_rates.columns || {}, - ); - expect(exchangeRateColumns).to.containEql('id'); - expect(exchangeRateColumns).to.containEql('currency_id'); - expect(exchangeRateColumns).to.containEql('rate'); - }); - - it('should handle failed attempt response from LLM', async () => { - const originalSchema = schemaHelper.modelToSchema('', [Employee]); - const filteredSchema = { - tables: { - employees: originalSchema.tables.employees, - }, - relations: [], - }; - - const state = { - prompt: 'Some ambiguous query', - schema: filteredSchema, - } as unknown as DbQueryState; - - llmStub.resolves({ - content: - 'failed attempt: Query is too ambiguous to determine relevant columns', - }); - - const result = await node.execute(state, {}); - - expect(result.status).to.equal(GenerationError.Failed); - expect(result.replyToUser).to.equal( - 'Query is too ambiguous to determine relevant columns', - ); - }); - - it('should throw error if no tables in schema', async () => { - const state = { - prompt: 'Get me some data', - schema: {tables: {}, relations: []}, - } as unknown as DbQueryState; - - await expect(node.execute(state, {})).to.be.rejectedWith( - 'No tables found in the schema. Please ensure the get-tables step was completed successfully.', - ); - }); - - it('should include primary key columns even if not explicitly selected', async () => { - const originalSchema = schemaHelper.modelToSchema('', [Employee]); - const filteredSchema = { - tables: { - employees: originalSchema.tables.employees, - }, - relations: [], - }; - - const state = { - prompt: 'Get employee names', - schema: filteredSchema, - } as unknown as DbQueryState; - - // Mock LLM response that doesn't include primary key - llmStub.resolves({ - content: '{\n "employees": ["name"]\n}', - }); - - const result = await node.execute(state, {}); - - // Should include both selected column and primary key - const employeeColumns = Object.keys( - result.schema?.tables.employees.columns || {}, - ); - expect(employeeColumns).to.containEql('id'); - expect(employeeColumns).to.containEql('name'); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts b/src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts deleted file mode 100644 index 2cab51d..0000000 --- a/src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts +++ /dev/null @@ -1,276 +0,0 @@ -import {juggler} from '@loopback/repository'; -import { - createStubInstance, - expect, - sinon, - StubbedInstanceWithSinonAccessor, -} from '@loopback/testlab'; -import { - DbQueryState, - DbSchemaHelperService, - GetTablesNode, - SchemaStore, - SqliteConnector, - TableSearchService, -} from '../../../../components'; -import {RuntimeLLMProvider} from '../../../../types'; -import { - Currency, - Employee, - EmployeeSkill, - ExchangeRate, - Skill, -} from '../../../fixtures/models'; -import {IAuthUserWithPermissions} from 'loopback4-authorization'; - -describe('GetTablesNode Unit', function () { - let node: GetTablesNode; - let smartllmStub: sinon.SinonStub; - let dumbllmStub: sinon.SinonStub; - let schemaHelper: DbSchemaHelperService; - let schemaStore: SchemaStore; - let tableSearchStub: StubbedInstanceWithSinonAccessor; - - beforeEach(async () => { - smartllmStub = sinon.stub(); - dumbllmStub = sinon.stub(); - const llm = dumbllmStub as unknown as RuntimeLLMProvider; - - schemaHelper = new DbSchemaHelperService( - new SqliteConnector( - new juggler.DataSource({ - connector: 'sqlite3', - file: ':memory:', - name: 'db', - debug: true, - }), - {} as unknown as IAuthUserWithPermissions, - ), - {models: []}, - ); - schemaStore = new SchemaStore(); - tableSearchStub = createStubInstance(TableSearchService); - node = new GetTablesNode( - llm, - dumbllmStub as unknown as RuntimeLLMProvider, - { - models: [], - }, - schemaHelper, - schemaStore, - tableSearchStub, - ['test context'], - ); - }); - - it('should return state with minimal schema based on prompt and table search', async () => { - tableSearchStub.stubs.getTables.resolves(['employees', 'exchange_rates']); - const originalSchema = schemaHelper.modelToSchema('', [ - Employee, - ExchangeRate, - Currency, - Skill, - EmployeeSkill, - ]); - await schemaStore.save(originalSchema); - - const state = { - prompt: 'Get me the employee with name Akshat', - schema: originalSchema, - } as unknown as DbQueryState; - - dumbllmStub.resolves({ - content: 'employees', - }); - - const result = await node.execute(state, {}); - - expect(dumbllmStub.getCalls()[0].args[0].value.trim()).equal( - ` -You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later. -- Consider not just the user query but also the context and the table descriptions while selecting the tables. -- Carefully consider each and every table before including or excluding it. -- If doubtful about a table's relevance, include it anyway to give the SQL generation step more options to choose from. -- Assume that the table would have appropriate columns for relating them to any other table even if the description does not mention it. -- If you are not sure about the tables to select from the given schema, just return your doubt asking the user for more details or to rephrase the question in the following format - -failed attempt: reason for failure - - - -employees: ${Employee.definition.settings.description} - -exchange_rates: ${ExchangeRate.definition.settings.description} - - - -Get me the employee with name Akshat - - - -- test context -- employee salary must be converted to USD, using the currency_id column and the exchange rate table - - - - - -The output should be just a comma separated list of table names with no other text, comments or formatting. -Ensure that table names are exact and match the names in the input including schema if given. - -public.employees, public.departments - -In case of failure, return the failure message in the format - -failed attempt: - -failed attempt: reason for failure - -`, - ); - - expect(result.schema).to.deepEqual( - schemaStore.filteredSchema(['employees']), - ); - }); - - it('should return state with minimal schema based on prompt and table search with smart llm', async () => { - node = new GetTablesNode( - dumbllmStub as unknown as RuntimeLLMProvider, - smartllmStub as unknown as RuntimeLLMProvider, - { - models: [], - nodes: { - // config to use smart llm for this node - getTablesNode: { - useSmartLLM: true, - }, - }, - }, - schemaHelper, - schemaStore, - tableSearchStub, - ['test context'], - ); - tableSearchStub.stubs.getTables.resolves(['employees', 'exchange_rates']); - const originalSchema = schemaHelper.modelToSchema('', [ - Employee, - ExchangeRate, - Currency, - Skill, - EmployeeSkill, - ]); - await schemaStore.save(originalSchema); - - const state = { - prompt: 'Get me the employee with name Akshat', - schema: originalSchema, - } as unknown as DbQueryState; - - smartllmStub.resolves({ - content: 'employees', - }); - - const result = await node.execute(state, {}); - - expect(smartllmStub.getCalls()[0].args[0].value.trim()).equal( - ` -You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later. -- Consider not just the user query but also the context and the table descriptions while selecting the tables. -- Carefully consider each and every table before including or excluding it. -- If doubtful about a table's relevance, include it anyway to give the SQL generation step more options to choose from. -- Assume that the table would have appropriate columns for relating them to any other table even if the description does not mention it. -- If you are not sure about the tables to select from the given schema, just return your doubt asking the user for more details or to rephrase the question in the following format - -failed attempt: reason for failure - - - -employees: ${Employee.definition.settings.description} - -exchange_rates: ${ExchangeRate.definition.settings.description} - - - -Get me the employee with name Akshat - - - -- test context -- employee salary must be converted to USD, using the currency_id column and the exchange rate table - - - - - -The output should be just a comma separated list of table names with no other text, comments or formatting. -Ensure that table names are exact and match the names in the input including schema if given. - -public.employees, public.departments - -In case of failure, return the failure message in the format - -failed attempt: - -failed attempt: reason for failure - -`, - ); - - expect(result.schema).to.deepEqual( - schemaStore.filteredSchema(['employees']), - ); - }); - - it('should return throw error if now table available in schema', async () => { - tableSearchStub.stubs.getTables.resolves([]); - const originalSchema = schemaHelper.modelToSchema('', [ - Employee, - ExchangeRate, - Currency, - Skill, - EmployeeSkill, - ]); - await schemaStore.save(originalSchema); - - const state = { - prompt: 'Get me the employee with name Akshat', - schema: originalSchema, - } as unknown as DbQueryState; - - dumbllmStub.resolves({ - content: 'employees', - }); - - await expect(node.execute(state, {})).to.be.rejectedWith( - 'No tables found in the provided database schema. Please ensure the schema is valid.', - ); - }); - - it('should retry selection if table names are not valid', async () => { - tableSearchStub.stubs.getTables.resolves(['employees', 'exchange_rates']); - const originalSchema = schemaHelper.modelToSchema('', [ - Employee, - ExchangeRate, - Currency, - Skill, - EmployeeSkill, - ]); - await schemaStore.save(originalSchema); - - const state = { - prompt: 'Get me the employee with name Akshat', - schema: originalSchema, - } as unknown as DbQueryState; - - dumbllmStub.onFirstCall().resolves({ - content: 'non_existing_table', - }); - dumbllmStub.onSecondCall().resolves({ - content: 'employees', - }); - - const result = await node.execute(state, {}); - - expect(dumbllmStub.callCount).to.equal(2); - expect(result.schema).to.deepEqual( - schemaStore.filteredSchema(['employees']), - ); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/is-improvement.node.unit.ts b/src/__tests__/db-query/unit/nodes/is-improvement.node.unit.ts deleted file mode 100644 index a83e2dd..0000000 --- a/src/__tests__/db-query/unit/nodes/is-improvement.node.unit.ts +++ /dev/null @@ -1,53 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import { - DbQueryState, - IDataSetStore, - IsImprovementNode, -} from '../../../../components'; -import {buildDatasetStoreStub} from '../../../test-helper'; - -describe('IsImprovementNode Unit', function () { - let node: IsImprovementNode; - let store: sinon.SinonStubbedInstance; - - beforeEach(() => { - store = buildDatasetStoreStub(); - node = new IsImprovementNode(store); - }); - - it('should return state as it is if datasetId is not set', async () => { - const state = { - datasetId: undefined, - prompt: 'Test prompt', - } as unknown as DbQueryState; - const result = await node.execute(state, {}); - expect(result).to.deepEqual(state); - }); - - it('should return state with sampleSql and sampleSqlPrompt if datasetId is set', async () => { - const dataset = { - id: 'test-dataset-id', - query: 'SELECT * FROM employees', - prompt: 'Test dataset prompt', - tenantId: 'default', - description: 'This is a test dataset', - tables: ['employees'], - schemaHash: 'test-schema-hash', - votes: 0, - }; - store.findById.resolves(dataset); - - const state = { - datasetId: 'test-dataset-id', - prompt: 'Test prompt', - } as unknown as DbQueryState; - const result = await node.execute(state, {}); - - expect(result).to.deepEqual({ - ...state, - sampleSql: dataset.query, - sampleSqlPrompt: dataset.prompt, - prompt: `${dataset.prompt}\n also consider following feedback given by user -\n ${state.prompt}\n`, - }); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts b/src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts deleted file mode 100644 index 312d4ad..0000000 --- a/src/__tests__/db-query/unit/nodes/save-dataset-node.unit.ts +++ /dev/null @@ -1,166 +0,0 @@ -import {HttpErrors} from '@loopback/rest'; -import { - createStubInstance, - expect, - sinon, - StubbedInstanceWithSinonAccessor, -} from '@loopback/testlab'; -import {IAuthUserWithPermissions} from '@sourceloop/core'; -import { - DbQueryState, - DbSchemaHelperService, - IDataSetStore, - SaveDataSetNode, -} from '../../../../components'; -import {DataSet} from '../../../../components/db-query/models'; -import {RuntimeLLMProvider} from '../../../../types'; -import {buildDatasetStoreStub} from '../../../test-helper'; - -describe('SaveDataSetNode Unit', function () { - let node: SaveDataSetNode; - let llmStub: sinon.SinonStub; - let store: sinon.SinonStubbedInstance; - let helper: StubbedInstanceWithSinonAccessor; - - beforeEach(() => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - store = buildDatasetStoreStub(); - helper = createStubInstance(DbSchemaHelperService); - node = new SaveDataSetNode( - llm, - store, - {models: []}, - { - tenantId: 'test-tenant', - userTenantId: 'test-tenant', - permissions: ['1'], - } as IAuthUserWithPermissions, - helper, - ); - }); - - it('should return state with dataset id', async () => { - llmStub.resolves({ - content: 'dataset desc', - }); - store.create.resolves( - new DataSet({ - id: '123', - }), - ); - const result = await node.execute( - { - prompt: 'Save this dataset', - schema: { - tables: {}, - relations: [], - }, - sql: 'SELECT * FROM test_table;', - description: 'dataset desc', - } as unknown as DbQueryState, - {}, - ); - expect(result).to.have.property('datasetId'); - expect(result.datasetId).to.equal('123'); - expect(result.done).to.be.true(); - expect(result.replyToUser).to.equal(`dataset desc`); - }); - - it('should return state with dataset id and result array if readAccessForAI is true', async () => { - node = new SaveDataSetNode( - llmStub as unknown as RuntimeLLMProvider, - store, - {models: [], readAccessForAI: true, maxRowsForAI: 50}, - { - tenantId: 'test-tenant', - userTenantId: 'test-tenant', - permissions: ['1'], - } as IAuthUserWithPermissions, - helper, - ); - llmStub.resolves({ - content: 'dataset desc', - }); - store.create.resolves( - new DataSet({ - id: '123', - }), - ); - const expectedResult = [{id: 1, name: 'test'}]; - store.getData.resolves(expectedResult); - const result = await node.execute( - { - prompt: 'Save this dataset', - schema: { - tables: {}, - relations: [], - }, - sql: 'SELECT * FROM test_table;', - description: 'dataset desc', - } as unknown as DbQueryState, - {}, - ); - expect(result).to.have.property('datasetId'); - expect(result.datasetId).to.equal('123'); - expect(result.done).to.be.true(); - expect(result.replyToUser).to.equal(`dataset desc`); - expect(result.resultArray).to.deepEqual(expectedResult); - }); - - it('should throw error if user does not have tenantId', async () => { - const llm = llmStub as unknown as RuntimeLLMProvider; - node = new SaveDataSetNode( - llm, - store, - {models: []}, - { - userTenantId: 'test-tenant', - permissions: ['1'], - } as IAuthUserWithPermissions, - helper, - ); - await expect( - node.execute( - { - prompt: 'Save this dataset', - schema: { - tables: {}, - relations: [], - }, - sql: 'SELECT * FROM test_table;', - } as unknown as DbQueryState, - {}, - ), - ).to.be.rejectedWith( - new HttpErrors.BadRequest(`User does not have a tenantId`), - ); - }); - - it('should throw error if sql is not present in state', async () => { - const llm = llmStub as unknown as RuntimeLLMProvider; - node = new SaveDataSetNode( - llm, - store, - {models: []}, - { - tenantId: 'test-tenant', - userTenantId: 'test-tenant', - permissions: ['1'], - } as IAuthUserWithPermissions, - helper, - ); - await expect( - node.execute( - { - prompt: 'Save this dataset', - schema: { - tables: {}, - relations: [], - }, - } as unknown as DbQueryState, - {}, - ), - ).to.be.rejectedWith(HttpErrors.InternalServerError()); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts b/src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts deleted file mode 100644 index a6fe36b..0000000 --- a/src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts +++ /dev/null @@ -1,237 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import { - DatabaseSchema, - EvaluationResult, - SemanticValidatorNode, -} from '../../../../components'; -import { - DbSchemaHelperService, - TableSearchService, -} from '../../../../components/db-query/services'; -import {RuntimeLLMProvider} from '../../../../types'; - -describe('SemanticValidatorNode Unit', function () { - let node: SemanticValidatorNode; - let llmStub: sinon.SinonStub; - let tableSearchStub: sinon.SinonStubbedInstance; - - beforeEach(() => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - const schemaHelper = { - asString: sinon.stub().returns(''), - } as unknown as DbSchemaHelperService; - tableSearchStub = sinon.createStubInstance(TableSearchService); - tableSearchStub.getTables.resolves([]); - - node = new SemanticValidatorNode( - llm, - llm, - {models: []}, - tableSearchStub, - schemaHelper, - ); - }); - - afterEach(() => { - sinon.restore(); - }); - - it('should return Pass if the query is valid', async () => { - const state = { - prompt: 'Get all users', - sql: 'SELECT * FROM users', - schema: {tables: {}, relations: []}, - status: EvaluationResult.Pass, - id: 'test-id', - feedbacks: [], - replyToUser: '', - datasetId: 'test-dataset-id', - done: false, - sampleSqlPrompt: '', - sampleSql: '', - fromCache: false, - resultArray: undefined, - description: undefined, - directCall: false, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - fromTemplate: undefined, - templateId: undefined, - validationChecklist: '1. Query selects all users', - changeType: undefined, - }; - llmStub.resolves({ - content: '', - }); - - const result = await node.execute(state, {}); - - expect(result.semanticStatus).to.equal(EvaluationResult.Pass); - sinon.assert.calledOnce(llmStub); - }); - - it('should return QueryError if the query is invalid', async () => { - tableSearchStub.getTables.resolves(['users', 'orders']); - const state = { - prompt: 'Get all users', - sql: 'SELECT * FROM invalid_table', - schema: { - tables: {users: {}, orders: {}}, - relations: [], - } as unknown as DatabaseSchema, - status: EvaluationResult.Pass, - id: 'test-id', - feedbacks: [], - replyToUser: '', - datasetId: 'test-dataset-id', - done: false, - sampleSqlPrompt: '', - sampleSql: '', - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - fromTemplate: undefined, - templateId: undefined, - validationChecklist: '1. Query selects from users table', - changeType: undefined, - }; - llmStub.resolves({ - content: - '\n- Query selects from wrong table. Should select from users table instead.\n\nusers', - }); - - const result = await node.execute(state, {}); - - expect(result.semanticStatus).to.equal(EvaluationResult.QueryError); - expect(result.semanticErrorTables).to.deepEqual(['users']); - sinon.assert.calledOnce(llmStub); - - const prompt = llmStub.firstCall.args[0]; - // Verify the prompt contains the user question, checklist, SQL, schema, and table names - expect(prompt.value).to.containEql(state.sql); - expect(prompt.value).to.containEql(state.prompt); - expect(prompt.value).to.containEql('1. Query selects from users table'); - expect(prompt.value).to.containEql(''); - expect(prompt.value).to.containEql(''); - expect(prompt.value).to.containEql(''); - expect(prompt.value).to.containEql('users, orders'); - }); - - it('should include feedbacks in the prompt', async () => { - const state = { - prompt: 'Get all users', - sql: 'SELECT * FROM users', - schema: {tables: {}, relations: []}, - status: EvaluationResult.Pass, - id: 'test-id', - feedbacks: ['the previous query was wrong'], - replyToUser: '', - datasetId: 'test-dataset-id', - done: false, - sampleSqlPrompt: '', - sampleSql: '', - fromCache: false, - description: undefined, - directCall: false, - resultArray: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - fromTemplate: undefined, - templateId: undefined, - validationChecklist: '1. Query selects all users', - changeType: undefined, - }; - llmStub.resolves({ - content: '', - }); - - await node.execute(state, {}); - - sinon.assert.calledOnce(llmStub); - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.containEql('the previous query was wrong'); - }); - - it('should pass all accessible tables from tableSearchService into available-tables so LLM can flag missing ones', async () => { - const searchedTables = [ - 'public.users', - 'public.orders', - 'public.payments', - 'analytics.reports', - ]; - tableSearchStub = sinon.createStubInstance(TableSearchService); - tableSearchStub.getTables.resolves(searchedTables); - - const schemaHelper = { - asString: sinon.stub().returns(''), - } as unknown as DbSchemaHelperService; - - const nodeWithTables = new SemanticValidatorNode( - llmStub as unknown as RuntimeLLMProvider, - llmStub as unknown as RuntimeLLMProvider, - {models: []}, - tableSearchStub, - schemaHelper, - ); - - const state = { - prompt: 'Get revenue per user', - sql: 'SELECT u.name, SUM(p.amount) FROM users u JOIN payments p ON u.id = p.user_id GROUP BY u.name', - schema: {tables: {}, relations: []}, - status: EvaluationResult.Pass, - id: 'test-id', - feedbacks: [], - replyToUser: '', - datasetId: 'test-dataset-id', - done: false, - sampleSqlPrompt: '', - sampleSql: '', - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - fromTemplate: undefined, - templateId: undefined, - validationChecklist: '1. Revenue grouped by user', - changeType: undefined, - }; - - llmStub.resolves({content: ''}); - - await nodeWithTables.execute(state, {}); - - sinon.assert.calledOnce(tableSearchStub.getTables); - expect(tableSearchStub.getTables.firstCall.args[0]).to.equal( - 'Get revenue per user', - ); - - sinon.assert.calledOnce(llmStub); - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.containEql(''); - expect(prompt.value).to.containEql( - 'public.users, public.orders, public.payments, analytics.reports', - ); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts b/src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts deleted file mode 100644 index 7243638..0000000 --- a/src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts +++ /dev/null @@ -1,919 +0,0 @@ -import {juggler} from '@loopback/repository'; -import {expect, sinon} from '@loopback/testlab'; -import { - ChangeType, - DbSchemaHelperService, - SqlGenerationNode, - SqliteConnector, -} from '../../../../components'; -import {RuntimeLLMProvider, SupportedDBs} from '../../../../types'; -import {IAuthUserWithPermissions} from 'loopback4-authorization'; - -describe('SqlGenerationNode Unit', function () { - let node: SqlGenerationNode; - let llmStub: sinon.SinonStub; - let schemaHelper: DbSchemaHelperService; - - beforeEach(() => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - - schemaHelper = new DbSchemaHelperService( - new SqliteConnector( - new juggler.DataSource({ - connector: 'sqlite3', - file: ':memory:', - name: 'db', - debug: true, - }), - {} as unknown as IAuthUserWithPermissions, - ), - {models: []}, - ); - - // Mock the getTablesContext method - sinon - .stub(schemaHelper, 'getTablesContext') - .returns(['Table employees contains employee information']); - - node = new SqlGenerationNode( - llm, - llm, - { - db: { - dialect: SupportedDBs.SQLite, - }, - models: [], - }, - schemaHelper, - ['test context'], - ); - }); - - afterEach(() => { - sinon.restore(); - }); - - it('should generate SQL query based on the provided prompt', async () => { - llmStub.resolves({ - content: 'thinking about itSELECT * FROM employees;', - }); - - const state = { - prompt: 'Generate a SQL query to select all employees', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: undefined, - sampleSqlPrompt: undefined, - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await node.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees;'); - - sinon.assert.calledOnce(llmStub); - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.eql(` - -You are an expert AI assistant that generates SQL queries based on user questions and a given database schema. -You try to following the instructions carefully to generate the SQL query that answers the question. -Do not hallucinate details or make up information. -Your task is to convert a question into a SQL query, given a ${SupportedDBs.SQLite} database schema. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **DO NOT make any DML statements** (INSERT, UPDATE, DELETE, DROP etc.) to the database. -- Never query for all the columns from a specific table, only ask for the relevant columns for the given the question. -- You can only generate a single query, so if you need multiple results you can use JOINs, subqueries, CTEs or UNIONS. -- Do not make any assumptions about the user's intent beyond what is explicitly provided in the prompt. -- Ensure proper grouping with brackets for where clauses with multiple conditions using AND and OR. -- Follow each and every single rule in the "must-follow-rules" section carefully while writing the query. DO NOT SKIP ANY RULE. - - -${state.prompt} - - - -${schemaHelper.asString(state.schema)} - - - -You must keep these additional details in mind while writing the query - -- test context -- Table employees contains employee information - - - - - - - - -Output should only be a valid SQL query with no other special character or formatting. -Contains the required valid SQL satisfying all the constraints. -It should have no other character or symbol or character that is not part of SQLs. -`); - }); - - it('should generate SQL query based on the provided prompt with a single feedback from some validation stage', async () => { - llmStub.resolves({ - content: 'thinking about itSELECT * FROM employees;', - }); - - const state = { - prompt: 'Generate a SQL query to select all employees', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - }, - relations: [], - }, - feedbacks: [`The last query was using wrong table`], - sampleSql: 'test sql', - sampleSqlPrompt: `test sql prompt`, - done: false, - sql: `select * from wrong_table;`, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await node.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees;'); - - sinon.assert.calledOnce(llmStub); - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.eql(` - -You are an expert AI assistant that generates SQL queries based on user questions and a given database schema. -You try to following the instructions carefully to generate the SQL query that answers the question. -Do not hallucinate details or make up information. -Your task is to convert a question into a SQL query, given a ${SupportedDBs.SQLite} database schema. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **DO NOT make any DML statements** (INSERT, UPDATE, DELETE, DROP etc.) to the database. -- Never query for all the columns from a specific table, only ask for the relevant columns for the given the question. -- You can only generate a single query, so if you need multiple results you can use JOINs, subqueries, CTEs or UNIONS. -- Do not make any assumptions about the user's intent beyond what is explicitly provided in the prompt. -- Ensure proper grouping with brackets for where clauses with multiple conditions using AND and OR. -- Follow each and every single rule in the "must-follow-rules" section carefully while writing the query. DO NOT SKIP ANY RULE. - - -${state.prompt} - - - -${schemaHelper.asString(state.schema)} - - - -You must keep these additional details in mind while writing the query - -- test context -- Table employees contains employee information - - - - - - -We also need to consider the users feedback on the last attempt at query generation. -Make sure you fix the provided error without introducing any new or past errors. -In the last attempt, you generated this SQL query - - -${state.sql} - - - -This was the error in the latest query you generated - \n${state.feedbacks[0]} - - - - - - - -Output should only be a valid SQL query with no other special character or formatting. -Contains the required valid SQL satisfying all the constraints. -It should have no other character or symbol or character that is not part of SQLs. -`); - }); - - it('should generate SQL query based on the provided prompt with a multiple feedbacks from from previous loops', async () => { - llmStub.resolves({ - content: 'thinking about itSELECT * FROM employees;', - }); - - const state = { - prompt: 'Generate a SQL query to select all employees', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - }, - relations: [], - }, - feedbacks: [ - `The last query was using wrong table`, - `The last query was not using the correct types`, - `The last query was not using the correct columns`, - ], - sampleSql: 'test sql', - sampleSqlPrompt: `test sql prompt`, - done: false, - sql: `select * from wrong_table;`, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await node.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees;'); - - sinon.assert.calledOnce(llmStub); - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.eql(` - -You are an expert AI assistant that generates SQL queries based on user questions and a given database schema. -You try to following the instructions carefully to generate the SQL query that answers the question. -Do not hallucinate details or make up information. -Your task is to convert a question into a SQL query, given a ${SupportedDBs.SQLite} database schema. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **DO NOT make any DML statements** (INSERT, UPDATE, DELETE, DROP etc.) to the database. -- Never query for all the columns from a specific table, only ask for the relevant columns for the given the question. -- You can only generate a single query, so if you need multiple results you can use JOINs, subqueries, CTEs or UNIONS. -- Do not make any assumptions about the user's intent beyond what is explicitly provided in the prompt. -- Ensure proper grouping with brackets for where clauses with multiple conditions using AND and OR. -- Follow each and every single rule in the "must-follow-rules" section carefully while writing the query. DO NOT SKIP ANY RULE. - - -${state.prompt} - - - -${schemaHelper.asString(state.schema)} - - - -You must keep these additional details in mind while writing the query - -- test context -- Table employees contains employee information - - - - - - -We also need to consider the users feedback on the last attempt at query generation. -Make sure you fix the provided error without introducing any new or past errors. -In the last attempt, you generated this SQL query - - -${state.sql} - - - -This was the error in the latest query you generated - \n${state.feedbacks[2]} - - - -You already faced following issues in the past - -${state.feedbacks[0]} -${state.feedbacks[1]} - - - - - -Output should only be a valid SQL query with no other special character or formatting. -Contains the required valid SQL satisfying all the constraints. -It should have no other character or symbol or character that is not part of SQLs. -`); - }); - - it('should generate SQL query with sample queries when no feedbacks but has sample SQL', async () => { - llmStub.resolves({ - content: 'thinking about itSELECT * FROM employees;', - }); - - const state = { - prompt: 'Generate a SQL query to select all employees', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: 'SELECT name FROM employees WHERE id = 1', - sampleSqlPrompt: 'Get employee name by id', - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: true, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await node.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees;'); - - sinon.assert.calledOnce(llmStub); - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.match( - /Here is an example query for reference that is similar to the question asked and has been validated by the user/, - ); - expect(prompt.value).to.match(/SELECT name FROM employees WHERE id = 1/); - expect(prompt.value).to.match( - /This was generated for the following question - \nGet employee name by id/, - ); - }); - - it('should generate SQL query with baseline sample queries when no feedbacks and not from cache', async () => { - llmStub.resolves({ - content: 'thinking about itSELECT * FROM employees;', - }); - - const state = { - prompt: 'Generate a SQL query to select all employees', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: 'SELECT name FROM employees WHERE id = 1', - sampleSqlPrompt: 'Get employee name by id', - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await node.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees;'); - - sinon.assert.calledOnce(llmStub); - const prompt = llmStub.firstCall.args[0]; - expect(prompt.value).to.match( - /Here is the last valid SQL query that was generated for the user that is supposed to be used as the base line for the next query generation\./, - ); - expect(prompt.value).to.match(/SELECT name FROM employees WHERE id = 1/); - expect(prompt.value).to.match( - /This was generated for the following question - \nGet employee name by id/, - ); - }); - - describe('Cheap LLM usage optimization', () => { - let smartLLMStub: sinon.SinonStub; - let cheapLLMStub: sinon.SinonStub; - let nodeWithTwoLLMs: SqlGenerationNode; - let originalEnv: string | undefined; - - beforeEach(() => { - smartLLMStub = sinon.stub(); - cheapLLMStub = sinon.stub(); - originalEnv = process.env.OPTIMIZE_CACHED_QUERIES; - - const smartLLM = smartLLMStub as unknown as RuntimeLLMProvider; - const cheapLLM = cheapLLMStub as unknown as RuntimeLLMProvider; - - nodeWithTwoLLMs = new SqlGenerationNode( - smartLLM, - cheapLLM, - { - db: { - dialect: SupportedDBs.SQLite, - }, - models: [], - }, - schemaHelper, - ['test context'], - ); - }); - - afterEach(() => { - if (originalEnv === undefined) { - delete process.env.OPTIMIZE_CACHED_QUERIES; - } else { - process.env.OPTIMIZE_CACHED_QUERIES = originalEnv; - } - }); - - it('should use cheap LLM when changeType is Minor', async () => { - cheapLLMStub.resolves({ - content: 'SELECT * FROM employees WHERE id = 1;', - }); - - const state = { - prompt: 'Get employee by id 1', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - departments: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Department table', - context: [], - hash: 'hash2', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: 'SELECT name FROM employees WHERE id = 5', - sampleSqlPrompt: 'Get employee name', - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: true, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: ChangeType.Minor, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await nodeWithTwoLLMs.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees WHERE id = 1;'); - sinon.assert.calledOnce(cheapLLMStub); - sinon.assert.notCalled(smartLLMStub); - }); - - it('should use smart LLM when OPTIMIZE_CACHED_QUERIES is false and sampleSql exists', async () => { - process.env.OPTIMIZE_CACHED_QUERIES = 'false'; - smartLLMStub.resolves({ - content: 'SELECT * FROM employees WHERE id = 1;', - }); - - const state = { - prompt: 'Get employee by id 1', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - departments: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Department table', - context: [], - hash: 'hash2', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: 'SELECT name FROM employees WHERE id = 5', - sampleSqlPrompt: 'Get employee name', - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: true, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await nodeWithTwoLLMs.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees WHERE id = 1;'); - sinon.assert.calledOnce(smartLLMStub); - sinon.assert.notCalled(cheapLLMStub); - }); - - it('should use cheap LLM for single table schemas regardless of cache', async () => { - process.env.OPTIMIZE_CACHED_QUERIES = 'false'; - cheapLLMStub.resolves({ - content: 'SELECT * FROM employees;', - }); - - const state = { - prompt: 'Get all employees', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: undefined, - sampleSqlPrompt: undefined, - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await nodeWithTwoLLMs.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees;'); - sinon.assert.calledOnce(cheapLLMStub); - sinon.assert.notCalled(smartLLMStub); - }); - - it('should use smart LLM for multiple tables without cached queries', async () => { - process.env.OPTIMIZE_CACHED_QUERIES = 'true'; - smartLLMStub.resolves({ - content: - 'SELECT e.name, d.name FROM employees e JOIN departments d ON e.dept_id = d.id;', - }); - - const state = { - prompt: 'Get employees with their departments', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - deptId: {type: 'number', required: false, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - departments: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Department table', - context: [], - hash: 'hash2', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: undefined, - sampleSqlPrompt: undefined, - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await nodeWithTwoLLMs.execute(state, {}); - - expect(result.sql).to.equal( - 'SELECT e.name, d.name FROM employees e JOIN departments d ON e.dept_id = d.id;', - ); - sinon.assert.calledOnce(smartLLMStub); - sinon.assert.notCalled(cheapLLMStub); - }); - - it('should use cheap LLM for validation fix retries', async () => { - cheapLLMStub.resolves({ - content: 'SELECT * FROM employees WHERE id = 1;', - }); - - const state = { - prompt: 'Get employee by id 1', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - departments: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Department table', - context: [], - hash: 'hash2', - }, - }, - relations: [], - }, - feedbacks: [ - 'Query Validation Failed by DB: query_error with error syntax error', - ], - sampleSql: 'SELECT name FROM employees WHERE id = 5', - sampleSqlPrompt: 'Get employee name', - done: false, - sql: 'SELECT * FROM employees WHERE', - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: true, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await nodeWithTwoLLMs.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees WHERE id = 1;'); - sinon.assert.calledOnce(cheapLLMStub); - sinon.assert.notCalled(smartLLMStub); - }); - - it('should use smart LLM when sampleSql is null despite optimization being enabled', async () => { - process.env.OPTIMIZE_CACHED_QUERIES = 'true'; - smartLLMStub.resolves({ - content: 'SELECT * FROM employees, departments;', - }); - - const state = { - prompt: 'Get all data', - schema: { - tables: { - employees: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Employee table', - context: [], - hash: 'hash1', - }, - departments: { - columns: { - id: {type: 'number', required: true, id: true}, - name: {type: 'string', required: true, id: false}, - }, - primaryKey: ['id'], - description: 'Department table', - context: [], - hash: 'hash2', - }, - }, - relations: [], - }, - feedbacks: [], - sampleSql: undefined, - sampleSqlPrompt: undefined, - done: false, - sql: undefined, - status: undefined, - id: '123', - replyToUser: undefined, - datasetId: undefined, - fromCache: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }; - - const result = await nodeWithTwoLLMs.execute(state, {}); - - expect(result.sql).to.equal('SELECT * FROM employees, departments;'); - sinon.assert.calledOnce(smartLLMStub); - sinon.assert.notCalled(cheapLLMStub); - }); - }); -}); diff --git a/src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts b/src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts deleted file mode 100644 index 7d38f6a..0000000 --- a/src/__tests__/db-query/unit/nodes/syntactic-validator.node.unit.ts +++ /dev/null @@ -1,99 +0,0 @@ -import {juggler} from '@loopback/repository'; -import {expect, sinon} from '@loopback/testlab'; -import { - DbQueryState, - EvaluationResult, - IDbConnector, - SqliteConnector, - SyntacticValidatorNode, -} from '../../../../components'; -import {RuntimeLLMProvider} from '../../../../types'; -import {IAuthUserWithPermissions} from 'loopback4-authorization'; - -describe('SyntacticValidatorNode Unit', function () { - let node: SyntacticValidatorNode; - let llmStub: sinon.SinonStub; - let connector: IDbConnector; - - beforeEach(async () => { - llmStub = sinon.stub(); - const llm = llmStub as unknown as RuntimeLLMProvider; - - const ds = new juggler.DataSource({ - connector: 'sqlite3', - file: ':memory:', - name: 'db', - debug: true, - }); - await ds.execute(` - CREATE TABLE users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - email TEXT NOT NULL UNIQUE, - age INTEGER, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - `); - connector = new SqliteConnector( - ds, - {} as unknown as IAuthUserWithPermissions, - ); - - node = new SyntacticValidatorNode(llm, connector); - }); - - it('should return pass status in state if it is valid', async () => { - const state = { - sql: 'SELECT * FROM users', - schema: { - tables: {}, - }, - } as unknown as DbQueryState; - - const result = await node.execute(state, {}); - expect(llmStub.calledOnce).to.be.false(); - expect(result).to.deepEqual({ - syntacticStatus: EvaluationResult.Pass, - }); - }); - - it('should return a feedback with table error if query has table related error', async () => { - const state = { - sql: 'SELECT * FROM users_wrong', - schema: { - tables: {users: {}, orders: {}}, - }, - } as unknown as DbQueryState; - - llmStub.resolves({ - content: `${EvaluationResult.TableError}\nusers`, - }); - - const result = await node.execute(state, {}); - expect(result.syntacticStatus).to.equal(EvaluationResult.TableError); - expect(result.syntacticFeedback).to.equal( - `Query Validation Failed by DB: ${EvaluationResult.TableError} with error SQLITE_ERROR: no such table: users_wrong`, - ); - expect(result.syntacticErrorTables).to.deepEqual(['users']); - }); - - it('should return a feedback with query error if query has non table related error', async () => { - const state = { - sql: 'SELECT * users', - schema: { - tables: {users: {}}, - }, - } as unknown as DbQueryState; - - llmStub.resolves({ - content: `${EvaluationResult.QueryError}\nusers`, - }); - - const result = await node.execute(state, {}); - expect(result.syntacticStatus).to.equal(EvaluationResult.QueryError); - expect(result.syntacticFeedback).to.equal( - `Query Validation Failed by DB: ${EvaluationResult.QueryError} with error SQLITE_ERROR: near \"users\": syntax error`, - ); - expect(result.syntacticErrorTables).to.deepEqual(['users']); - }); -}); diff --git a/src/__tests__/fixtures/fake-ai-models.ts b/src/__tests__/fixtures/fake-ai-models.ts new file mode 100644 index 0000000..bb11372 --- /dev/null +++ b/src/__tests__/fixtures/fake-ai-models.ts @@ -0,0 +1,134 @@ +/** + * Creates a minimal fake AI SDK `LanguageModelV3` that returns a preset text + * response. Use this in unit tests to avoid network calls and module-level + * stubbing (which doesn't work on non-configurable getters in the `ai` package). + * + * @param text - The text the model should return. + * @param inputTokens - Simulated input token count (default: 10). + * @param outputTokens - Simulated output token count (default: 5). + */ +export function createFakeLanguageModel( + text: string, + inputTokens = 10, + outputTokens = 5, +) { + return { + specificationVersion: 'v3' as const, + provider: 'fake', + modelId: 'fake-model', + defaultObjectGenerationMode: undefined, + supportsImageUrls: false, + supportsStructuredOutputs: false, + doGenerate: async () => ({ + content: [{type: 'text' as const, text}], + finishReason: 'stop' as const, + usage: { + inputTokens: {total: inputTokens, noCache: inputTokens, cache: 0}, + outputTokens: {total: outputTokens}, + totalTokens: {total: inputTokens + outputTokens}, + }, + warnings: [], + request: {body: '{}'}, + response: { + id: 'fake-id', + timestamp: new Date(), + modelId: 'fake-model', + headers: {}, + }, + }), + doStream: async () => { + throw new Error('doStream not supported in fake model'); + }, + }; +} + +/** + * Creates a fake `LanguageModelV3` that supports `doStream` (for use with + * `streamText`). Emits the provided text as a single text-delta chunk followed + * by a finish event with the given token counts. + */ +export function createFakeStreamingLanguageModel( + text: string, + inputTokens = 10, + outputTokens = 5, +) { + const usage = { + inputTokens: {total: inputTokens, noCache: inputTokens, cache: 0}, + outputTokens: {total: outputTokens}, + totalTokens: {total: inputTokens + outputTokens}, + }; + return { + specificationVersion: 'v3' as const, + provider: 'fake', + modelId: 'fake-streaming-model', + defaultObjectGenerationMode: undefined, + supportsImageUrls: false, + supportsStructuredOutputs: false, + doGenerate: async () => ({ + content: [{type: 'text' as const, text}], + finishReason: 'stop' as const, + usage, + warnings: [], + request: {body: '{}'}, + response: { + id: 'fake-stream-id', + timestamp: new Date(), + modelId: 'fake-streaming-model', + headers: {}, + }, + }), + doStream: async () => { + const parts: object[] = [ + {type: 'text-start', id: 'fake-text-1'}, + {type: 'text-delta', id: 'fake-text-1', delta: text}, + {type: 'text-end', id: 'fake-text-1'}, + {type: 'finish', finishReason: 'stop', usage}, + ]; + let idx = 0; + const stream = new ReadableStream({ + pull(controller) { + if (idx < parts.length) { + controller.enqueue(parts[idx++]); + } else { + controller.close(); + } + }, + }); + return { + stream, + warnings: [], + rawCall: {rawPrompt: null, rawSettings: {}}, + request: {body: '{}'}, + response: { + id: 'fake-stream-id', + timestamp: new Date(), + modelId: 'fake-streaming-model', + headers: {}, + }, + }; + }, + }; +} + +/** + * Creates a minimal fake AI SDK `EmbeddingModel` that returns preset embeddings. + * + * @param embeddingsPerCall - Array of embedding vectors to return. One per value. + */ +export function createFakeEmbeddingModel( + embeddingsPerCall: number[][] = [[0.1, 0.2, 0.3]], +) { + return { + specificationVersion: 'v2' as const, + provider: 'fake', + modelId: 'fake-embedding-model', + maxEmbeddingsPerCall: 100, + supportsParallelCalls: false, + doEmbed: async ({values}: {values: string[]}) => ({ + embeddings: values.map( + (_, i) => embeddingsPerCall[i % embeddingsPerCall.length], + ), + usage: {tokens: values.length}, + }), + }; +} diff --git a/src/__tests__/fixtures/test-app.ts b/src/__tests__/fixtures/test-app.ts index f6d0324..6837436 100644 --- a/src/__tests__/fixtures/test-app.ts +++ b/src/__tests__/fixtures/test-app.ts @@ -27,8 +27,8 @@ import {SupportedDBs} from '../../types'; import {Currency, ExchangeRate} from './models'; import {Employee} from './models/employee.model'; import {EmployeeRepository} from './repositories'; -import {Ollama, OllamaEmbedding} from '../../sub-modules/providers/ollama'; -import {Cerebras} from '../../sub-modules/providers/cerebras'; +import {OllamaSdk} from '../../sub-modules/providers/ollama'; +import {CerebrasSdk} from '../../sub-modules/providers/cerebras'; import {sinon} from '@loopback/testlab'; export class TestApp extends BootMixin( ServiceMixin(RepositoryMixin(RestApplication)), @@ -41,27 +41,40 @@ export class TestApp extends BootMixin( useCustomSequence: true, }); if (process.env.OLLAMA === '1') { - this.bind(AiIntegrationBindings.CheapLLM).toProvider(Ollama); - this.bind(AiIntegrationBindings.SmartLLM).toProvider(Ollama); - this.bind(AiIntegrationBindings.FileLLM).toProvider(Ollama); - this.bind(AiIntegrationBindings.ChatLLM).toProvider(Ollama); - this.bind(AiIntegrationBindings.EmbeddingModel).toProvider( - OllamaEmbedding, + this.bind(AiIntegrationBindings.CheapLLM).toProvider(OllamaSdk); + this.bind(AiIntegrationBindings.SmartLLM).toProvider(OllamaSdk); + this.bind(AiIntegrationBindings.FileLLM).toProvider(OllamaSdk); + this.bind(AiIntegrationBindings.ChatLLM).toProvider(OllamaSdk); + this.bind(AiIntegrationBindings.AiSdkCheapLLM).toProvider(OllamaSdk); + this.bind(AiIntegrationBindings.AiSdkSmartLLM).toProvider(OllamaSdk); + this.bind(AiIntegrationBindings.EmbeddingModel).to( + undefined as unknown as never, + ); + this.bind(AiIntegrationBindings.AiSdkEmbeddingModel).to( + undefined as unknown as never, ); } else if (process.env.CEREBRAS === '1') { - this.bind(AiIntegrationBindings.CheapLLM).toProvider(Cerebras); - this.bind(AiIntegrationBindings.SmartLLM).toProvider(Cerebras); - this.bind(AiIntegrationBindings.FileLLM).toProvider(Cerebras); - this.bind(AiIntegrationBindings.ChatLLM).toProvider(Cerebras); - this.bind(AiIntegrationBindings.EmbeddingModel).toProvider( - OllamaEmbedding, + this.bind(AiIntegrationBindings.CheapLLM).toProvider(CerebrasSdk); + this.bind(AiIntegrationBindings.SmartLLM).toProvider(CerebrasSdk); + this.bind(AiIntegrationBindings.FileLLM).toProvider(CerebrasSdk); + this.bind(AiIntegrationBindings.ChatLLM).toProvider(CerebrasSdk); + this.bind(AiIntegrationBindings.AiSdkCheapLLM).toProvider(CerebrasSdk); + this.bind(AiIntegrationBindings.AiSdkSmartLLM).toProvider(CerebrasSdk); + this.bind(AiIntegrationBindings.EmbeddingModel).to( + undefined as unknown as never, + ); + this.bind(AiIntegrationBindings.AiSdkEmbeddingModel).to( + undefined as unknown as never, ); } else if (options.llmStub) { this.bind(AiIntegrationBindings.CheapLLM).to(options.llmStub); this.bind(AiIntegrationBindings.SmartLLM).to(options.llmStub); this.bind(AiIntegrationBindings.FileLLM).to(options.llmStub); this.bind(AiIntegrationBindings.ChatLLM).to(options.llmStub); + this.bind(AiIntegrationBindings.AiSdkCheapLLM).to(options.llmStub); + this.bind(AiIntegrationBindings.AiSdkSmartLLM).to(options.llmStub); this.bind(AiIntegrationBindings.EmbeddingModel).to(options.llmStub); + this.bind(AiIntegrationBindings.AiSdkEmbeddingModel).to(options.llmStub); } this.bind('datasources.readerdb').to( new juggler.DataSource({ @@ -110,7 +123,7 @@ export class TestApp extends BootMixin( .toClass(SqliteConnector) .inScope(BindingScope.TRANSIENT); - this.bind(AiIntegrationBindings.VectorStore) + this.bind(AiIntegrationBindings.AiSdkVectorStore) .toProvider(InMemoryVectorStore) .inScope(BindingScope.SINGLETON); diff --git a/src/__tests__/integration/generation.service.integration.ts b/src/__tests__/integration/generation.service.integration.ts index 3e332e7..352413a 100644 --- a/src/__tests__/integration/generation.service.integration.ts +++ b/src/__tests__/integration/generation.service.integration.ts @@ -1,4 +1,3 @@ -import {IterableReadableStream} from '@langchain/core/utils/stream'; import {Request, Response} from '@loopback/rest'; import { createStubInstance, @@ -7,7 +6,7 @@ import { StubbedInstanceWithSinonAccessor, } from '@loopback/testlab'; import {PassThrough} from 'stream'; -import {ChatGraph, LLMStreamEvent} from '../../graphs'; +import {LLMStreamEvent} from '../../types/events'; import {MastraChatAgent} from '../../mastra'; import {GenerationService} from '../../services'; import {HttpTransport, SSETransport} from '../../transports'; @@ -16,12 +15,10 @@ describe(`GenerationService Integration`, () => { let service: GenerationService; let dummyRequest: Request; let dummyResponse: Response; - let graph: StubbedInstanceWithSinonAccessor; let mastraAgent: StubbedInstanceWithSinonAccessor; describe('with SSETransport', () => { beforeEach(() => { - graph = createStubInstance(ChatGraph); mastraAgent = createStubInstance(MastraChatAgent); dummyResponse = { write: sinon.stub(), @@ -33,12 +30,12 @@ describe(`GenerationService Integration`, () => { once: sinon.stub(), } as unknown as Request; const transport = new SSETransport(dummyResponse, dummyRequest); - service = new GenerationService(graph, mastraAgent, transport, undefined); + service = new GenerationService(mastraAgent, transport, undefined); }); it('should handle generation request and return response', async () => { const dummyStream = new PassThrough({objectMode: true}); - graph.stubs.execute.callsFake(async () => { - return dummyStream as unknown as IterableReadableStream; + mastraAgent.stubs.execute.callsFake(async function* () { + yield* dummyStream as unknown as AsyncIterable; }); dummyStream.push({ type: 'text', @@ -90,8 +87,8 @@ describe(`GenerationService Integration`, () => { it('should handle error gracyfully', async () => { const dummyStream = new PassThrough({objectMode: true}); - graph.stubs.execute.callsFake(async () => { - return dummyStream as unknown as IterableReadableStream; + mastraAgent.stubs.execute.callsFake(async function* () { + yield* dummyStream as unknown as AsyncIterable; }); dummyStream.push({ type: 'text', @@ -142,7 +139,6 @@ describe(`GenerationService Integration`, () => { describe('with HttpTransport', () => { beforeEach(() => { - graph = createStubInstance(ChatGraph); mastraAgent = createStubInstance(MastraChatAgent); dummyResponse = { write: sinon.stub(), @@ -154,12 +150,12 @@ describe(`GenerationService Integration`, () => { once: sinon.stub(), } as unknown as Request; const transport = new HttpTransport(dummyResponse, dummyRequest); - service = new GenerationService(graph, mastraAgent, transport, undefined); + service = new GenerationService(mastraAgent, transport, undefined); }); it('should handle generation request and return response', async () => { const dummyStream = new PassThrough({objectMode: true}); - graph.stubs.execute.callsFake(async () => { - return dummyStream as unknown as IterableReadableStream; + mastraAgent.stubs.execute.callsFake(async function* () { + yield* dummyStream as unknown as AsyncIterable; }); dummyStream.push({ type: 'text', @@ -205,8 +201,8 @@ describe(`GenerationService Integration`, () => { it('should handle error gracyfully', async () => { const dummyStream = new PassThrough({objectMode: true}); - graph.stubs.execute.callsFake(async () => { - return dummyStream as unknown as IterableReadableStream; + mastraAgent.stubs.execute.callsFake(async function* () { + yield* dummyStream as unknown as AsyncIterable; }); dummyStream.push({ type: 'text', diff --git a/src/__tests__/unit/chat.graph.unit.ts b/src/__tests__/unit/chat.graph.unit.ts deleted file mode 100644 index d870991..0000000 --- a/src/__tests__/unit/chat.graph.unit.ts +++ /dev/null @@ -1,160 +0,0 @@ -import {Context} from '@loopback/core'; -import {expect} from '@loopback/testlab'; -import {GRAPH_NODE_NAME} from '../../constant'; -import {ChatGraph, ChatNodes, ChatState, IGraphTool} from '../../graphs'; -import {AiIntegrationBindings} from '../../keys'; -import {TokenCounter} from '../../services'; -import {buildFileStub, buildNodeStub} from '../test-helper'; - -describe(`ChatGraph Unit`, function () { - let graph: ChatGraph; - let stubMap: Record; - - beforeEach(async () => { - const context = new Context('test-context'); - context.bind(AiIntegrationBindings.Tools).to({ - list: [], - map: { - 'test-tool': { - needsReview: false, - key: 'test-tool', - build: (async () => {}) as unknown as IGraphTool['build'], - }, - }, - }); - context.bind('services.TokenCounter').to(new TokenCounter()); - context.bind('ChatGraph').toClass(ChatGraph); - stubMap = {} as Record; - for (const key of Object.values(ChatNodes)) { - const stub = buildNodeStub(); - context - .bind(`services.${key}`) - .to(stub) - .tag({ - [GRAPH_NODE_NAME]: key, - }); - stubMap[key] = stub.execute; - } - graph = await context.get('ChatGraph'); - }); - - it('should init session, and end session on user prompt', async () => { - await graph.execute('test prompt', [], new AbortController().signal); - - expect(stubMap[ChatNodes.InitSession].calledOnce).to.be.true(); - expect(stubMap[ChatNodes.CallLLM].calledOnce).to.be.true(); - // should end at this point - expect(stubMap[ChatNodes.TrimMessages].calledOnce).to.be.false(); - // called once by default - expect(stubMap[ChatNodes.SummariseFile].calledOnce).to.be.true(); - // should be called after call LLM - expect(stubMap[ChatNodes.EndSession].calledOnce).to.be.true(); - }); - - it('should init session, summarise multiple files, and end session on user prompt if no tool call', async () => { - stubMap[ChatNodes.SummariseFile].callsFake((state: ChatState) => { - return { - ...state, - files: state.files?.filter( - (f: Express.Multer.File, index: number) => index !== 0, - ), - }; - }); - - await graph.execute( - 'test prompt', - [buildFileStub(), buildFileStub()], - new AbortController().signal, - ); - - expect(stubMap[ChatNodes.InitSession].calledOnce).to.be.true(); - expect(stubMap[ChatNodes.CallLLM].calledOnce).to.be.true(); - // should end at this point - expect(stubMap[ChatNodes.TrimMessages].calledOnce).to.be.false(); - // called once by default - expect(stubMap[ChatNodes.SummariseFile].getCalls().length).to.be.equal(2); - // should be called after call LLM - expect(stubMap[ChatNodes.EndSession].calledOnce).to.be.true(); - }); - - it('should init session, summarise single files, and end session on user prompt if no tool call', async () => { - stubMap[ChatNodes.SummariseFile].callsFake((state: ChatState) => { - return { - ...state, - files: state.files?.filter( - (f: Express.Multer.File, index: number) => index !== 0, - ), - }; - }); - - await graph.execute( - 'test prompt', - buildFileStub(), - new AbortController().signal, - ); - - expect(stubMap[ChatNodes.InitSession].calledOnce).to.be.true(); - expect(stubMap[ChatNodes.CallLLM].calledOnce).to.be.true(); - // should end at this point - expect(stubMap[ChatNodes.TrimMessages].calledOnce).to.be.false(); - // called once by default - expect(stubMap[ChatNodes.SummariseFile].getCalls().length).to.be.equal(1); - // should be called after call LLM - expect(stubMap[ChatNodes.EndSession].calledOnce).to.be.true(); - }); - - it('should init session, summarise file, call LLM, run tool and end session', async () => { - let calledAlready = false; - stubMap[ChatNodes.CallLLM].callsFake((state: ChatState) => { - if (calledAlready) { - // if called already, return the LLM response without tool call - return { - ...state, - messages: [ - ...state.messages, - { - role: 'assistant', - content: 'This is a response from LLM', - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_calls: [], - }, - ], - }; - } - calledAlready = true; - return { - ...state, - messages: [ - ...state.messages, - { - role: 'assistant', - content: 'This is a response from LLM', - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_calls: [ - { - id: 'tool-call-1', - name: 'test-tool', - type: 'function', - arguments: {}, - }, - ], - }, - ], - }; - }); - - await graph.execute('test prompt', [], new AbortController().signal); - - expect(stubMap[ChatNodes.InitSession].calledOnce).to.be.true(); - // this should called twice - expect(stubMap[ChatNodes.CallLLM].calledTwice).to.be.true(); - // should call the tool once - expect(stubMap[ChatNodes.RunTool].calledOnce).to.be.true(); - // should call this once after the tool call - expect(stubMap[ChatNodes.TrimMessages].calledOnce).to.be.true(); - // called once by default - expect(stubMap[ChatNodes.SummariseFile].calledOnce).to.be.true(); - // should be called after call LLM - expect(stubMap[ChatNodes.EndSession].calledOnce).to.be.true(); - }); -}); diff --git a/src/__tests__/unit/langfuse-core.provider.unit.ts b/src/__tests__/unit/langfuse-core.provider.unit.ts new file mode 100644 index 0000000..d0e62c1 --- /dev/null +++ b/src/__tests__/unit/langfuse-core.provider.unit.ts @@ -0,0 +1,65 @@ +import {expect} from '@loopback/testlab'; +import {LangfuseCoreProvider} from '../../sub-modules/obf/langfuse/langfuse-core.provider'; + +describe('LangfuseCoreProvider (unit)', function () { + const ORIGINAL_ENV = {...process.env}; + + afterEach(() => { + // Restore env vars after each test + process.env.LANGFUSE_PUBLIC_KEY = ORIGINAL_ENV.LANGFUSE_PUBLIC_KEY; + process.env.LANGFUSE_SECRET_KEY = ORIGINAL_ENV.LANGFUSE_SECRET_KEY; + process.env.LANGFUSE_HOST = ORIGINAL_ENV.LANGFUSE_HOST; + }); + + it('throws when LANGFUSE_PUBLIC_KEY is missing', () => { + delete process.env.LANGFUSE_PUBLIC_KEY; + process.env.LANGFUSE_SECRET_KEY = 'sk-test'; + + const provider = new LangfuseCoreProvider(); + expect(() => provider.value()).to.throwError( + /LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables must be set/, + ); + }); + + it('throws when LANGFUSE_SECRET_KEY is missing', () => { + process.env.LANGFUSE_PUBLIC_KEY = 'pk-test'; + delete process.env.LANGFUSE_SECRET_KEY; + + const provider = new LangfuseCoreProvider(); + expect(() => provider.value()).to.throwError( + /LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables must be set/, + ); + }); + + it('throws when both keys are missing', () => { + delete process.env.LANGFUSE_PUBLIC_KEY; + delete process.env.LANGFUSE_SECRET_KEY; + + const provider = new LangfuseCoreProvider(); + expect(() => provider.value()).to.throwError(/must be set/); + }); + + it('instantiates LangfuseAPIClient when both keys are present', () => { + process.env.LANGFUSE_PUBLIC_KEY = 'pk-test'; + process.env.LANGFUSE_SECRET_KEY = 'sk-test'; + delete process.env.LANGFUSE_HOST; + + const provider = new LangfuseCoreProvider(); + const client = provider.value(); + + // Should be a non-null object (LangfuseAPIClient) + expect(client).to.not.be.null(); + expect(typeof client).to.equal('object'); + }); + + it('uses LANGFUSE_HOST when provided', () => { + process.env.LANGFUSE_PUBLIC_KEY = 'pk-test'; + process.env.LANGFUSE_SECRET_KEY = 'sk-test'; + process.env.LANGFUSE_HOST = 'https://my.langfuse.server'; + + const provider = new LangfuseCoreProvider(); + // Just confirm it doesn't throw — host validation is internal to the client + const client = provider.value(); + expect(client).to.not.be.null(); + }); +}); diff --git a/src/__tests__/unit/mastra-bridge.unit.ts b/src/__tests__/unit/mastra-bridge.unit.ts index bf605ae..af303e5 100644 --- a/src/__tests__/unit/mastra-bridge.unit.ts +++ b/src/__tests__/unit/mastra-bridge.unit.ts @@ -7,7 +7,7 @@ import { TOOL_NAME, TOOL_TAG, } from '../../constant'; -import {IGraphNode, IGraphTool} from '../../graphs/types'; +import {IGraphNode, IGraphTool} from '../../types/tool'; import { MastraBridgeService, MastraRuntimeFactory, diff --git a/src/__tests__/unit/nodes/call-llm.node.unit.ts b/src/__tests__/unit/nodes/call-llm.node.unit.ts deleted file mode 100644 index 91d9c4b..0000000 --- a/src/__tests__/unit/nodes/call-llm.node.unit.ts +++ /dev/null @@ -1,100 +0,0 @@ -import {Context} from '@loopback/core'; -import {juggler} from '@loopback/repository'; -import {expect, sinon} from '@loopback/testlab'; -import {AuthenticationBindings} from 'loopback4-authentication'; -import {CallLLMNode, ChatStore, RunnableConfig} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {Chat} from '../../../models'; -import {ChatRepository, MessageRepository} from '../../../repositories'; -import {RuntimeLLMProvider} from '../../../types'; -import {setupChats, setupMessages, stubUser} from '../../test-helper'; - -describe('CallLLMNode Unit', function () { - let node: CallLLMNode; - let bindToolsStub: sinon.SinonStub; - let llmStub: sinon.SinonStub; - let chatStore: ChatStore; - let baseChat: Chat; - beforeEach(async () => { - bindToolsStub = sinon.stub(); - llmStub = sinon.stub(); - const llmProvider = { - bindTools: bindToolsStub.callsFake(() => { - return { - invoke: llmStub, - }; - }), - } as unknown as RuntimeLLMProvider; - const context = new Context('test-context'); - context.bind('services.CallLLMNode').toClass(CallLLMNode); - context.bind('services.ChatStore').toClass(ChatStore); - context.bind('repositories.ChatRepository').toClass(ChatRepository); - context.bind('repositories.MessageRepository').toClass(MessageRepository); - context.bind(AiIntegrationBindings.Tools).to({ - list: [], - map: {}, - }); - context.bind(AuthenticationBindings.CURRENT_USER).to(stubUser()); - context.bind(AiIntegrationBindings.SmartLLM).to(llmProvider); - context.bind(AiIntegrationBindings.CheapLLM).to(llmProvider); - context.bind(AiIntegrationBindings.ChatLLM).to(llmProvider); - context.bind('datasources.readerdb').to( - new juggler.DataSource({ - connector: 'sqlite3', - file: ':memory:', - name: 'db', - debug: true, - }), - ); - context.bind(`datasources.writerdb`).to( - new juggler.DataSource({ - connector: 'memory', - name: 'db', - }), - ); - - await setupChats(context); - await setupMessages(context); - - node = await context.get(`services.CallLLMNode`); - - chatStore = await context.get(`services.ChatStore`); - baseChat = await chatStore.init(`test`); - }); - - it('should call llm with all tools, and add response to messages list, and update chat state', async () => { - llmStub.resolves({ - content: 'This is a response from LLM', - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_calls: [], - }); - await node.execute( - { - id: baseChat.id, - prompt: 'test prompt', - messages: [], - files: [], - userMessage: undefined, - aiMessage: undefined, - }, - { - writer: sinon.stub(), - } as unknown as RunnableConfig, - ); - - expect(bindToolsStub.calledOnceWith([])).to.be.true(); - const chat = await chatStore.findById(baseChat.id, { - include: ['messages'], - }); - // should have added a message from LLM - expect(chat).to.have.property('messages'); - expect(chat.messages).to.have.length(1); - expect(chat.messages[0]).to.have.property( - 'body', - 'This is a response from LLM', - ); - expect(chat.messages[0].metadata).to.deepEqual({ - type: 'ai', - }); - }); -}); diff --git a/src/__tests__/unit/nodes/context-compression.node.unit.ts b/src/__tests__/unit/nodes/context-compression.node.unit.ts deleted file mode 100644 index 79bf400..0000000 --- a/src/__tests__/unit/nodes/context-compression.node.unit.ts +++ /dev/null @@ -1,69 +0,0 @@ -import {AIMessage} from '@langchain/core/messages'; -import {expect} from '@loopback/testlab'; -import {ContextCompressionNode} from '../../../graphs'; - -describe(`ContextCompressionNode Unit`, function () { - let node: ContextCompressionNode; - - beforeEach(() => { - node = new ContextCompressionNode({ - maxTokenCount: 15, - }); - }); - - it('should not compress context if within limit', async () => { - const state = { - prompt: 'test prompt', - // approx 13 tokens - context: ['This is a long context that needs to be compressed.'], - id: 'test-id', - done: false, - userMessage: undefined, - aiMessage: undefined, - messages: [ - new AIMessage({ - content: 'This is a long context that needs to be compressed', - }), - ], - files: undefined, - }; - - const result = await node.execute(state, {}); - - // no changes in messages as limit is not reached - expect(result.messages).to.have.length(1); - }); - - it('should compress context if not within limit', async () => { - const state = { - prompt: 'test prompt', - // approx 13 tokens - context: ['This is a long context that needs to be compressed.'], - id: 'test-id', - done: false, - userMessage: undefined, - aiMessage: undefined, - messages: [ - new AIMessage({ - content: 'This is a first context that needs to be compressed', - }), - new AIMessage({ - content: 'This is a second context that needs to be compressed', - }), - new AIMessage({ - content: 'This is a third context that needs to be compressed', - }), - ], - files: undefined, - }; - - const result = await node.execute(state, {}); - - // should only have the last message after compression - expect(result.messages).to.have.length(1); - - expect(result.messages[0].text).to.eql( - 'This is a third context that needs to be compressed', - ); - }); -}); diff --git a/src/__tests__/unit/nodes/end-session.node.unit.ts b/src/__tests__/unit/nodes/end-session.node.unit.ts deleted file mode 100644 index 41b909d..0000000 --- a/src/__tests__/unit/nodes/end-session.node.unit.ts +++ /dev/null @@ -1,110 +0,0 @@ -import {LLMResult} from '@langchain/core/outputs'; -import { - createStubInstance, - expect, - sinon, - StubbedInstanceWithSinonAccessor, -} from '@loopback/testlab'; -import { - ChatState, - ChatStore, - EndSessionNode, - LLMStreamEventType, - RunnableConfig, -} from '../../../graphs'; -import {TokenCounter} from '../../../services'; - -describe('EndSessionNode Unit', function () { - let node: EndSessionNode; - let chatStore: StubbedInstanceWithSinonAccessor; - - beforeEach(async () => { - chatStore = createStubInstance(ChatStore); - const counter = new TokenCounter(); - // first llm call - counter.handleLlmStart('1', 'test'); - counter.handleLlmEnd('1', { - generations: [ - [ - { - message: { - // eslint-disable-next-line @typescript-eslint/naming-convention - usage_metadata: { - // eslint-disable-next-line @typescript-eslint/naming-convention - input_tokens: 10, - // eslint-disable-next-line @typescript-eslint/naming-convention - output_tokens: 5, - }, - }, - }, - ], - ], - } as unknown as LLMResult); - // second llm call - counter.handleLlmStart('2', 'test-2'); - counter.handleLlmEnd('2', { - generations: [ - [ - { - message: { - // eslint-disable-next-line @typescript-eslint/naming-convention - usage_metadata: { - // eslint-disable-next-line @typescript-eslint/naming-convention - input_tokens: 20, - // eslint-disable-next-line @typescript-eslint/naming-convention - output_tokens: 30, - }, - }, - }, - ], - ], - } as unknown as LLMResult); - node = new EndSessionNode(chatStore, counter); - }); - - it('should update token counts and return the state, and update the chat counter', async () => { - const state = { - id: 'test-session-id', - prompt: 'test prompt', - messages: [], - done: false, - userMessage: undefined, - aiMessage: undefined, - files: undefined, - } as ChatState; - - const writerStub = sinon.stub(); - const config = { - writer: writerStub, - } as unknown as RunnableConfig; - - const result = await node.execute(state, config); - - expect(result).to.equal(state); - const writerCalls = writerStub.getCalls(); - expect(writerCalls).to.have.length(1); - expect(writerCalls[0].args).to.deepEqual([ - { - type: LLMStreamEventType.TokenCount, - data: { - // sum of 10 and 20 from the two calls above - inputTokens: 30, - // sum of 5 and 30 from the two calls above - outputTokens: 35, - }, - }, - ]); - - const calls = chatStore.stubs.updateCounts.getCalls(); - expect(calls[0].args).to.deepEqual([ - 'test-session-id', - 30, - 35, - // model wise map of token counts - { - test: {inputTokens: 10, outputTokens: 5}, - 'test-2': {inputTokens: 20, outputTokens: 30}, - }, - ]); - }); -}); diff --git a/src/__tests__/unit/nodes/init-session.node.unit.ts b/src/__tests__/unit/nodes/init-session.node.unit.ts deleted file mode 100644 index a257b97..0000000 --- a/src/__tests__/unit/nodes/init-session.node.unit.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { - createStubInstance, - expect, - sinon, - StubbedInstanceWithSinonAccessor, -} from '@loopback/testlab'; -import { - ChatState, - ChatStore, - InitSessionNode, - LLMStreamEventType, -} from '../../../graphs'; -import {Chat} from '../../../models'; - -describe(`InitSessionNode Unit`, function () { - let node: InitSessionNode; - let chatStore: StubbedInstanceWithSinonAccessor; - - beforeEach(() => { - chatStore = createStubInstance(ChatStore); - node = new InitSessionNode(chatStore); - }); - - it('should initialize a new chat session', async () => { - const writerStub = sinon.stub(); - chatStore.stubs.init.callsFake(async () => { - return new Chat({ - id: 'test-session-id', - }); - }); - const result = await node.execute( - {prompt: 'Hello'} as unknown as ChatState, - { - writer: writerStub, - }, - ); - expect(result).to.have.property('id', 'test-session-id'); - expect(writerStub.calledOnce).to.be.true(); - expect(writerStub.getCalls()[0].args[0]).to.deepEqual({ - type: LLMStreamEventType.Init, - data: { - sessionId: 'test-session-id', - }, - }); - }); - - it('should have date in system message', async () => { - const writerStub = sinon.stub(); - chatStore.stubs.init.callsFake(async () => { - return new Chat({ - id: 'test-session-id', - }); - }); - const result = await node.execute( - {prompt: 'Hello'} as unknown as ChatState, - { - writer: writerStub, - }, - ); - const systemMessage = result.messages?.find( - msg => msg.getType() === 'system', - )?.content; - if (typeof systemMessage !== 'string') { - throw new Error('System message is not a string'); - } - expect( - systemMessage?.includes(`Current date is ${new Date().toDateString()}`), - ).to.be.true(); - }); - - it('should have extra provided context in system context', async () => { - node = new InitSessionNode(chatStore, ['Test context.']); - const writerStub = sinon.stub(); - chatStore.stubs.init.callsFake(async () => { - return new Chat({ - id: 'test-session-id', - }); - }); - const result = await node.execute( - {prompt: 'Hello'} as unknown as ChatState, - { - writer: writerStub, - }, - ); - const systemMessage = result.messages?.find( - msg => msg.getType() === 'system', - )?.content; - if (typeof systemMessage !== 'string') { - throw new Error('System message is not a string'); - } - expect(systemMessage?.includes(`Test context.`)).to.be.true(); - }); -}); diff --git a/src/__tests__/unit/nodes/run-tool.node.unit.ts b/src/__tests__/unit/nodes/run-tool.node.unit.ts deleted file mode 100644 index afcb259..0000000 --- a/src/__tests__/unit/nodes/run-tool.node.unit.ts +++ /dev/null @@ -1,125 +0,0 @@ -import {AIMessage, ToolMessage} from '@langchain/core/messages'; -import {HttpErrors} from '@loopback/rest'; -import { - createStubInstance, - expect, - sinon, - StubbedInstanceWithSinonAccessor, -} from '@loopback/testlab'; // Changed import -import {Message} from '@sourceloop/chat-service'; -import {ChatStore} from '../../../graphs/chat/chat.store'; -import {RunToolNode} from '../../../graphs/chat/nodes/run-tool.node'; -import {ChatState} from '../../../graphs/state'; -import {IGraphTool, RunnableConfig} from '../../../graphs/types'; -import {ToolStore} from '../../../types'; - -describe('RunToolNode Unit', () => { - let runToolNode: RunToolNode; - let tools: ToolStore; - let chatStore: StubbedInstanceWithSinonAccessor; - let writerStub: sinon.SinonStub; - let invokeStub: sinon.SinonStub; - - beforeEach(() => { - writerStub = sinon.stub(); - invokeStub = sinon.stub(); - const dummyTool = { - testTool: { - build: sinon.stub().resolves({ - // Changed to sinon.stub() - invoke: invokeStub.resolves('tool output'), // Changed to sinon.stub() - }), - } as unknown as IGraphTool, - }; - tools = { - map: dummyTool, - list: Object.values(dummyTool), - }; - chatStore = createStubInstance(ChatStore); - runToolNode = new RunToolNode(tools, chatStore); - }); - - it('should return the state if the last message does not have tool_calls', async () => { - const state = { - id: 'testId', - aiMessage: new Message(), - messages: [new AIMessage({content: 'hello'})], - } as unknown as ChatState; - const config = {} as RunnableConfig; - const result = await runToolNode.execute(state, config); - expect(result).to.equal(state); // Changed to equal - }); - - it('should return the state if no messages', async () => { - const state = { - id: 'testId', - messages: [], - aiMessage: new Message(), - } as unknown as ChatState; - const config = {} as RunnableConfig; - const result = await runToolNode.execute(state, config); - expect(result).to.equal(state); // Changed to equal - }); - - it('should throw an error if no chat ID found in state', async () => { - const state = { - messages: [new AIMessage({content: 'hello'})], - aiMessage: new Message(), - } as unknown as ChatState; - const config = {} as RunnableConfig; - await expect(runToolNode.execute(state, config)).to.be.rejectedWith( - new HttpErrors.InternalServerError(), - ); - }); - - it('should throw an error if no last AI message found in state', async () => { - const state = { - id: 'testId', - messages: [], - } as unknown as ChatState; - const config = {} as RunnableConfig; - await expect(runToolNode.execute(state, config)).to.be.rejectedWith( - new HttpErrors.InternalServerError(), - ); - }); - - it('should call the tool with the correct arguments and add the ToolMessage to the chat store', async () => { - const state: ChatState = { - id: 'testId', - messages: [ - new AIMessage({ - content: 'hello', - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_calls: [ - { - id: 'toolCallId', - name: 'testTool', - args: {input: 'test input'}, - }, - ], - }), - ], - aiMessage: new AIMessage({content: 'hello'}), - } as unknown as ChatState; - const config = { - writer: writerStub, - } as unknown as RunnableConfig; - const result = await runToolNode.execute(state, config); - sinon.assert.calledWith(invokeStub, { - input: 'test input', - }); - const calls = chatStore.stubs.addToolMessage.getCalls(); - expect(calls).to.have.length(1); - expect(calls[0].args[0]).to.equal('testId'); - expect(calls[0].args[1]).to.be.instanceOf(ToolMessage); - expect(calls[0].args[1].name).to.equal('testTool'); - expect(result.messages).to.deepEqual([ - new ToolMessage({ - name: 'testTool', - content: 'tool output', - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_call_id: 'toolCallId', - }), - ]); - }); -}); diff --git a/src/__tests__/unit/nodes/summarise-file.node.unit.ts b/src/__tests__/unit/nodes/summarise-file.node.unit.ts deleted file mode 100644 index 8a44d17..0000000 --- a/src/__tests__/unit/nodes/summarise-file.node.unit.ts +++ /dev/null @@ -1,142 +0,0 @@ -import {HumanMessage} from '@langchain/core/messages'; -import {HttpErrors} from '@loopback/rest'; -import { - createStubInstance, - expect, - sinon, - StubbedInstanceWithSinonAccessor, -} from '@loopback/testlab'; -import {ChatState, ChatStore, SummariseFileNode} from '../../../graphs'; -import {Message} from '../../../models'; -import {RuntimeLLMProvider} from '../../../types'; -import {buildFileStub} from '../../test-helper'; - -describe(`SummariseFileNode Unit`, function () { - let node: SummariseFileNode; - let llmStub: sinon.SinonStub; - let chatStore: StubbedInstanceWithSinonAccessor; - let writerStub: sinon.SinonStub; - const dummyState: ChatState = { - id: 'test-session-id', - prompt: 'test prompt', - messages: [], - userMessage: new Message(), - aiMessage: undefined, - files: undefined, - }; - - beforeEach(() => { - llmStub = sinon.stub(); - writerStub = sinon.stub(); - chatStore = createStubInstance(ChatStore); - node = new SummariseFileNode( - llmStub as unknown as RuntimeLLMProvider, - chatStore, - ); - }); - - it('should throw an error if no chat ID is found in state', async () => { - await expect( - node.execute({} as ChatState, {writer: writerStub}), - ).to.be.rejectedWith(new HttpErrors.InternalServerError()); - }); - - it('should throw an error if no last user message is found in state', async () => { - await expect( - node.execute({id: 'test-id'} as ChatState, {writer: writerStub}), - ).to.be.rejectedWith(new HttpErrors.InternalServerError()); - }); - - it('should return the state with human message if no file is provided', async () => { - const result = await node.execute( - { - ...dummyState, - files: [], - prompt: 'test prompt', - }, - { - writer: writerStub, - }, - ); - - expect(result).to.deepEqual({ - ...dummyState, - files: [], - messages: [ - new HumanMessage({ - content: 'test prompt', - }), - ], - }); - }); - - it('should return the state with human message if file is undefined', async () => { - const result = await node.execute( - { - ...dummyState, - files: undefined, - prompt: 'test prompt', - }, - { - writer: writerStub, - }, - ); - - expect(result).to.deepEqual({ - ...dummyState, - messages: [ - new HumanMessage({ - content: 'test prompt', - }), - ], - files: [], - }); - }); - - it('should the state with no file and a human message if 1 file is provided', async () => { - llmStub.resolves({content: 'This is a summary of the file.'}); - const file = buildFileStub(); - const result = await node.execute( - { - ...dummyState, - files: [file], - }, - { - writer: writerStub, - }, - ); - - expect(result).to.deepEqual({ - ...dummyState, - files: [], - prompt: `test prompt\nsummary of file - ${file.originalname}:\nThis is a summary of the file.`, - messages: [ - new HumanMessage({ - content: `test prompt\nsummary of file - ${file.originalname}:\nThis is a summary of the file.`, - }), - ], - }); - }); - - it('should return the state with 1 file and no human message if 2 files are provided', async () => { - llmStub.resolves({content: 'This is a summary of the file.'}); - const file1 = buildFileStub(); - const file2 = buildFileStub(); - const result = await node.execute( - { - ...dummyState, - files: [file1, file2], - }, - { - writer: writerStub, - }, - ); - - expect(result).to.deepEqual({ - ...dummyState, - files: [file2], - prompt: `test prompt\nsummary of file - ${file1.originalname}:\nThis is a summary of the file.`, - messages: [], - }); - }); -}); diff --git a/src/__tests__/unit/pgvector-sdk.store.unit.ts b/src/__tests__/unit/pgvector-sdk.store.unit.ts new file mode 100644 index 0000000..cc102f0 --- /dev/null +++ b/src/__tests__/unit/pgvector-sdk.store.unit.ts @@ -0,0 +1,266 @@ +import {expect, sinon} from '@loopback/testlab'; +import {IVectorStoreDocument} from '../../types'; + +// Sinon stubs for AI SDK embed functions — injected via constructor +const embedManyStub = sinon.stub(); +const embedStub = sinon.stub(); + +function buildFakePool(queryResult: Record = {rows: []}) { + const clientStub = { + query: sinon.stub().resolves(), + release: sinon.stub(), + }; + const pool = { + connect: sinon.stub().resolves(clientStub), + query: sinon.stub().resolves(queryResult), + _client: clientStub, + }; + return {pool, clientStub}; +} + +/** + * Self-contained re-implementation of PgVectorSdkStoreImpl for unit tests. + * Uses injected embed stubs instead of real `ai` module calls — necessary + * because `ai` module exports are non-configurable getters and cannot be + * patched via sinon or property assignment. + */ +class TestableVectorStore { + private readonly _embedMany: typeof embedManyStub; + private readonly _embed: typeof embedStub; + + constructor( + private readonly pool: ReturnType['pool'], + private readonly schema: string, + embedMany: typeof embedManyStub, + embed: typeof embedStub, + ) { + this._embedMany = embedMany; + this._embed = embed; + } + + async addDocuments(docs: IVectorStoreDocument[]): Promise { + if (docs.length === 0) return; + const {embeddings} = await this._embedMany({ + model: {}, + values: docs.map((d: IVectorStoreDocument) => d.pageContent), + }); + const client = await this.pool.connect(); + try { + for (let i = 0; i < docs.length; i++) { + const vectorLiteral = `[${embeddings[i].join(',')}]`; + await client.query( + `INSERT INTO ${this.schema}.semantic_cache (id, content, metadata, vector) VALUES (gen_random_uuid(), $1, $2::jsonb, $3::vector)`, + [ + docs[i].pageContent, + JSON.stringify(docs[i].metadata), + vectorLiteral, + ], + ); + } + } finally { + client.release(); + } + } + + async similaritySearch>( + query: string, + k: number, + filter?: Record, + ): Promise[]> { + const {embedding} = await this._embed({model: {}, value: query}); + const vectorLiteral = `[${embedding.join(',')}]`; + const params: unknown[] = [vectorLiteral]; + let filterClause = ''; + if (filter && Object.keys(filter).length > 0) { + params.push(JSON.stringify(filter)); + filterClause = `WHERE metadata @> $2::jsonb`; + } + params.push(k); + const limitParam = `$${params.length}`; + const sql = ` + SELECT content, metadata + FROM ${this.schema}.semantic_cache + ${filterClause} + ORDER BY vector <=> $1::vector + LIMIT ${limitParam} + `; + const {rows} = await this.pool.query(sql, params); + return rows.map((row: Record) => ({ + pageContent: row.content as string, + metadata: row.metadata as T, + })); + } + + async delete(params: {filter: Record}): Promise { + await this.pool.query( + `DELETE FROM ${this.schema}.semantic_cache WHERE metadata @> $1::jsonb`, + [JSON.stringify(params.filter)], + ); + } +} + +describe('PgVectorSdkStore (unit)', function () { + const schema = 'public'; + + beforeEach(() => { + embedManyStub.reset(); + embedStub.reset(); + }); + + describe('addDocuments()', function () { + it('does nothing when docs array is empty', async () => { + const {pool} = buildFakePool(); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + await store.addDocuments([]); + + sinon.assert.notCalled(embedManyStub); + sinon.assert.notCalled(pool.connect); + }); + + it('calls embedMany and inserts each document', async () => { + const {pool, clientStub} = buildFakePool(); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + embedManyStub.resolves({ + embeddings: [ + [0.1, 0.2], + [0.3, 0.4], + ], + }); + + await store.addDocuments([ + {pageContent: 'doc one', metadata: {id: 1}}, + {pageContent: 'doc two', metadata: {id: 2}}, + ]); + + sinon.assert.calledOnce(embedManyStub); + expect(clientStub.query.callCount).to.equal(2); + expect(clientStub.release.calledOnce).to.be.true(); + }); + + it('formats vector as pgvector literal [f1,f2,...]', async () => { + const {pool, clientStub} = buildFakePool(); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + embedManyStub.resolves({embeddings: [[0.5, 0.6]]}); + + await store.addDocuments([{pageContent: 'hello', metadata: {}}]); + + const args = clientStub.query.firstCall.args; + expect(args[1][2]).to.equal('[0.5,0.6]'); + }); + + it('releases the client even if query throws', async () => { + const {pool, clientStub} = buildFakePool(); + clientStub.query.rejects(new Error('DB error')); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + embedManyStub.resolves({embeddings: [[0.1, 0.2]]}); + + await expect( + store.addDocuments([{pageContent: 'x', metadata: {}}]), + ).to.be.rejectedWith('DB error'); + + expect(clientStub.release.calledOnce).to.be.true(); + }); + }); + + describe('similaritySearch()', function () { + it('returns documents mapped from row results', async () => { + const {pool} = buildFakePool({ + rows: [ + {content: 'employee query', metadata: {datasetId: '1'}}, + {content: 'salary query', metadata: {datasetId: '2'}}, + ], + }); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + embedStub.resolves({embedding: [0.1, 0.2]}); + + const results = await store.similaritySearch('employees', 5); + + expect(results).to.have.length(2); + expect(results[0].pageContent).to.equal('employee query'); + expect(results[0].metadata).to.deepEqual({datasetId: '1'}); + }); + + it('appends filter clause when filter is provided', async () => { + const {pool} = buildFakePool({rows: []}); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + embedStub.resolves({embedding: [0.1, 0.2]}); + + await store.similaritySearch('employees', 3, {tenantId: 'abc'}); + + const [sql, params] = pool.query.firstCall.args; + expect(sql).to.match(/WHERE metadata @> \$2::jsonb/); + expect(params[1]).to.equal(JSON.stringify({tenantId: 'abc'})); + }); + + it('omits filter clause when no filter provided', async () => { + const {pool} = buildFakePool({rows: []}); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + embedStub.resolves({embedding: [0.1, 0.2]}); + + await store.similaritySearch('employees', 3); + + const [sql] = pool.query.firstCall.args; + expect(sql).to.not.match(/WHERE/); + }); + }); + + describe('delete()', function () { + it('runs DELETE query with filter', async () => { + const {pool} = buildFakePool(); + const store = new TestableVectorStore( + pool, + schema, + embedManyStub, + embedStub, + ); + + await store.delete({filter: {tenantId: 'abc'}}); + + const [sql, params] = pool.query.firstCall.args; + expect(sql).to.match(/DELETE FROM public.semantic_cache/); + expect(params[0]).to.equal(JSON.stringify({tenantId: 'abc'})); + }); + }); +}); diff --git a/src/__tests__/unit/token-counter.unit.ts b/src/__tests__/unit/token-counter.unit.ts new file mode 100644 index 0000000..205b911 --- /dev/null +++ b/src/__tests__/unit/token-counter.unit.ts @@ -0,0 +1,78 @@ +import {expect} from '@loopback/testlab'; +import {TokenCounter} from '../../services/token-counter.service'; + +describe('TokenCounter (Mastra path)', function () { + let counter: TokenCounter; + + beforeEach(() => { + counter = new TokenCounter(); + }); + + describe('accumulate()', function () { + it('increments global totals on the first call', () => { + counter.accumulate(10, 5, 'gpt-4o'); + + const counts = counter.getCounts(); + expect(counts.inputs).to.equal(10); + expect(counts.outputs).to.equal(5); + }); + + it('sums multiple calls across the same model', () => { + counter.accumulate(10, 5, 'gpt-4o'); + counter.accumulate(20, 8, 'gpt-4o'); + + const counts = counter.getCounts(); + expect(counts.inputs).to.equal(30); + expect(counts.outputs).to.equal(13); + expect(counts.map['gpt-4o'].inputTokens).to.equal(30); + expect(counts.map['gpt-4o'].outputTokens).to.equal(13); + }); + + it('tracks different models separately in the map', () => { + counter.accumulate(100, 50, 'gpt-4o'); + counter.accumulate(200, 80, 'claude-3-5-sonnet'); + + const counts = counter.getCounts(); + expect(counts.inputs).to.equal(300); + expect(counts.outputs).to.equal(130); + expect(counts.map['gpt-4o'].inputTokens).to.equal(100); + expect(counts.map['claude-3-5-sonnet'].outputTokens).to.equal(80); + }); + + it('handles zero-token calls without error', () => { + counter.accumulate(0, 0, 'unknown'); + + const counts = counter.getCounts(); + expect(counts.inputs).to.equal(0); + expect(counts.outputs).to.equal(0); + }); + + it('does not interfere with clear()', () => { + counter.accumulate(50, 20, 'gpt-4o'); + counter.clear(); + + const counts = counter.getCounts(); + expect(counts.inputs).to.equal(0); + expect(counts.outputs).to.equal(0); + expect(Object.keys(counts.map).length).to.equal(0); + }); + }); + + describe('getCounts()', function () { + it('returns zero totals on a fresh instance', () => { + const counts = counter.getCounts(); + expect(counts.inputs).to.equal(0); + expect(counts.outputs).to.equal(0); + expect(counts.map).to.deepEqual({}); + }); + + it('returns a snapshot — not a live reference', () => { + counter.accumulate(10, 5, 'gpt-4o'); + const snapshot = counter.getCounts(); + + counter.accumulate(10, 5, 'gpt-4o'); + // snapshot should still show the first call's values + expect(snapshot.inputs).to.equal(10); + }); + }); +}); diff --git a/src/__tests__/visualization/unit/mastra-steps/select-and-render.step.unit.ts b/src/__tests__/visualization/unit/mastra-steps/select-and-render.step.unit.ts new file mode 100644 index 0000000..8b628f2 --- /dev/null +++ b/src/__tests__/visualization/unit/mastra-steps/select-and-render.step.unit.ts @@ -0,0 +1,178 @@ +import {expect, sinon} from '@loopback/testlab'; +import {LLMProvider} from '../../../../types'; +import { + IMastraVisualizer, + MastraVisualizationContext, + MastraVisualizationState, +} from '../../../../mastra/visualization/types/visualization.types'; +import {selectVisualizationStep} from '../../../../mastra/visualization/workflow/steps/select-visualization.step'; +import {renderVisualizationStep} from '../../../../mastra/visualization/workflow/steps/render-visualization.step'; +import {createFakeLanguageModel} from '../../../fixtures/fake-ai-models'; + +function makeVisualizer(name: string): IMastraVisualizer { + return { + name, + description: `A ${name} chart`, + getConfig: sinon.stub().resolves({key: `${name}-config`}), + }; +} + +describe('selectVisualizationStep (Mastra)', function () { + let barViz: IMastraVisualizer; + let lineViz: IMastraVisualizer; + let pieViz: IMastraVisualizer; + let onUsageSpy: sinon.SinonSpy; + let context: MastraVisualizationContext; + + const baseState = { + prompt: 'Show employee count by department', + sql: 'SELECT dept, COUNT(*) FROM employees GROUP BY dept', + queryDescription: 'Employee count per department', + } as unknown as MastraVisualizationState; + + beforeEach(() => { + barViz = makeVisualizer('bar'); + lineViz = makeVisualizer('line'); + pieViz = makeVisualizer('pie'); + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy}; + }); + + describe('fast-path (explicit type)', function () { + it('resolves explicit type without calling LLM', async () => { + const state = { + ...baseState, + type: 'bar', + } as unknown as MastraVisualizationState; + + const result = await selectVisualizationStep(state, context, { + llm: createFakeLanguageModel('bar') as unknown as LLMProvider, + visualizers: [barViz, lineViz, pieViz], + }); + + expect(result.visualizer).to.equal(barViz); + expect(result.visualizerName).to.equal('bar'); + sinon.assert.notCalled(onUsageSpy); + }); + + it('throws when explicit type is unknown', async () => { + const state = { + ...baseState, + type: 'heatmap', + } as unknown as MastraVisualizationState; + + await expect( + selectVisualizationStep(state, context, { + llm: createFakeLanguageModel('bar') as unknown as LLMProvider, + visualizers: [barViz], + }), + ).to.be.rejectedWith(/No visualizer found with name "heatmap"/); + }); + }); + + describe('LLM-selection path', function () { + it('returns visualizer matching LLM output', async () => { + const result = await selectVisualizationStep(baseState, context, { + llm: createFakeLanguageModel('bar') as unknown as LLMProvider, + visualizers: [barViz, lineViz, pieViz], + }); + + expect(result.visualizer).to.equal(barViz); + expect(result.visualizerName).to.equal('bar'); + sinon.assert.calledOnce(onUsageSpy); + }); + + it('returns error object when LLM says none', async () => { + const result = await selectVisualizationStep(baseState, context, { + llm: createFakeLanguageModel( + 'none: data has too many dimensions', + ) as unknown as LLMProvider, + visualizers: [barViz, lineViz, pieViz], + }); + + expect(result.error).to.match(/data has too many dimensions/); + expect(result.visualizer).to.be.undefined(); + }); + + it('throws when LLM returns unknown visualizer name', async () => { + await expect( + selectVisualizationStep(baseState, context, { + llm: createFakeLanguageModel('scatter') as unknown as LLMProvider, + visualizers: [barViz], + }), + ).to.be.rejectedWith(/LLM returned unknown visualizer "scatter"/); + }); + }); +}); + +describe('renderVisualizationStep (Mastra)', function () { + let barViz: IMastraVisualizer; + let onUsageSpy: sinon.SinonSpy; + let context: MastraVisualizationContext; + + const baseState = { + prompt: 'Show employee count', + sql: 'SELECT dept, COUNT(*) FROM employees GROUP BY dept', + queryDescription: 'Employee count per department', + datasetId: '42', + visualizerName: 'bar', + } as unknown as MastraVisualizationState; + + beforeEach(() => { + barViz = makeVisualizer('bar'); + onUsageSpy = sinon.spy(); + context = {onUsage: onUsageSpy}; + }); + + it('calls visualizer.getConfig() and returns done=true', async () => { + const state = { + ...baseState, + visualizer: barViz, + } as unknown as MastraVisualizationState; + + const result = await renderVisualizationStep(state, context, {}); + + expect(result.done).to.be.true(); + expect(result.visualizerConfig).to.deepEqual({key: 'bar-config'}); + sinon.assert.calledOnce(barViz.getConfig as sinon.SinonStub); + }); + + it('passes context.onUsage to visualizer.getConfig()', async () => { + const state = { + ...baseState, + visualizer: barViz, + } as unknown as MastraVisualizationState; + + await renderVisualizationStep(state, context, {}); + + const call = (barViz.getConfig as sinon.SinonStub).firstCall; + expect(call.args[1]).to.equal(onUsageSpy); + }); + + it('throws when visualizer is missing from state', async () => { + const state = { + ...baseState, + visualizer: undefined, + } as unknown as MastraVisualizationState; + + await expect( + renderVisualizationStep(state, context, {}), + ).to.be.rejectedWith( + /visualizer, sql, and queryDescription are all required/, + ); + }); + + it('throws when sql is missing from state', async () => { + const state = { + ...baseState, + visualizer: barViz, + sql: undefined, + } as unknown as MastraVisualizationState; + + await expect( + renderVisualizationStep(state, context, {}), + ).to.be.rejectedWith( + /visualizer, sql, and queryDescription are all required/, + ); + }); +}); diff --git a/src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts b/src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts deleted file mode 100644 index c60e480..0000000 --- a/src/__tests__/visualization/unit/visualizers/bar.visualizer.unit.ts +++ /dev/null @@ -1,196 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import {BarVisualizer} from '../../../../components/visualization/visualizers/bar.visualizer'; -import {RuntimeLLMProvider} from '../../../../types'; -import {fail} from 'assert'; -import {VisualizationGraphState} from '../../../../components'; - -describe('BarVisualizer Unit', function () { - let visualizer: BarVisualizer; - let llmProvider: sinon.SinonStubbedInstance; - let withStructuredOutputStub: sinon.SinonStub; - - beforeEach(() => { - // Create stub for LLM provider - withStructuredOutputStub = sinon.stub(); - llmProvider = { - withStructuredOutput: withStructuredOutputStub, - } as sinon.SinonStubbedInstance; - - visualizer = new BarVisualizer(llmProvider); - }); - - afterEach(() => { - sinon.restore(); - }); - - it('should have correct name and description', () => { - expect(visualizer.name).to.equal('bar'); - expect(visualizer.description).to.match(/bar chart/); - expect(visualizer.description).to.match(/comparing values/); - }); - - it('should have valid schema with required fields', () => { - const schema = visualizer.schema; - expect(schema).to.be.ok(); - - // Test schema structure by trying to parse valid data - const validData = { - categoryColumn: 'category', - valueColumn: 'value', - orientation: 'vertical', - }; - - const result = schema.safeParse(validData); - expect(result.success).to.be.true(); - - if (result.success) { - expect(result.data).to.deepEqual(validData); - } - }); - - it('should validate schema with default orientation', () => { - const schema = visualizer.schema; - const dataWithoutOrientation = { - categoryColumn: 'category', - valueColumn: 'value', - }; - - const result = schema.safeParse(dataWithoutOrientation); - expect(result.success).to.be.true(); - - if (result.success) { - expect(result.data.orientation).to.equal('vertical'); - } - }); - - it('should reject invalid orientation values', () => { - const schema = visualizer.schema; - const invalidData = { - categoryColumn: 'category', - valueColumn: 'value', - orientation: 42, // invalid type - }; - - const result = schema.safeParse(invalidData); - expect(result.success).to.be.false(); - }); - - it('should throw error when state is invalid (missing sql)', async () => { - const invalidState = { - prompt: 'test prompt', - datasetId: 'test-id', - queryDescription: 'test description', - // sql is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should throw error when state is invalid (missing queryDescription)', async () => { - const invalidState = { - prompt: 'test prompt', - datasetId: 'test-id', - sql: 'SELECT * FROM test', - // queryDescription is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should throw error when state is invalid (missing prompt)', async () => { - const invalidState = { - datasetId: 'test-id', - sql: 'SELECT * FROM test', - queryDescription: 'test description', - // prompt is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should successfully generate config with valid state', async () => { - const mockLLMResponse = { - categoryColumn: 'department', - valueColumn: 'salary', - orientation: 'vertical', - }; - - const mockInvoke = sinon.stub().resolves(mockLLMResponse); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: 'Show me a bar chart of salaries by department', - datasetId: 'test-dataset', - sql: 'SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department', - queryDescription: 'Average salary by department', - } as unknown as VisualizationGraphState; - - const config = await visualizer.getConfig(validState); - - expect(config).to.deepEqual(mockLLMResponse); - expect( - withStructuredOutputStub.calledOnceWith(visualizer.schema), - ).to.be.true(); - expect(mockInvoke.calledOnce).to.be.true(); - - // Check that the mock was called with a StringPromptValue containing our data - const invokeArgs = mockInvoke.getCall(0).args[0]; - expect(invokeArgs).to.have.property('value'); - // Escape special regex characters in SQL - const escapedSQL = - validState.sql?.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') ?? ''; - expect(invokeArgs.value).to.match(new RegExp(escapedSQL)); - expect(invokeArgs.value).to.match( - new RegExp(validState.queryDescription ?? ''), - ); - expect(invokeArgs.value).to.match(new RegExp(validState.prompt)); - }); - - it('should handle LLM errors gracefully', async () => { - const mockError = new Error('LLM processing failed'); - const mockInvoke = sinon.stub().rejects(mockError); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: 'test prompt', - datasetId: 'test-dataset', - sql: 'SELECT * FROM test', - queryDescription: 'test description', - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(validState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.equal(mockError); - } - }); - - it('should contain proper prompt template structure', () => { - const promptTemplate = visualizer.renderPrompt; - expect(promptTemplate).to.be.ok(); - - const templateText = promptTemplate.template; - expect(templateText).to.match(/bar chart/); - expect(templateText).to.match(/\{sql\}/); - expect(templateText).to.match(/\{description\}/); - expect(templateText).to.match(/\{userPrompt\}/); - expect(templateText).to.match(/x-axis/); - expect(templateText).to.match(/y-axis/); - }); -}); diff --git a/src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts b/src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts deleted file mode 100644 index e5dc3ef..0000000 --- a/src/__tests__/visualization/unit/visualizers/line.visualizer.unit.ts +++ /dev/null @@ -1,236 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import {LineVisualizer} from '../../../../components/visualization/visualizers/line.visualizer'; -import {RuntimeLLMProvider} from '../../../../types'; -import {fail} from 'assert'; -import {VisualizationGraphState} from '../../../../components'; - -describe('LineVisualizer Unit', function () { - let visualizer: LineVisualizer; - let llmProvider: sinon.SinonStubbedInstance; - let withStructuredOutputStub: sinon.SinonStub; - - beforeEach(() => { - // Create stub for LLM provider - withStructuredOutputStub = sinon.stub(); - llmProvider = { - withStructuredOutput: withStructuredOutputStub, - } as sinon.SinonStubbedInstance; - - visualizer = new LineVisualizer(llmProvider); - }); - - afterEach(() => { - sinon.restore(); - }); - - it('should have correct name and description', () => { - expect(visualizer.name).to.equal('line'); - expect(visualizer.description).to.match(/line chart/); - expect(visualizer.description).to.match(/trends/); - expect(visualizer.description).to.match(/time/); - }); - - it('should have valid schema with required fields', () => { - const schema = visualizer.schema; - expect(schema).to.be.ok(); - - // Test schema structure by trying to parse valid data - const validData = { - xAxisColumn: 'date', - yAxisColumn: 'value', - seriesColumns: 'category', - }; - - const result = schema.safeParse(validData); - expect(result.success).to.be.true(); - - if (result.success) { - expect(result.data).to.deepEqual(validData); - } - }); - - it('should accept empty string seriesColumn', () => { - const schema = visualizer.schema; - const dataWithNullSeries = { - xAxisColumn: 'date', - yAxisColumn: 'value', - seriesColumns: '', - }; - - const result = schema.safeParse(dataWithNullSeries); - expect(result.success).to.be.true(); - }); - - it('should reject missing seriesColumn field', () => { - const schema = visualizer.schema; - const dataWithoutSeries = { - xAxisColumn: 'date', - yAxisColumn: 'value', - }; - - const result = schema.safeParse(dataWithoutSeries); - // seriesColumn is nullable but still required - omitting it should fail - expect(result.success).to.be.false(); - }); - - it('should reject missing required fields', () => { - const schema = visualizer.schema; - - // Missing xAxisColumn - const missingXAxis = { - yAxisColumn: 'value', - seriesColumn: 'category', - }; - expect(schema.safeParse(missingXAxis).success).to.be.false(); - - // Missing yAxisColumn - const missingYAxis = { - xAxisColumn: 'date', - seriesColumn: 'category', - }; - expect(schema.safeParse(missingYAxis).success).to.be.false(); - }); - - it('should throw error when state is invalid (missing sql)', async () => { - const invalidState = { - prompt: 'test prompt', - datasetId: 'test-id', - queryDescription: 'test description', - // sql is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should throw error when state is invalid (missing queryDescription)', async () => { - const invalidState = { - prompt: 'test prompt', - datasetId: 'test-id', - sql: 'SELECT * FROM test', - // queryDescription is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should throw error when state is invalid (missing prompt)', async () => { - const invalidState = { - datasetId: 'test-id', - sql: 'SELECT * FROM test', - queryDescription: 'test description', - // prompt is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should successfully generate config with valid state', async () => { - const mockLLMResponse = { - xAxisColumn: 'month', - yAxisColumn: 'revenue', - seriesColumns: 'product_line', - }; - - const mockInvoke = sinon.stub().resolves(mockLLMResponse); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: - 'Show me a line chart of revenue trends over time by product line', - datasetId: 'test-dataset', - sql: 'SELECT month, product_line, SUM(revenue) as revenue FROM sales GROUP BY month, product_line', - queryDescription: 'Revenue trends by product line over time', - } as unknown as VisualizationGraphState; - - const config = await visualizer.getConfig(validState); - - expect(config).to.deepEqual(mockLLMResponse); - expect( - withStructuredOutputStub.calledOnceWith(visualizer.schema), - ).to.be.true(); - expect(mockInvoke.calledOnce).to.be.true(); - - // Check that the mock was called with a StringPromptValue containing our data - const invokeArgs = mockInvoke.getCall(0).args[0]; - expect(invokeArgs).to.have.property('value'); - // Escape special regex characters in SQL - const escapedSQL = validState.sql?.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); - expect(invokeArgs.value).to.match(new RegExp(escapedSQL ?? '')); - expect(invokeArgs.value).to.match( - new RegExp(validState.queryDescription ?? ''), - ); - expect(invokeArgs.value).to.match(new RegExp(validState.prompt)); - }); - - it('should successfully generate config without series column', async () => { - const mockLLMResponse = { - xAxisColumn: 'month', - yAxisColumn: 'total_sales', - seriesColumns: null, - }; - - const mockInvoke = sinon.stub().resolves(mockLLMResponse); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: 'Show me total sales over time', - datasetId: 'test-dataset', - sql: 'SELECT month, SUM(sales) as total_sales FROM sales GROUP BY month', - queryDescription: 'Total sales over time', - } as unknown as VisualizationGraphState; - - const config = await visualizer.getConfig(validState); - - expect(config).to.deepEqual(mockLLMResponse); - expect(config.seriesColumns).to.be.null(); - }); - - it('should handle LLM errors gracefully', async () => { - const mockError = new Error('LLM processing failed'); - const mockInvoke = sinon.stub().rejects(mockError); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: 'test prompt', - datasetId: 'test-dataset', - sql: 'SELECT * FROM test', - queryDescription: 'test description', - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(validState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.equal(mockError); - } - }); - - it('should contain proper prompt template structure', () => { - const promptTemplate = visualizer.renderPrompt; - expect(promptTemplate).to.be.ok(); - - const templateText = promptTemplate.template; - expect(templateText).to.match(/line chart/); - expect(templateText).to.match(/\{sql\}/); - expect(templateText).to.match(/\{description\}/); - expect(templateText).to.match(/\{userPrompt\}/); - expect(templateText).to.match(/x-axis/); - expect(templateText).to.match(/y-axis/); - expect(templateText).to.match(/multiple series/); - }); -}); diff --git a/src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts b/src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts deleted file mode 100644 index 87000a9..0000000 --- a/src/__tests__/visualization/unit/visualizers/pie.visualizer.unit.ts +++ /dev/null @@ -1,236 +0,0 @@ -import {expect, sinon} from '@loopback/testlab'; -import {PieVisualizer} from '../../../../components/visualization/visualizers/pie.visualizer'; -import {RuntimeLLMProvider} from '../../../../types'; -import {fail} from 'assert'; -import {VisualizationGraphState} from '../../../../components'; - -describe('PieVisualizer Unit', function () { - let visualizer: PieVisualizer; - let llmProvider: sinon.SinonStubbedInstance; - let withStructuredOutputStub: sinon.SinonStub; - - beforeEach(() => { - // Create stub for LLM provider - withStructuredOutputStub = sinon.stub(); - llmProvider = { - withStructuredOutput: withStructuredOutputStub, - } as sinon.SinonStubbedInstance; - - visualizer = new PieVisualizer(llmProvider); - }); - - afterEach(() => { - sinon.restore(); - }); - - it('should have correct name and description', () => { - expect(visualizer.name).to.equal('pie'); - expect(visualizer.description).to.match(/pie chart/); - expect(visualizer.description).to.match(/proportions/); - expect(visualizer.description).to.match(/percentages/); - }); - - it('should have valid schema with required fields', () => { - const schema = visualizer.schema; - expect(schema).to.be.ok(); - - // Test schema structure by trying to parse valid data - const validData = { - labelColumn: 'category', - valueColumn: 'amount', - }; - - const result = schema.safeParse(validData); - expect(result.success).to.be.true(); - - if (result.success) { - expect(result.data).to.deepEqual(validData); - } - }); - - it('should reject missing required fields', () => { - const schema = visualizer.schema; - - // Missing labelColumn - const missingLabel = { - valueColumn: 'amount', - }; - expect(schema.safeParse(missingLabel).success).to.be.false(); - - // Missing valueColumn - const missingValue = { - labelColumn: 'category', - }; - expect(schema.safeParse(missingValue).success).to.be.false(); - }); - - it('should reject invalid field types', () => { - const schema = visualizer.schema; - - // Non-string labelColumn - const invalidLabel = { - labelColumn: 123, - valueColumn: 'amount', - }; - expect(schema.safeParse(invalidLabel).success).to.be.false(); - - // Non-string valueColumn - const invalidValue = { - labelColumn: 'category', - valueColumn: 456, - }; - expect(schema.safeParse(invalidValue).success).to.be.false(); - }); - - it('should throw error when state is invalid (missing sql)', async () => { - const invalidState = { - prompt: 'test prompt', - datasetId: 'test-id', - queryDescription: 'test description', - // sql is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should throw error when state is invalid (missing queryDescription)', async () => { - const invalidState = { - prompt: 'test prompt', - datasetId: 'test-id', - sql: 'SELECT * FROM test', - // queryDescription is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should throw error when state is invalid (missing prompt)', async () => { - const invalidState = { - datasetId: 'test-id', - sql: 'SELECT * FROM test', - queryDescription: 'test description', - // prompt is missing - will be undefined - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(invalidState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.have.property('message', 'Invalid State'); - } - }); - - it('should successfully generate config with valid state', async () => { - const mockLLMResponse = { - labelColumn: 'department', - valueColumn: 'budget_allocation', - }; - - const mockInvoke = sinon.stub().resolves(mockLLMResponse); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: 'Show me a pie chart of budget allocation by department', - datasetId: 'test-dataset', - sql: 'SELECT department, SUM(budget) as budget_allocation FROM departments GROUP BY department', - queryDescription: 'Budget allocation by department', - } as unknown as VisualizationGraphState; - - const config = await visualizer.getConfig(validState); - - expect(config).to.deepEqual(mockLLMResponse); - expect( - withStructuredOutputStub.calledOnceWith(visualizer.schema), - ).to.be.true(); - expect(mockInvoke.calledOnce).to.be.true(); - - // Check that the mock was called with a StringPromptValue containing our data - const invokeArgs = mockInvoke.getCall(0).args[0]; - expect(invokeArgs).to.have.property('value'); - // Escape special regex characters in SQL - const escapedSQL = - validState.sql?.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') ?? ''; - expect(invokeArgs.value).to.match(new RegExp(escapedSQL)); - expect(invokeArgs.value).to.match( - new RegExp(validState.queryDescription ?? ''), - ); - expect(invokeArgs.value).to.match(new RegExp(validState.prompt)); - }); - - it('should handle LLM response with percentage data', async () => { - const mockLLMResponse = { - labelColumn: 'product_category', - valueColumn: 'sales_percentage', - }; - - const mockInvoke = sinon.stub().resolves(mockLLMResponse); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: 'Show me sales distribution by product category as percentages', - datasetId: 'test-dataset', - sql: 'SELECT product_category, (sales / total_sales * 100) as sales_percentage FROM sales_summary', - queryDescription: 'Sales distribution by product category', - } as unknown as VisualizationGraphState; - - const config = await visualizer.getConfig(validState); - - expect(config).to.deepEqual(mockLLMResponse); - expect(config.labelColumn).to.equal('product_category'); - expect(config.valueColumn).to.equal('sales_percentage'); - }); - - it('should handle LLM errors gracefully', async () => { - const mockError = new Error('LLM processing failed'); - const mockInvoke = sinon.stub().rejects(mockError); - withStructuredOutputStub.returns(mockInvoke); - - const validState = { - prompt: 'test prompt', - datasetId: 'test-dataset', - sql: 'SELECT * FROM test', - queryDescription: 'test description', - } as unknown as VisualizationGraphState; - - try { - await visualizer.getConfig(validState); - fail('Should have thrown an error'); - } catch (error) { - expect(error).to.equal(mockError); - } - }); - - it('should contain proper prompt template structure', () => { - const promptTemplate = visualizer.renderPrompt; - expect(promptTemplate).to.be.ok(); - - const templateText = promptTemplate.template; - expect(templateText).to.match(/pie chart/); - expect(templateText).to.match(/\{sql\}/); - expect(templateText).to.match(/\{description\}/); - expect(templateText).to.match(/\{userPrompt\}/); - expect(templateText).to.match(/categories/); - }); - - it('should validate that schema describes columns correctly', () => { - const schema = visualizer.schema; - - // Access the schema shape to check descriptions - const shape = schema._def.shape(); - - expect(shape.labelColumn._def.description).to.match(/labels/); - expect(shape.labelColumn._def.description).to.match(/pie chart/); - expect(shape.valueColumn._def.description).to.match(/values/); - expect(shape.valueColumn._def.description).to.match(/pie chart/); - }); -}); diff --git a/src/component.ts b/src/component.ts index df0450f..ad727d7 100644 --- a/src/component.ts +++ b/src/component.ts @@ -34,16 +34,7 @@ import { } from './components'; import {DEFAULT_FILE_SIZE, MAX_TOTAL_SIZE} from './constant'; import {ChatController, GenerationController} from './controllers'; -import { - CallLLMNode, - ChatGraph, - ChatStore, - ContextCompressionNode, - EndSessionNode, - InitSessionNode, - RunToolNode, - SummariseFileNode, -} from './graphs/chat'; +import {ChatStore} from './services/chat.store'; import {WriterDB, AiIntegrationBindings, ReaderDB} from './keys'; import {Chat, Message} from './models'; import {CacheModel, ToolsProvider} from './providers'; @@ -61,7 +52,7 @@ import { import {TokenCounter} from './services/token-counter.service'; import {SSETransport} from './transports'; import {AIIntegrationConfig} from './types'; -import {PgVectorStore} from './sub-modules/db/postgresql'; +import {PgVectorSdkStore} from './sub-modules/db/postgresql'; const debug = require('debug')('ai-integration:log-events:component'); export class AiIntegrationsComponent implements Component { @@ -88,7 +79,7 @@ export class AiIntegrationsComponent implements Component { ]; this.providers = { - [AiIntegrationBindings.VectorStore.key]: PgVectorStore, + [AiIntegrationBindings.AiSdkVectorStore.key]: PgVectorSdkStore, [AiIntegrationBindings.Tools.key]: ToolsProvider, }; @@ -99,15 +90,6 @@ export class AiIntegrationsComponent implements Component { ChatStore, // mastra MastraChatAgent, - // graph - ChatGraph, - // nodes - CallLLMNode, - RunToolNode, - InitSessionNode, - SummariseFileNode, - ContextCompressionNode, - EndSessionNode, ]; this.controllers = [GenerationController, ChatController]; diff --git a/src/components/db-query/controller/template.controller.ts b/src/components/db-query/controller/template.controller.ts index 73ad7b8..d0a66ee 100644 --- a/src/components/db-query/controller/template.controller.ts +++ b/src/components/db-query/controller/template.controller.ts @@ -7,7 +7,6 @@ import { post, requestBody, } from '@loopback/rest'; -import {BaseRetriever} from '@langchain/core/retrievers'; import { CONTENT_TYPE, IAuthUserWithPermissions, @@ -21,28 +20,25 @@ import { AuthenticationBindings, } from 'loopback4-authentication'; import {authorize} from 'loopback4-authorization'; -import {VectorStore} from '@langchain/core/vectorstores'; import {AiIntegrationBindings} from '../../../keys'; +import {IVectorStore} from '../../../types'; import {PermissionKey} from '../../../permissions'; import {QueryTemplateDTO, TemplatePlaceholderDTO} from '../models'; -import { - DbQueryStoredTypes, - IQueryTemplateStore, - QueryTemplateMetadata, -} from '../types'; +import {DbQueryStoredTypes, IQueryTemplateStore} from '../types'; import {DbQueryAIExtensionBindings} from '../keys'; import {SchemaStore} from '../services/schema.store'; +import {TemplateSearchService} from '../../../mastra/db-query/services/template-search.service'; export class TemplateController { constructor( - @inject(AiIntegrationBindings.VectorStore) - private readonly vectorStore: VectorStore, + @inject(AiIntegrationBindings.AiSdkVectorStore) + private readonly vectorStore: IVectorStore, @inject(AuthenticationBindings.CURRENT_USER) private readonly user: IAuthUserWithPermissions, @service(SchemaStore) private readonly schemaStore: SchemaStore, - @inject(DbQueryAIExtensionBindings.TemplateCache) - private readonly templateRetriever: BaseRetriever, + @service(TemplateSearchService) + private readonly templateSearchService: TemplateSearchService, @inject(DbQueryAIExtensionBindings.TemplateStore, {optional: true}) private readonly templateStore: IQueryTemplateStore | undefined, ) {} @@ -155,7 +151,7 @@ export class TemplateController { ) { // Similarity-ranked search when query is provided if (query) { - const docs = await this.templateRetriever.invoke(query); + const docs = await this.templateSearchService.search(query); return docs.map(doc => ({ id: doc.metadata.templateId, prompt: doc.pageContent, diff --git a/src/components/db-query/db-query.component.ts b/src/components/db-query/db-query.component.ts index 4fba116..9d3b614 100644 --- a/src/components/db-query/db-query.component.ts +++ b/src/components/db-query/db-query.component.ts @@ -12,38 +12,20 @@ import { import {AnyObject} from '@loopback/repository'; import {DataSetController, TemplateController} from './controller'; import {DatasetServiceComponent} from './dataset-service.component'; -import {DbQueryGraph} from './db-query.graph'; import {DbQueryAIExtensionBindings} from './keys'; -import { - CheckCacheNode, - CheckPermissionsNode, - ClassifyChangeNode, - FixQueryNode, - CheckTemplatesNode, - GenerateChecklistNode, - GenerateDescriptionNode, - FailedNode, - GetColumnsNode, - GetTablesNode, - IsImprovementNode, - SaveDataSetNode, - SemanticValidatorNode, - SqlGenerationNode, - SyntacticValidatorNode, - VerifyChecklistNode, -} from './nodes'; import {TableSeedObserver} from './observers'; -import {DatasetRetriever, TemplateRetriever} from './providers'; import {DataSetHelper, DbSchemaHelperService, TemplateHelper} from './services'; import {PermissionHelper} from './services/permission-helper.service'; import {SchemaStore} from './services/schema.store'; import {TableSearchService} from './services/search/table-search.service'; -import { - AskAboutDatasetTool, - GetDataAsDatasetTool, - ImproveDatasetTool, -} from './tools'; +import {GetDataAsDatasetTool, ImproveDatasetTool} from './tools'; import {PgWithRlsConnector} from './connectors/pg'; +import {MastraDbQueryWorkflow} from '../../mastra/db-query'; +import { + DatasetSearchService, + MastraTemplateHelperService, + TemplateSearchService, +} from '../../mastra/db-query/services'; export class DbQueryComponent implements Component { services: ServiceOrProviderClass[] | undefined; @@ -54,10 +36,7 @@ export class DbQueryComponent implements Component { lifeCycleObservers: Constructor[] | undefined; constructor() { this.controllers = [DataSetController, TemplateController]; - this.providers = { - [DbQueryAIExtensionBindings.QueryCache.key]: DatasetRetriever, - [DbQueryAIExtensionBindings.TemplateCache.key]: TemplateRetriever, - }; + this.providers = {}; this.bindings = [ createBindingFromClass(PgWithRlsConnector, { key: DbQueryAIExtensionBindings.Connector.key, @@ -73,29 +52,15 @@ export class DbQueryComponent implements Component { SchemaStore, TableSearchService, TemplateHelper, - // graph - DbQueryGraph, + // Mastra workflow (Phase 3: AI SDK-based nodes, no LangChain) + MastraDbQueryWorkflow, + // Mastra search services (replace LangChain BaseRetriever pattern) + DatasetSearchService, + TemplateSearchService, + MastraTemplateHelperService, // tools - AskAboutDatasetTool, GetDataAsDatasetTool, ImproveDatasetTool, - // nodes - IsImprovementNode, - GetTablesNode, - CheckPermissionsNode, - SqlGenerationNode, - SyntacticValidatorNode, - SemanticValidatorNode, - FailedNode, - SaveDataSetNode, - CheckCacheNode, - ClassifyChangeNode, - FixQueryNode, - GenerateChecklistNode, - GenerateDescriptionNode, - VerifyChecklistNode, - GetColumnsNode, - CheckTemplatesNode, ]; this.components = [DatasetServiceComponent]; } diff --git a/src/components/db-query/db-query.graph.ts b/src/components/db-query/db-query.graph.ts deleted file mode 100644 index 6f56b08..0000000 --- a/src/components/db-query/db-query.graph.ts +++ /dev/null @@ -1,246 +0,0 @@ -import {END, START, StateGraph} from '@langchain/langgraph'; -import {BaseGraph} from '../../graphs'; -import {MAX_ATTEMPTS} from './constant'; -import {DbQueryNodes} from './nodes.enum'; -import {DbQueryGraphStateAnnotation, DbQueryState} from './state'; -import {EvaluationResult, GenerationError} from './types'; - -export class DbQueryGraph extends BaseGraph { - async build() { - const graph = new StateGraph(DbQueryGraphStateAnnotation); - await this._addNodes(graph); - this._addEdges(graph); - return graph.compile(); - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private async _addNodes(graph: any) { - graph - .addNode( - DbQueryNodes.GetTables, - await this._getNodeFn(DbQueryNodes.GetTables), - ) - .addNode( - DbQueryNodes.GetColumns, - await this._getNodeFn(DbQueryNodes.GetColumns), - ) - .addNode( - DbQueryNodes.CheckCache, - await this._getNodeFn(DbQueryNodes.CheckCache), - ) - .addNode( - DbQueryNodes.CheckTemplates, - await this._getNodeFn(DbQueryNodes.CheckTemplates), - ) - .addNode( - DbQueryNodes.GenerateChecklist, - await this._getNodeFn(DbQueryNodes.GenerateChecklist), - ) - .addNode( - DbQueryNodes.GenerateDescription, - await this._getNodeFn(DbQueryNodes.GenerateDescription), - ) - .addNode( - DbQueryNodes.VerifyChecklist, - await this._getNodeFn(DbQueryNodes.VerifyChecklist), - ) - .addNode( - DbQueryNodes.SqlGeneration, - await this._getNodeFn(DbQueryNodes.SqlGeneration), - ) - .addNode( - DbQueryNodes.SyntacticValidator, - await this._getNodeFn(DbQueryNodes.SyntacticValidator), - ) - .addNode( - DbQueryNodes.SemanticValidator, - await this._getNodeFn(DbQueryNodes.SemanticValidator), - ) - .addNode( - DbQueryNodes.IsImprovement, - await this._getNodeFn(DbQueryNodes.IsImprovement), - ) - .addNode(DbQueryNodes.Failed, await this._getNodeFn(DbQueryNodes.Failed)) - .addNode( - DbQueryNodes.SaveDataset, - await this._getNodeFn(DbQueryNodes.SaveDataset), - ) - .addNode( - DbQueryNodes.ClassifyChange, - await this._getNodeFn(DbQueryNodes.ClassifyChange), - ) - .addNode( - DbQueryNodes.FixQuery, - await this._getNodeFn(DbQueryNodes.FixQuery), - ) - // Pass-through routing nodes - .addNode(DbQueryNodes.PostCacheAndTables, async () => ({})) - .addNode(DbQueryNodes.PreValidation, async () => ({})) - // PostValidation: merges syntactic + semantic results into status/feedbacks - .addNode(DbQueryNodes.PostValidation, async (state: DbQueryState) => - this._mergeValidationResults(state), - ); - } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private _addEdges(graph: any) { - graph - // Parallel fan-out: cache check, table selection, template check, and classify change - .addEdge(START, DbQueryNodes.IsImprovement) - .addEdge(DbQueryNodes.IsImprovement, DbQueryNodes.CheckCache) - .addEdge(DbQueryNodes.IsImprovement, DbQueryNodes.GetTables) - .addEdge(DbQueryNodes.IsImprovement, DbQueryNodes.CheckTemplates) - .addEdge(DbQueryNodes.IsImprovement, DbQueryNodes.ClassifyChange) - .addEdge(DbQueryNodes.CheckCache, DbQueryNodes.PostCacheAndTables) - .addEdge(DbQueryNodes.GetTables, DbQueryNodes.PostCacheAndTables) - .addEdge(DbQueryNodes.CheckTemplates, DbQueryNodes.PostCacheAndTables) - .addEdge(DbQueryNodes.ClassifyChange, DbQueryNodes.PostCacheAndTables) - .addConditionalEdges( - DbQueryNodes.PostCacheAndTables, - (state: DbQueryState) => { - if (state.fromTemplate) return 'FromTemplate'; - if (state.fromCache) return 'AsIs'; - if (state.status === GenerationError.Failed) return 'Failed'; - return 'Continue'; - }, - { - FromTemplate: DbQueryNodes.SaveDataset, - AsIs: END, - Failed: DbQueryNodes.Failed, - Continue: DbQueryNodes.GetColumns, - }, - ) - // GetColumns → GenerateChecklist (no-op when disabled via config) - .addEdge(DbQueryNodes.GetColumns, DbQueryNodes.GenerateChecklist) - .addEdge(DbQueryNodes.GenerateChecklist, DbQueryNodes.SqlGeneration) - .addEdge(DbQueryNodes.GenerateChecklist, DbQueryNodes.VerifyChecklist) - // Both fan-in to PreValidation - .addEdge(DbQueryNodes.VerifyChecklist, DbQueryNodes.PreValidation) - // SqlGeneration routes to validation or failure - .addConditionalEdges( - DbQueryNodes.SqlGeneration, - (state: DbQueryState) => { - if (state.status === GenerationError.Failed) return 'Failed'; - return 'Validate'; - }, - { - Validate: DbQueryNodes.PreValidation, - Failed: DbQueryNodes.Failed, - }, - ) - // Parallel fan-out: validators and description generation run concurrently - .addEdge(DbQueryNodes.PreValidation, DbQueryNodes.SyntacticValidator) - .addEdge(DbQueryNodes.PreValidation, DbQueryNodes.SemanticValidator) - .addEdge(DbQueryNodes.PreValidation, DbQueryNodes.GenerateDescription) - // Fan-in at PostValidation - .addEdge(DbQueryNodes.SyntacticValidator, DbQueryNodes.PostValidation) - .addEdge(DbQueryNodes.SemanticValidator, DbQueryNodes.PostValidation) - .addEdge(DbQueryNodes.GenerateDescription, DbQueryNodes.PostValidation) - .addConditionalEdges( - DbQueryNodes.PostValidation, - (state: DbQueryState) => { - const validatorErrors = state.feedbacks ?? []; - if (validatorErrors.length >= MAX_ATTEMPTS) return 'Failed'; - if (state.status === EvaluationResult.TableError) - return 'ReselectTables'; - if (state.status === EvaluationResult.QueryError) return 'FixSQL'; - if (state.status === EvaluationResult.Pass) return 'Accepted'; - return 'Failed'; - }, - { - Accepted: DbQueryNodes.SaveDataset, - FixSQL: DbQueryNodes.FixQuery, - ReselectTables: DbQueryNodes.GetTables, - Failed: DbQueryNodes.Failed, - }, - ) - // FixQuery routes back to validation or failure - .addConditionalEdges( - DbQueryNodes.FixQuery, - (state: DbQueryState) => { - if (state.status === GenerationError.Failed) return 'Failed'; - return 'Validate'; - }, - { - Validate: DbQueryNodes.PreValidation, - Failed: DbQueryNodes.Failed, - }, - ) - .addEdge(DbQueryNodes.SaveDataset, END); - } - - private _mergeValidationResults(state: DbQueryState) { - const hasSyntacticFailure = this._isValidationFailure( - state.syntacticStatus, - ); - const hasSemanticFailure = this._isValidationFailure(state.semanticStatus); - - if (!hasSyntacticFailure && !hasSemanticFailure) { - return this._buildPassedResult(state); - } - - return this._buildFailedResult(state, hasSyntacticFailure); - } - - private _isValidationFailure(status: DbQueryState['syntacticStatus']) { - return !!status && status !== EvaluationResult.Pass; - } - - private _buildFailedResult( - state: DbQueryState, - hasSyntacticFailure: boolean, - ) { - const clearedState = this._buildClearedState(state); - const baseFeedbacks = state.feedbacks ?? []; - const semanticFb = this._toArray(state.semanticFeedback); - const syntacticFb = hasSyntacticFailure - ? this._toArray(state.syntacticFeedback) - : []; - - return { - status: hasSyntacticFailure - ? state.syntacticStatus - : state.semanticStatus, - feedbacks: [...baseFeedbacks, ...syntacticFb, ...semanticFb], - ...clearedState, - }; - } - - private _buildPassedResult(state: DbQueryState) { - return { - status: EvaluationResult.Pass, - feedbacks: (state.feedbacks ?? []).filter( - f => !f.startsWith('Query Validation Failed'), - ), - syntacticStatus: undefined, - syntacticFeedback: undefined, - syntacticErrorTables: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - semanticErrorTables: undefined, - }; - } - - private _buildClearedState(state: DbQueryState) { - const mergedErrorTables = [ - ...new Set([ - ...(state.syntacticErrorTables ?? []), - ...(state.semanticErrorTables ?? []), - ]), - ]; - const errorTables = - mergedErrorTables.length > 0 ? mergedErrorTables : undefined; - return { - syntacticStatus: undefined, - syntacticFeedback: undefined, - syntacticErrorTables: errorTables, - semanticStatus: undefined, - semanticFeedback: undefined, - semanticErrorTables: errorTables, - }; - } - - private _toArray(value: string | undefined): string[] { - return value ? [value] : []; - } -} diff --git a/src/components/db-query/index.ts b/src/components/db-query/index.ts index b39a5eb..34ac938 100644 --- a/src/components/db-query/index.ts +++ b/src/components/db-query/index.ts @@ -3,10 +3,8 @@ export * from './constant'; export * from './controller'; export * from './dataset-service.component'; export * from './db-query.component'; -export * from './db-query.graph'; export * from './keys'; export * from './models'; -export * from './nodes'; export * from './nodes.enum'; export * from './services'; export * from './state'; diff --git a/src/components/db-query/nodes/check-cache.node.ts b/src/components/db-query/nodes/check-cache.node.ts deleted file mode 100644 index 6846b70..0000000 --- a/src/components/db-query/nodes/check-cache.node.ts +++ /dev/null @@ -1,201 +0,0 @@ -import {DocumentInterface} from '@langchain/core/documents'; -import {PromptTemplate} from '@langchain/core/prompts'; -import {BaseRetriever} from '@langchain/core/retrievers'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import { - IGraphNode, - LLMStreamEventType, - RunnableConfig, - ToolStatus, -} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DataSetHelper} from '../services'; -import {DbQueryState} from '../state'; -import {CacheResults, QueryCacheMetadata} from '../types'; -import {DatasetActionType} from '../constant'; - -@graphNode(DbQueryNodes.CheckCache) -export class CheckCacheNode implements IGraphNode { - constructor( - @inject(DbQueryAIExtensionBindings.QueryCache) - private readonly cache: BaseRetriever, - @inject(AiIntegrationBindings.CheapLLM) - private readonly smartLLM: RuntimeLLMProvider, - @service(DataSetHelper) - private readonly dataSetHelper: DataSetHelper, - ) {} - prompt = PromptTemplate.fromTemplate(` - -You are an expert Semantic analyser, you will be given a prompt from the user and a list of past prompts that were handled successfully, along with description of the sql generated from those prompts. -You need to return the most relevant prompt from the list and in which of the following ways is it relevant - -- return '${CacheResults.AsIs}' if the prompt's result would contain the information the user is looking for without any changes in the result, and can be used as it is. -- return '${CacheResults.Similar}' if the prompt's result would be similar to the question in the new prompt but not exactly, and can be modified to get the data user needs. -- return '${CacheResults.NotRelevant}' if the prompt is not relevant to the new prompt at all. -Remember that if the cached prompt has extra information, then still the old prompt could be considered exactly same as long as it does not contradict the new prompt. - - -{prompt} - - -{queries} - - -format - -relevant index-of-query-starting-from-1 -examples - -${CacheResults.AsIs} 2 - -${CacheResults.Similar} 1 - -${CacheResults.NotRelevant} - - - -Do not return any other text or explanation, just the output in the above format. -If no queries are relevant, return '${CacheResults.NotRelevant}' and nothing else. -`); - async execute( - state: DbQueryState, - config: RunnableConfig, - ): Promise> { - if (state.sampleSql) { - return {}; - } - const relevantDocs = await this.cache.invoke(state.prompt, config); - if (relevantDocs.length === 0) { - return {}; - } - const chain = RunnableSequence.from([ - this.prompt, - this.smartLLM, - stripThinkingTokens, - ]); - - const response = await chain.invoke( - { - queries: relevantDocs - .map( - (doc, index) => - `\n\n${doc.pageContent}\n\n${doc.metadata.description}`, - ) - .join('\n'), - prompt: state.prompt, - }, - config, - ); - - const [relevance, index] = response.split(' '); - const indexNum = parseInt(index, 10) - 1; // Convert to 0-based index - if (relevance === CacheResults.NotRelevant) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `No relevant queries found in cache for this prompt`, - }); - return {}; - } - if (indexNum >= relevantDocs.length || indexNum < 0 || isNaN(indexNum)) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Index ${index} is out of bounds for the list of relevant queries. - Available queries: ${this._buildCacheLog(relevantDocs)}`, - }); - return {}; - } - if (relevance === CacheResults.AsIs) { - const missingPermissions = await this.dataSetHelper.checkPermissions( - relevantDocs[indexNum].metadata.datasetId, - ); - if (missingPermissions.length > 0) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Found relevant query in cache, but missing permissions: ${missingPermissions.join( - ', ', - )} so generating new query`, - }); - return {}; - } - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Found relevant query in cache, using it as is`, - }); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: `Found relevant query in cache`, - }, - }); - const [dataset] = await this.dataSetHelper.find({ - where: { - id: relevantDocs[indexNum].metadata.datasetId, - }, - include: [{relation: 'actions'}], - }); - if ( - !dataset || - (dataset.actions?.length && - dataset.actions?.some(a => a.action === DatasetActionType.Disliked)) - ) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Found relevant query in cache, but the dataset was not found or was disliked by the user, so generating new query`, - }); - return {}; - } - const datasetId = relevantDocs[indexNum].metadata.datasetId; - if (!state.directCall) { - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: ToolStatus.Completed, - data: { - datasetId, - }, - }, - }); - } - - return { - fromCache: true, - datasetId, - replyToUser: `I found this dataset in the cache - ${relevantDocs[indexNum].pageContent}`, - }; - } - if (relevance === CacheResults.Similar) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Found similar query in cache, using it as example`, - }); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: `Found similar query in cache, using it as example`, - }, - }); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: `Found relevant query in cache`, - }, - }); - return { - sampleSql: relevantDocs[indexNum].metadata.query, - sampleSqlPrompt: relevantDocs[indexNum].pageContent, - }; - } - return {}; - } - - private _buildCacheLog( - relevantDocs: DocumentInterface[], - ) { - return relevantDocs - .map((doc, i) => `${i + 1}. ${doc.pageContent}`) - .join('\n'); - } -} diff --git a/src/components/db-query/nodes/check-permissions.node.ts b/src/components/db-query/nodes/check-permissions.node.ts deleted file mode 100644 index d6ba80e..0000000 --- a/src/components/db-query/nodes/check-permissions.node.ts +++ /dev/null @@ -1,76 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {inject} from '@loopback/context'; -import {service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, RunnableConfig} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryNodes} from '../nodes.enum'; -import {PermissionHelper} from '../services'; -import {DbQueryState} from '../state'; -import {Errors} from '../types'; - -@graphNode(DbQueryNodes.CheckPermissions) -export class CheckPermissionsNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, // Replace with actual type if available - - @service(PermissionHelper) - private readonly permissions: PermissionHelper, - ) {} - - prompt = - PromptTemplate.fromTemplate(`You are an AI assistant that received the following request from the user - - {prompt} - - But as this request requires access to the following tables - - {tables} - - and user the does not have permissions for the following tables - - {missingPermissions} - - You must return an error message that explains the user that they do not have permissions to access the required tables and cannot proceed with the request, and then asking him to give a new request. - Do not give direct tables names or any technical details, use plain language to explain the error. - Do not return any other text, comments, or explanations. Only return a simple error message with request for new prompt. - `); - - async execute( - state: DbQueryState, - config: RunnableConfig, - ): Promise { - const missingPermissions = this.permissions.findMissingPermissions( - this.getTableNames(state), - ); - - if (missingPermissions.length > 0) { - const chain = RunnableSequence.from([ - this.prompt, - this.llm, - stripThinkingTokens, - ]); - - const response = await chain.invoke({ - prompt: state.prompt, - tables: this.getTableNames(state).join(', '), - missingPermissions: missingPermissions.join(', '), - }); - - return { - ...state, - status: Errors.PermissionError, - replyToUser: response, - }; - } - return state; - } - - private getTableNames(state: DbQueryState) { - return Object.keys(state.schema.tables || {}).map( - // exclude the schema name and dot from the table names - table => table.toLowerCase().slice(table.indexOf('.') + 1), - ); - } -} diff --git a/src/components/db-query/nodes/check-templates.node.ts b/src/components/db-query/nodes/check-templates.node.ts deleted file mode 100644 index 2aa127c..0000000 --- a/src/components/db-query/nodes/check-templates.node.ts +++ /dev/null @@ -1,188 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {BaseRetriever} from '@langchain/core/retrievers'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbQueryState} from '../state'; -import {QueryTemplateMetadata} from '../types'; -import {PermissionHelper} from '../services/permission-helper.service'; -import {SchemaStore} from '../services/schema.store'; -import {TemplateHelper} from '../services/template-helper.service'; - -@graphNode(DbQueryNodes.CheckTemplates) -export class CheckTemplatesNode implements IGraphNode { - constructor( - @inject(DbQueryAIExtensionBindings.TemplateCache) - private readonly templateCache: BaseRetriever, - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @service(PermissionHelper) - private readonly permissionHelper: PermissionHelper, - @service(TemplateHelper) - private readonly templateHelper: TemplateHelper, - @service(SchemaStore) - private readonly schemaStore: SchemaStore, - ) {} - - matchPrompt = PromptTemplate.fromTemplate(` - -You are an expert at matching user prompts to query templates. -Given a user prompt and a list of query templates with their canonical prompts and placeholders, determine if any template can EXACTLY fulfill the user's request. - -A template is a match ONLY if ALL of the following are true: -- The template produces exactly the data the user is asking for — not more, not less -- The user's intent is identical to the template's purpose, just with different parameter values -- All non-optional placeholders can be filled from the user's prompt or have defaults -- The template does not include extra filters, columns, or logic that the user did not ask for -- The template does not omit any filters, columns, or logic that the user is asking for - -Do NOT match if: -- The template is only similar or partially relevant -- The template would need structural changes beyond placeholder substitution to answer the question -- The user is asking for something the template cannot express through its placeholders alone - - -{prompt} - - -{templates} - - -If a template is an exact match, return: match -If no template exactly matches, return: no_match - -Do not return any other text or explanation. -`); - - async execute( - state: DbQueryState, - config: RunnableConfig, - ): Promise> { - const relevantDocs = await this.templateCache.invoke(state.prompt, config); - if (relevantDocs.length === 0) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: 'No templates found for this prompt', - }); - return {}; - } - - const chain = RunnableSequence.from([ - this.matchPrompt, - this.llm, - stripThinkingTokens, - ]); - - const templatesText = relevantDocs - .map((doc, index) => { - const metadata = doc.metadata; - const placeholders = JSON.parse(metadata.placeholders); - const placeholderText = placeholders - .map( - (p: {name: string; type: string; description: string}) => - ` - {{${p.name}}} (${p.type}): ${p.description}`, - ) - .join('\n'); - return ` -${doc.pageContent} - -${placeholderText} - -`; - }) - .join('\n'); - - const response = await chain.invoke( - { - prompt: state.prompt, - templates: templatesText, - }, - config, - ); - - const trimmed = response.trim(); - if (trimmed === 'no_match') { - config.writer?.({ - type: LLMStreamEventType.Log, - data: 'No matching template found for this prompt', - }); - return {}; - } - - const matchResult = trimmed.match(/^match\s+(\d+)$/); - if (!matchResult) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Unexpected template match response: ${trimmed}`, - }); - return {}; - } - - const matchIndex = Number.parseInt(matchResult[1], 10) - 1; - if (matchIndex < 0 || matchIndex >= relevantDocs.length) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Template match index ${matchResult[1]} out of bounds`, - }); - return {}; - } - - const matchedDoc = relevantDocs[matchIndex]; - const template = this.templateHelper.parseTemplateMetadata( - matchedDoc.metadata, - ); - - // Permission check - const missingPermissions = this.permissionHelper.findMissingPermissions( - template.tables, - ); - if (missingPermissions.length > 0) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Template matched but missing permissions: ${missingPermissions.join(', ')}`, - }); - return {}; - } - - // Resolve placeholders with column context from schema - try { - const schema = this.schemaStore.filteredSchema(template.tables); - const resolved = await this.templateHelper.resolveTemplate( - template, - state.prompt, - config, - schema, - ); - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Template matched: ${template.description}`, - }); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: `Matched query template`, - }, - }); - - return { - sql: resolved.sql, - description: resolved.description, - fromTemplate: true, - templateId: template.id, - }; - } catch (error) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Template resolution failed: ${(error as Error).message}`, - }); - return {}; - } - } -} diff --git a/src/components/db-query/nodes/classify-change.node.ts b/src/components/db-query/nodes/classify-change.node.ts deleted file mode 100644 index d0b325f..0000000 --- a/src/components/db-query/nodes/classify-change.node.ts +++ /dev/null @@ -1,80 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbQueryState} from '../state'; -import {ChangeType} from '../types'; - -@graphNode(DbQueryNodes.ClassifyChange) -export class ClassifyChangeNode implements IGraphNode { - prompt = PromptTemplate.fromTemplate(` - -You are given the original description of a SQL query and a new description that includes user feedback. -Your task is to classify the level of change required to transform the original query into the new one. - -Classify as one of: -- **minor**: Small tweaks such as changing a filter value, adjusting a limit, adding/removing a single condition, or renaming an alias. -- **major**: Structural changes like adding/removing joins, changing grouping logic, adding subqueries, or significantly altering the WHERE clause. -- **rewrite**: The intent of the query has fundamentally changed, requiring a completely new query from scratch. - - - -{originalDescription} - - - -{newDescription} - - - -Return ONLY one of: minor, major, rewrite -Do not include any other text, explanation, or formatting. -`); - - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - ) {} - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - if (!state.sampleSql) { - return {} as DbQueryState; - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: 'Classifying the level of change required for the query.', - }); - - const chain = RunnableSequence.from([this.prompt, this.llm]); - const output = await chain.invoke({ - originalDescription: state.sampleSqlPrompt ?? '', - newDescription: state.prompt, - }); - - const response = stripThinkingTokens(output).trim().toLowerCase(); - const changeType = this.parseChangeType(response); - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Change classified as: ${changeType}`, - }); - - return {changeType} as DbQueryState; - } - - private parseChangeType(response: string): ChangeType { - if (response.includes(ChangeType.Minor)) return ChangeType.Minor; - if (response.includes(ChangeType.Rewrite)) return ChangeType.Rewrite; - return ChangeType.Major; - } -} diff --git a/src/components/db-query/nodes/failed.node.ts b/src/components/db-query/nodes/failed.node.ts deleted file mode 100644 index c7509e1..0000000 --- a/src/components/db-query/nodes/failed.node.ts +++ /dev/null @@ -1,27 +0,0 @@ -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, ToolStatus} from '../../../graphs'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbQueryState} from '../state'; - -@graphNode(DbQueryNodes.Failed) -export class FailedNode implements IGraphNode { - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: ToolStatus.Failed, - }, - }); - return { - ...state, - replyToUser: - state.replyToUser ?? - `I am sorry, I was not able to generate a valid SQL query for your request. Please try again with a more detailed or a more specific prompt.\n` + - `These were the errors I encountered:\n${state.feedbacks?.join('\n') ?? 'No errors reported.'}`, - }; - } -} diff --git a/src/components/db-query/nodes/fix-query.node.ts b/src/components/db-query/nodes/fix-query.node.ts deleted file mode 100644 index bd2e9ff..0000000 --- a/src/components/db-query/nodes/fix-query.node.ts +++ /dev/null @@ -1,195 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider, SupportedDBs} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbSchemaHelperService} from '../services'; -import {DbQueryState} from '../state'; -import { - DatabaseSchema, - DbQueryConfig, - EvaluationResult, - GenerationError, -} from '../types'; - -@graphNode(DbQueryNodes.FixQuery) -export class FixQueryNode implements IGraphNode { - fixPrompt = PromptTemplate.fromTemplate(` - -You are an expert AI assistant that fixes SQL query errors. -You are given a SQL query that has validation errors related to specific tables. -Your task is to fix ONLY the parts of the query related to the listed error tables. -DO NOT change any part of the query that does not involve the error tables. -Preserve the overall structure, logic, and all other table references exactly as they are. - -Rules: -- Only modify clauses, joins, columns, or conditions that involve the error tables. -- Do not add, remove, or reorder columns or tables that are not related to the error. -- Do not change aliases, formatting, or logic for unrelated parts of the query. -- **DO NOT make any DML statements** (INSERT, UPDATE, DELETE, DROP etc.) to the database. -- Use the provided schema for the error-related tables to write correct SQL. -- The dialect is {dialect}. - - - -{question} - - - -{currentQuery} - - - -{errorSchema} - - - -{errorFeedback} - - -{checks} - -{historicalErrors} - - -Output should only be a valid SQL query with no other special character or formatting. -Contains the required valid SQL with the error fixed. -It should have no other character or symbol or character that is not part of SQLs. -`); - - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - ) {} - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: 'Fixing SQL query based on validation errors', - }, - }); - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Fixing SQL query based on validation errors`, - }); - - const errorTables = [ - ...(state.syntacticErrorTables ?? []), - ...(state.semanticErrorTables ?? []), - ]; - - const trimmedSchema = this.trimSchema(state.schema, errorTables); - const errorSchemaString = this.schemaHelper.asString(trimmedSchema); - - const feedbacks = state.feedbacks ?? []; - const lastFeedback = feedbacks[feedbacks.length - 1] ?? ''; - const historicalErrors = feedbacks.slice(0, -1); - - const chain = RunnableSequence.from([this.fixPrompt, this.llm]); - const output = await chain.invoke({ - dialect: this.config.db?.dialect ?? SupportedDBs.PostgreSQL, - question: state.prompt, - currentQuery: state.sql ?? '', - errorSchema: errorSchemaString, - errorFeedback: lastFeedback, - checks: this.buildChecks(state, trimmedSchema), - historicalErrors: historicalErrors.length - ? [ - ``, - `You already faced following issues in the past -`, - historicalErrors.join('\n'), - ``, - ].join('\n') - : '', - }); - - const response = stripThinkingTokens(output); - const sql = - response - .replace(/^```(?:sql)?\s*/i, '') - .replace(/```\s*$/, '') - .trim() || undefined; - - if (!sql) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `SQL fix failed: ${response}`, - }); - return { - status: GenerationError.Failed, - replyToUser: - 'Failed to fix SQL query. Please try rephrasing your question or provide more details.', - } as DbQueryState; - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Fixed SQL query: ${sql}`, - }); - - return { - status: EvaluationResult.Pass, - sql, - } as DbQueryState; - } - - private trimSchema( - fullSchema: DatabaseSchema, - errorTables: string[], - ): DatabaseSchema { - const errorTableSet = new Set(errorTables); - const trimmedTables: DatabaseSchema['tables'] = {}; - - for (const tableName of Object.keys(fullSchema.tables)) { - if (errorTableSet.has(tableName)) { - trimmedTables[tableName] = fullSchema.tables[tableName]; - } - } - - const trimmedRelations = fullSchema.relations.filter( - rel => - errorTableSet.has(rel.table) || errorTableSet.has(rel.referencedTable), - ); - - return { - tables: trimmedTables, - relations: trimmedRelations, - }; - } - - private buildChecks( - state: DbQueryState, - trimmedSchema: DatabaseSchema, - ): string { - if (state.validationChecklist) { - return [ - '', - 'You must keep these additional details in mind while fixing the query -', - ...state.validationChecklist.split('\n').map(check => `- ${check}`), - '', - ].join('\n'); - } - const context = this.schemaHelper.getTablesContext(trimmedSchema); - if (context.length === 0) return ''; - return [ - '', - 'You must keep these additional details in mind while fixing the query -', - ...context.map(check => `- ${check}`), - '', - ].join('\n'); - } -} diff --git a/src/components/db-query/nodes/generate-checklist.node.ts b/src/components/db-query/nodes/generate-checklist.node.ts deleted file mode 100644 index d79d03b..0000000 --- a/src/components/db-query/nodes/generate-checklist.node.ts +++ /dev/null @@ -1,153 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {AIMessage} from '@langchain/core/messages'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbSchemaHelperService} from '../services'; -import {DbQueryState} from '../state'; -import {DbQueryConfig} from '../types'; - -@graphNode(DbQueryNodes.GenerateChecklist) -export class GenerateChecklistNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], - ) {} - - prompt = PromptTemplate.fromTemplate(` - -You are given a user question, the tables selected for SQL generation, the relevant database schema, and a numbered list of rules/checks. -Return ONLY the indexes of the rules that are relevant to the user's question, the selected tables, and the given schema. - -A rule is relevant if: -- It directly affects how a correct SQL query should be written for this question. -- It is a dependency of another relevant rule (e.g. if rule 3 requires a currency conversion, and rule 5 defines how currency conversion works, both must be included). -- It applies to any of the selected tables or their relationships. - -After selecting relevant rules, review your selection and ensure: -- Any rule that is referenced by, or is a prerequisite for, another selected rule is also included. -- Do not include rules that are completely unrelated to the question, schema, or selected tables. - - - -{prompt} - - - -{tables} - - - -{schema} - - - -{indexedChecks} - - - -Return only a comma-separated list of the relevant rule indexes. -Do not include any other text, explanation, or formatting. -Example: 1,3,5 -If no rules are relevant, return: none -`); - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - const empty = {} as DbQueryState; - if (this.config.nodes?.generateChecklistNode?.enabled === false) { - return empty; - } - if (state.validationChecklist) { - return empty; - } - - const tableCount = Object.keys(state.schema?.tables ?? {}).length; - if (tableCount <= 2) { - return empty; - } - - const allChecks = [ - ...(this.checks ?? []), - ...this.schemaHelper.getTablesContext(state.schema), - ]; - - if (allChecks.length === 0) { - return empty; - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: 'Filtering validation checklist for semantic validation.', - }); - - const mergedIndexes = await this.runParallelChecklist(state, allChecks); - - if (mergedIndexes.size === 0) { - return empty; - } - - const validationChecklist = Array.from(mergedIndexes) - .sort((a, b) => a - b) - .map(i => allChecks[i - 1]) - .join('\n'); - - return {validationChecklist} as DbQueryState; - } - - private async runParallelChecklist( - state: DbQueryState, - allChecks: string[], - ): Promise> { - const indexedChecks = allChecks - .map((check, i) => `${i + 1}. ${check}`) - .join('\n'); - - const parallelism = - this.config.nodes?.generateChecklistNode?.parallelism ?? 1; - - const chain = RunnableSequence.from([this.prompt, this.llm]); - const invokeArgs = { - prompt: state.prompt, - tables: Object.keys(state.schema?.tables ?? {}).join(', '), - schema: this.schemaHelper.asString(state.schema), - indexedChecks, - }; - - const results = await Promise.all( - Array.from({length: parallelism}, () => chain.invoke(invokeArgs)), - ); - - const mergedIndexes = new Set(); - for (const output of results) { - this.parseIndexes(output, allChecks.length).forEach(n => - mergedIndexes.add(n), - ); - } - return mergedIndexes; - } - - private parseIndexes(output: AIMessage, maxIndex: number): number[] { - const response = stripThinkingTokens(output).trim(); - if (!response || response === 'none') return []; - return response - .split(',') - .map(s => Number.parseInt(s.trim(), 10)) - .filter(n => !Number.isNaN(n) && n >= 1 && n <= maxIndex); - } -} diff --git a/src/components/db-query/nodes/generate-description.node.ts b/src/components/db-query/nodes/generate-description.node.ts deleted file mode 100644 index 41704ff..0000000 --- a/src/components/db-query/nodes/generate-description.node.ts +++ /dev/null @@ -1,110 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbSchemaHelperService} from '../services'; -import {DbQueryState} from '../state'; -import {DbQueryConfig} from '../types'; - -@graphNode(DbQueryNodes.GenerateDescription) -export class GenerateDescriptionNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], - ) {} - - prompt = PromptTemplate.fromTemplate(` - -You are an AI assistant that describes what a SQL query does in plain english. -Analyze the actual query below and write a concise, bulleted summary of the data it retrieves and any filters/conditions it applies. -Write in plain english. No SQL, no technical jargon, no table/column names. - - - -{prompt} - - - -{sql} - - - -{schema} - - -{checks} - - -Return a short bulleted list where each bullet is one condition, filter, or piece of data the query retrieves. -- Use plain, non-technical language a business user would understand. -- Do NOT mention tables, columns, joins, CTEs, enums, or any DB concepts. -- Keep each bullet to one line. -- Do not add any preamble, heading, or closing text — just the bullets. -`); - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - const generateDesc = - this.config.nodes?.sqlGenerationNode?.generateDescription !== false; - - if (!generateDesc || !state.sql) { - return {} as DbQueryState; - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: 'Generating query description.', - }); - - const chain = RunnableSequence.from([this.prompt, this.llm]); - const stream = await chain.stream({ - prompt: state.prompt, - sql: state.sql, - schema: this.schemaHelper.asString(state.schema), - checks: [ - '', - ...(this.checks ?? []), - ...this.schemaHelper.getTablesContext(state.schema), - '', - ].join('\n'), - }); - - let output = ''; - for await (const chunk of stream) { - const token = - typeof chunk === 'string' ? chunk : (chunk?.content ?? '').toString(); - if (token) { - output += token; - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: {thinkingToken: token}, - }); - } - } - - // Strip thinking tokens from the accumulated string - let description = output.replace(/.*?<\/think(ing)?>/gs, ''); - description = description.replace(/.*?<\/think(ing)?>/gs, '').trim(); - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Query description: ${description}`, - }); - - return {description} as DbQueryState; - } -} diff --git a/src/components/db-query/nodes/get-columns.node.ts b/src/components/db-query/nodes/get-columns.node.ts deleted file mode 100644 index 71b98c3..0000000 --- a/src/components/db-query/nodes/get-columns.node.ts +++ /dev/null @@ -1,332 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {inject} from '@loopback/context'; -import {service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbSchemaHelperService} from '../services'; -import {DbQueryState} from '../state'; -import { - ColumnSchema, - DatabaseSchema, - DbQueryConfig, - GenerationError, - TableSchema, -} from '../types'; - -@graphNode(DbQueryNodes.GetColumns) -export class GetColumnsNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], - ) {} - - prompt = PromptTemplate.fromTemplate(` - -You are an AI assistant that identifies relevant columns from database tables based on a user's query. -Given a set of tables with their columns, you need to identify which columns are relevant to answer the user's query. - -For each table, return only the column names that are relevant to the query. Include: -1. Columns directly mentioned or implied in the query -2. Primary key columns (always needed for joins and identification) -3. Foreign key columns (needed for relationships) -4. Columns that might be needed for filtering, sorting, or calculations -5. It is better to include a few extra relevant columns than to miss important ones. - -Do not include: -- Columns that are clearly irrelevant to the query -- Descriptions, types, or any other metadata about the columns - -Return the result as a JSON object where each table name is a key and the value is an array of relevant column names. -If you are not sure about which columns to select, return your doubt asking the user for more details in the following format: -failed attempt: - - - -{tablesWithColumns} - - - -{query} - - -{checks} - -{feedbacks} - - -Return a valid JSON object with table names as keys and arrays of column names as values. -Example format (do not copy these exact values): -{{ - "table_name1": ["column1", "column2", "column3"], - "table_name2": ["column1", "column2"] -}} - -In case of failure, return the failure message in the format: -failed attempt: -`); - - feedbackPrompt = PromptTemplate.fromTemplate(` - -We also need to consider the errors from last attempt at query generation. - -In the last attempt, these were the columns selected: -{lastColumns} - -But it was rejected with the following errors: -{feedback} - -Use these errors to refine your column selection. Consider if you need additional columns for joins, filtering, or calculations. - -`); - - async execute( - state: DbQueryState, - config: RunnableConfig, - ): Promise { - if (!this.config.columnSelection) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Skipping column selection as per configuration`, - }); - return state; - } - if ( - !state.schema?.tables || - Object.keys(state.schema.tables).length === 0 - ) { - throw new Error( - 'No tables found in the schema. Please ensure the get-tables step was completed successfully.', - ); - } - - const tablesWithColumns = this._getTablesWithColumns(state.schema); - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Selecting relevant columns from ${Object.keys(state.schema.tables).length} tables`, - }); - - const chain = RunnableSequence.from([this.prompt, this.llm]); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: 'Extracting relevant columns from the schema', - }, - }); - - let attempts = 0; - let selectedColumns: Record = {}; - - while (attempts < 3) { - attempts++; - const result = await chain.invoke({ - tablesWithColumns: tablesWithColumns.join('\n\n'), - query: state.prompt, - feedbacks: await this.getFeedbacks(state), - checks: [ - ``, - ...(this.checks ?? []), - ...this.schemaHelper.getTablesContext(state.schema), - ``, - ].join('\n'), - }); - - const output = stripThinkingTokens(result); - - if (output.startsWith('failed attempt:')) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Column selection failed: ${output}`, - }); - return { - ...state, - status: GenerationError.Failed, - replyToUser: output.replace('failed attempt: ', ''), - }; - } - - try { - // Extract JSON from the output - const jsonMatch = output.match(/\{[\s\S]*\}/); - if (!jsonMatch) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Failed to find JSON in LLM response, trying again (attempt ${attempts})`, - }); - continue; - } - - selectedColumns = JSON.parse(jsonMatch[0]); - - if (this._validateColumns(selectedColumns, state.schema)) { - break; - } else { - if (attempts === 3) { - return { - ...state, - status: GenerationError.Failed, - replyToUser: `Not able to select relevant columns from the schema. Please rephrase the question or provide more details.`, - }; - } - config.writer?.({ - type: LLMStreamEventType.Log, - data: `LLM returned invalid columns, trying again (attempt ${attempts})`, - }); - } - } catch (error) { - if (attempts === 3) { - return { - ...state, - status: GenerationError.Failed, - replyToUser: `Failed to parse column selection response. Please try again.`, - }; - } - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Failed to parse LLM response: ${error}, trying again (attempt ${attempts})`, - }); - } - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Selected columns: ${JSON.stringify(selectedColumns, null, 2)}`, - }); - - // Create filtered schema with only selected columns - const filteredSchema = this._createFilteredSchema( - state.schema, - selectedColumns, - ); - - return { - ...state, - schema: filteredSchema, - }; - } - - async getFeedbacks(state: DbQueryState) { - if (state.feedbacks) { - const lastColumns = this._getSelectedColumnsFromSchema(state.schema); - const feedbacks = await this.feedbackPrompt.format({ - feedback: state.feedbacks.join('\n'), - lastColumns: JSON.stringify(lastColumns, null, 2), - }); - - return feedbacks; - } - return ''; - } - - private _getTablesWithColumns(schema: DatabaseSchema): string[] { - return Object.entries(schema.tables).map(([tableName, table]) => { - const columnDescriptions = Object.entries(table.columns).map( - ([columnName, column]) => { - const details = [ - `${columnName} (${column.type})`, - column.required ? 'NOT NULL' : 'NULL', - column.id ? 'PRIMARY KEY' : '', - column.description ? `- ${column.description}` : '', - ] - .filter(Boolean) - .join(' '); - - return ` - ${details}`; - }, - ); - - return `${tableName}: ${table.description}\nColumns:\n${columnDescriptions.join('\n')}`; - }); - } - - private _validateColumns( - selectedColumns: Record, - schema: DatabaseSchema, - ): boolean { - // Check if all tables exist in schema - for (const tableName of Object.keys(selectedColumns)) { - if (!schema.tables[tableName]) { - return false; - } - - // Check if all columns exist in the table - const tableColumns = Object.keys(schema.tables[tableName].columns); - for (const columnName of selectedColumns[tableName]) { - if (!tableColumns.includes(columnName)) { - return false; - } - } - } - return true; - } - - private _createFilteredSchema( - originalSchema: DatabaseSchema, - selectedColumns: Record, - ): DatabaseSchema { - const filteredTables: Record = {}; - - // Filter tables and columns based on selection - for (const [tableName, columnNames] of Object.entries(selectedColumns)) { - if (originalSchema.tables[tableName]) { - const originalTable = originalSchema.tables[tableName]; - const filteredColumns: Record = {}; - - // Include selected columns - for (const columnName of columnNames) { - if (originalTable.columns[columnName]) { - filteredColumns[columnName] = originalTable.columns[columnName]; - } - } - - // Always include primary key columns if not already included - for (const pkColumn of originalTable.primaryKey) { - if (!filteredColumns[pkColumn] && originalTable.columns[pkColumn]) { - filteredColumns[pkColumn] = originalTable.columns[pkColumn]; - } - } - - filteredTables[tableName] = { - ...originalTable, - columns: filteredColumns, - }; - } - } - - // Filter relations to only include those between selected tables - const filteredRelations = originalSchema.relations.filter( - relation => - filteredTables[relation.table] && - filteredTables[relation.referencedTable], - ); - - return { - tables: filteredTables, - relations: filteredRelations, - }; - } - - private _getSelectedColumnsFromSchema( - schema: DatabaseSchema, - ): Record { - const result: Record = {}; - - for (const [tableName, table] of Object.entries(schema.tables)) { - result[tableName] = Object.keys(table.columns); - } - - return result; - } -} diff --git a/src/components/db-query/nodes/get-tables.node.ts b/src/components/db-query/nodes/get-tables.node.ts deleted file mode 100644 index d342758..0000000 --- a/src/components/db-query/nodes/get-tables.node.ts +++ /dev/null @@ -1,225 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {inject} from '@loopback/context'; -import {service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbSchemaHelperService, PermissionHelper} from '../services'; -import {SchemaStore} from '../services/schema.store'; -import {TableSearchService} from '../services/search/table-search.service'; -import {DbQueryState} from '../state'; -import {DatabaseSchema, DbQueryConfig, GenerationError} from '../types'; - -@graphNode(DbQueryNodes.GetTables) -export class GetTablesNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llmCheap: RuntimeLLMProvider, - @inject(AiIntegrationBindings.SmartLLM) - private readonly llmSmart: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - @service(SchemaStore) - private readonly schemaStore: SchemaStore, - @service(TableSearchService) - private readonly tableSearchService: TableSearchService, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], - @service(PermissionHelper) - private readonly permissionHelper?: PermissionHelper, - ) {} - prompt = PromptTemplate.fromTemplate(` - -You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later. -- Consider not just the user query but also the context and the table descriptions while selecting the tables. -- Carefully consider each and every table before including or excluding it. -- If doubtful about a table's relevance, include it anyway to give the SQL generation step more options to choose from. -- Assume that the table would have appropriate columns for relating them to any other table even if the description does not mention it. -- If you are not sure about the tables to select from the given schema, just return your doubt asking the user for more details or to rephrase the question in the following format - -failed attempt: reason for failure - - - -{tables} - - - -{query} - - -{checks} - -{feedbacks} - - -The output should be just a comma separated list of table names with no other text, comments or formatting. -Ensure that table names are exact and match the names in the input including schema if given. - -public.employees, public.departments - -In case of failure, return the failure message in the format - -failed attempt: - -failed attempt: reason for failure - -`); - - feedbackPrompt = PromptTemplate.fromTemplate(` - -We also need to consider the errors from last attempt at query generation. - -In the last attempt, these were the last tables selected: -{lastTables} - -But it was rejected with the following errors: -{feedback} - -Use these if they are relevant to the table selection, otherwise ignore them, they would be considered again during the SQL generation step. - -`); - async execute( - state: DbQueryState, - config: RunnableConfig, - ): Promise> { - const tableList = await this.tableSearchService.getTables(state.prompt, 10); - const accessibleTables = this._filterByPermissions(tableList); - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Selecting from tables: ${accessibleTables}`, - }); - const dbSchema = this.schemaStore.filteredSchema(accessibleTables); - const allTables = this._getTablesFromSchema(dbSchema); - if (allTables.length === 0) { - throw new Error( - 'No tables found in the provided database schema. Please ensure the schema is valid.', - ); - } - - const useSmartLLM = this.config.nodes?.getTablesNode?.useSmartLLM ?? false; - const llm = useSmartLLM ? this.llmSmart : this.llmCheap; - - const chain = RunnableSequence.from([this.prompt, llm]); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: 'Extracting relevant tables from the schema', - }, - }); - - let attempts = 0; - let requiredTables: string[] = []; - while (attempts < 2) { - attempts++; - const result = await chain.invoke({ - tables: allTables.join('\n\n'), - query: state.prompt, - feedbacks: await this.getFeedbacks(state), - checks: [ - ``, - ...(this.checks ?? []).map(check => `- ${check}`), - ...this.schemaHelper - .getTablesContext(dbSchema) - .map(check => `- ${check}`), - ``, - ].join('\n'), - }); - - const output = stripThinkingTokens(result); - - if (output.startsWith('failed attempt:')) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Table selection failed: ${output}`, - }); - return { - status: GenerationError.Failed, - replyToUser: output.replace('failed attempt: ', ''), - }; - } - - const lastLine = output.split('\n').pop() ?? ''; - requiredTables = lastLine.split(',').map(t => t.trim()); - if (this._validateTables(requiredTables, dbSchema)) { - break; - } else { - if (attempts === 3) { - return { - status: GenerationError.Failed, - replyToUser: `Not able to select relevant tables from the schema. Please rephrase the question or provide more details.`, - }; - } - config.writer?.({ - type: LLMStreamEventType.Log, - data: `LLM returned invalid tables: ${lastLine}, trying again`, - }); - } - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Picked tables - ${requiredTables.join(', ')}`, - }); - - if (requiredTables.length === 0) { - throw new Error( - 'LLM did not return a valid comma separated string response.', - ); - } - - return { - schema: this.schemaStore.filteredSchema(requiredTables), - }; - } - - async getFeedbacks(state: DbQueryState) { - if (state.feedbacks) { - const feedbacks = await this.feedbackPrompt.format({ - query: state.sql, - feedback: state.feedbacks.join('\n'), - lastTables: this._tableListFromSchema(state.schema).join(', '), - }); - - return feedbacks; - } - return ''; - } - - private _tableListFromSchema(schema: DatabaseSchema): string[] { - if (!schema?.tables) { - return []; - } - return Object.keys(schema.tables); - } - - private _getTablesFromSchema(schema: DatabaseSchema): string[] { - if (!schema?.tables) { - return []; - } - return Object.keys(schema.tables).map(tableName => { - const table = schema.tables[tableName]; - return `${tableName}: ${table.description}`; - }); - } - - private _filterByPermissions(tables: string[]): string[] { - const permHelper = this.permissionHelper; - if (!permHelper) { - return tables; - } - return tables.filter(t => { - const name = t.toLowerCase().slice(t.indexOf('.') + 1); - return permHelper.findMissingPermissions([name]).length === 0; - }); - } - - private _validateTables(tables: string[], schema: DatabaseSchema): boolean { - return tables.every(t => schema.tables[t] !== undefined); - } -} diff --git a/src/components/db-query/nodes/index.ts b/src/components/db-query/nodes/index.ts deleted file mode 100644 index b736811..0000000 --- a/src/components/db-query/nodes/index.ts +++ /dev/null @@ -1,16 +0,0 @@ -export * from './check-cache.node'; -export * from './classify-change.node'; -export * from './check-permissions.node'; -export * from './check-templates.node'; -export * from './failed.node'; -export * from './fix-query.node'; -export * from './generate-checklist.node'; -export * from './generate-description.node'; -export * from './get-columns.node'; -export * from './get-tables.node'; -export * from './is-improvement.node'; -export * from './save-dataset-node'; -export * from './semantic-validator.node'; -export * from './sql-generation.node'; -export * from './syntactic-validator.node'; -export * from './verify-checklist.node'; diff --git a/src/components/db-query/nodes/is-improvement.node.ts b/src/components/db-query/nodes/is-improvement.node.ts deleted file mode 100644 index bbd44a7..0000000 --- a/src/components/db-query/nodes/is-improvement.node.ts +++ /dev/null @@ -1,32 +0,0 @@ -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject} from '@loopback/context'; -import {graphNode} from '../../../decorators'; -import {IGraphNode} from '../../../graphs'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbQueryState} from '../state'; -import {IDataSetStore} from '../types'; - -@graphNode(DbQueryNodes.IsImprovement) -export class IsImprovementNode implements IGraphNode { - constructor( - @inject(DbQueryAIExtensionBindings.DatasetStore) - private readonly store: IDataSetStore, - ) {} - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - if (state.datasetId) { - const dataset = await this.store.findById(state.datasetId); - return { - ...state, - sampleSql: dataset.query, - sampleSqlPrompt: dataset.prompt, - prompt: `${dataset.prompt}\n also consider following feedback given by user -\n ${state.prompt}\n`, - }; - } - return state; - } -} diff --git a/src/components/db-query/nodes/save-dataset-node.ts b/src/components/db-query/nodes/save-dataset-node.ts deleted file mode 100644 index 97692f0..0000000 --- a/src/components/db-query/nodes/save-dataset-node.ts +++ /dev/null @@ -1,145 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {HttpErrors} from '@loopback/rest'; -import {IAuthUserWithPermissions} from '@sourceloop/core'; -import {createHash} from 'crypto'; -import {AuthenticationBindings} from 'loopback4-authentication'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, ToolStatus} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbQueryState} from '../state'; -import {DatabaseSchema, DbQueryConfig, IDataSetStore} from '../types'; -import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../constant'; -import {AnyObject} from '@loopback/repository'; -import {DbSchemaHelperService} from '../services'; - -@graphNode(DbQueryNodes.SaveDataset) -export class SaveDataSetNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.DatasetStore) - private readonly store: IDataSetStore, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @inject(AuthenticationBindings.CURRENT_USER) - private readonly user: IAuthUserWithPermissions, - @service(DbSchemaHelperService) - private readonly dbSchemaHelper: DbSchemaHelperService, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], - ) {} - - prompt = - PromptTemplate.fromTemplate(`You are an AI assitant that generates a short description of a query based on a given schema, providing a summary of the query's intent and user's demand in a way that is short but does not miss any importance detail. - - Here is the query that you need to describe - {query} - - And here is the schema that was used to generate the query - - {schema} - - - {checks} - The output should be a valid description of the query that is easy to understand by the user in plain text, without any formatting`); - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - config.writer?.({ - type: LLMStreamEventType.Log, - data: 'Dataset generated', - }); - - const tenantId = this.user.tenantId; - if (!tenantId) { - throw new HttpErrors.BadRequest(`User does not have a tenantId`); - } - if (!state.sql) { - throw new HttpErrors.InternalServerError(); - } - - if (!state.description) { - const chain = RunnableSequence.from([this.prompt, this.llm]); - - const output = await chain.invoke({ - checks: [ - 'You must keep these additional details in consideration while describing the query -', - ...(this.checks ?? []), - ].join('\n'), - query: state.sql, - schema: this.dbSchemaHelper.asString(state.schema), - }); - - state.description = stripThinkingTokens(output); - } - - const dataset = await this.store.create({ - query: state.sql, - tenantId, - description: state.description, - prompt: state.prompt, - tables: this._getTableList(state.schema), - schemaHash: this._hashSchema(state.schema), - votes: 0, - }); - - if (!state.directCall) { - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: ToolStatus.Completed, - data: { - datasetId: dataset.id, - }, - }, - }); - } - - let result: undefined | AnyObject[] = undefined; - if (this.config.readAccessForAI && dataset.id) { - result = await this.store.getData( - dataset.id, - this.config.maxRowsForAI ?? DEFAULT_MAX_READ_ROWS_FOR_AI, - ); - } - - return { - ...state, - datasetId: dataset.id, - replyToUser: state.description, - done: true, - resultArray: result, - }; - } - - private _hashSchema(schema: DatabaseSchema): string { - const hash = createHash('sha256'); - const tableList = this._getTableList(schema).sort((a, b) => - a.localeCompare(b), - ); - tableList.forEach(table => { - hash.update(table); - const columns = schema.tables[table]?.columns || {}; - Object.keys(columns) - .sort((a, b) => a.localeCompare(b)) - .forEach(column => { - hash.update(`${column}:${columns[column].type}`); - }); - }); - return hash.digest('hex'); - } - - private _getTableList(schema: DatabaseSchema): string[] { - if (!schema?.tables) { - return []; - } - return Object.keys(schema.tables); - } -} diff --git a/src/components/db-query/nodes/semantic-validator.node.ts b/src/components/db-query/nodes/semantic-validator.node.ts deleted file mode 100644 index 1e8d7a6..0000000 --- a/src/components/db-query/nodes/semantic-validator.node.ts +++ /dev/null @@ -1,177 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import { - DbSchemaHelperService, - PermissionHelper, - TableSearchService, -} from '../services'; -import {DbQueryState} from '../state'; -import {DbQueryConfig, EvaluationResult} from '../types'; - -@graphNode(DbQueryNodes.SemanticValidator) -export class SemanticValidatorNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.SmartLLM) - private readonly smartllm: RuntimeLLMProvider, - @inject(AiIntegrationBindings.CheapLLM) - private readonly cheapllm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @service(TableSearchService) - private readonly tableSearchService: TableSearchService, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - @service(PermissionHelper) - private readonly permissionHelper?: PermissionHelper, - ) {} - - prompt = PromptTemplate.fromTemplate(` - -You are an AI assistant that validates whether a SQL query satisfies a given checklist. -The query has already been validated for syntax and correctness. -Go through each checklist item and verify it against the SQL query. -DO NOT make up issues that do not exist in the query. - - - -{userPrompt} - - - -{query} - - - -{schema} - - - -{tableNames} - - - -{checklist} - - -{feedbacks} - - -If the query satisfies ALL checklist items, return ONLY a valid tag with no other text: - - - - -If any checklist item is NOT satisfied, return your response in two sections: -1. An invalid tag containing each failed item with a detailed explanation of what is wrong and how it should be fixed. -2. A tables tag listing ALL table names from the available tables that are related to the errors. Be generous - include tables directly involved in the error, tables that need to be joined to fix the issue, and any tables that could be relevant. It is better to include extra tables than to miss any. - - - -- Salary values are not converted to USD. The query should join the exchange_rates table using currency_id and multiply salary by the rate. -- Lost and hold deals are not excluded. Add a WHERE condition to filter out deals with status 0 and 2. - -exchange_rates, deals, employees - - -`); - - feedbackPrompt = PromptTemplate.fromTemplate(` - -We also need to consider the users feedback on the last attempt at query generation. - -But was rejected by validator with the following errors - -{feedback} - -Keep these feedbacks in mind while validating the new query. -`); - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: `Verifying if the query fully satisfies the user's requirement`, - }, - }); - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Validating the query semantically.`, - }); - const useSmartLLM = - this.config.nodes?.semanticValidatorNode?.useSmartLLM ?? false; - const llm = useSmartLLM ? this.smartllm : this.cheapllm; - const tableList = - (await this.tableSearchService.getTables(state.prompt)) ?? []; - const accessibleTables = this._filterByPermissions(tableList); - const chain = RunnableSequence.from([this.prompt, llm]); - const output = await chain.invoke({ - userPrompt: state.prompt, - query: state.sql, - schema: this.schemaHelper.asString(state.schema), - tableNames: accessibleTables.join(', '), - checklist: state.validationChecklist ?? 'No checklist provided.', - feedbacks: await this.getFeedbacks(state), - }); - const response = stripThinkingTokens(output); - - const invalidMatch = /(.*?)<\/invalid>/s.exec(response); - const tablesMatch = /(.*?)<\/tables>/s.exec(response); - const isValid = - response.includes('') || response.includes(''); - - if (isValid && !invalidMatch) { - return { - semanticStatus: EvaluationResult.Pass, - } as DbQueryState; - } else { - const reason = invalidMatch ? invalidMatch[1].trim() : response.trim(); - const errorTables = tablesMatch - ? tablesMatch[1] - .split(',') - .map(t => t.trim()) - .filter(t => t.length > 0) - : []; - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Query Validation Failed by LLM: ${reason}`, - }); - return { - semanticStatus: EvaluationResult.QueryError, - semanticFeedback: `Query Validation Failed by LLM: ${reason}`, - semanticErrorTables: errorTables, - } as DbQueryState; - } - } - - async getFeedbacks(state: DbQueryState) { - if (state.feedbacks?.length) { - const feedbacks = await this.feedbackPrompt.format({ - feedback: state.feedbacks.join('\n'), - }); - return feedbacks; - } - return ''; - } - - private _filterByPermissions(tables: string[]): string[] { - const permHelper = this.permissionHelper; - if (!permHelper) { - return tables; - } - return tables.filter(t => { - const name = t.toLowerCase().slice(t.indexOf('.') + 1); - return permHelper.findMissingPermissions([name]).length === 0; - }); - } -} diff --git a/src/components/db-query/nodes/sql-generation.node.ts b/src/components/db-query/nodes/sql-generation.node.ts deleted file mode 100644 index 3abd79a..0000000 --- a/src/components/db-query/nodes/sql-generation.node.ts +++ /dev/null @@ -1,229 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider, SupportedDBs} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbSchemaHelperService} from '../services'; -import {DbQueryState} from '../state'; -import { - ChangeType, - DbQueryConfig, - EvaluationResult, - GenerationError, -} from '../types'; - -@graphNode(DbQueryNodes.SqlGeneration) -export class SqlGenerationNode implements IGraphNode { - sqlGenerationPrompt = PromptTemplate.fromTemplate(` - -You are an expert AI assistant that generates SQL queries based on user questions and a given database schema. -You try to following the instructions carefully to generate the SQL query that answers the question. -Do not hallucinate details or make up information. -Your task is to convert a question into a SQL query, given a {dialect} database schema. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **DO NOT make any DML statements** (INSERT, UPDATE, DELETE, DROP etc.) to the database. -- Never query for all the columns from a specific table, only ask for the relevant columns for the given the question. -- You can only generate a single query, so if you need multiple results you can use JOINs, subqueries, CTEs or UNIONS. -- Do not make any assumptions about the user's intent beyond what is explicitly provided in the prompt. -- Ensure proper grouping with brackets for where clauses with multiple conditions using AND and OR. -- Follow each and every single rule in the "must-follow-rules" section carefully while writing the query. DO NOT SKIP ANY RULE. - - -{question} - - - -{dbschema} - - -{checks} - -{exampleQueries} - -{feedbacks} - - -{outputFormat} -`); - - outputFormat = ` -Output should only be a valid SQL query with no other special character or formatting. -Contains the required valid SQL satisfying all the constraints. -It should have no other character or symbol or character that is not part of SQLs.`; - - feedbackPrompt = PromptTemplate.fromTemplate(` - -We also need to consider the users feedback on the last attempt at query generation. -Make sure you fix the provided error without introducing any new or past errors. -In the last attempt, you generated this SQL query - - -{query} - - - -{feedback} - - -{historicalErrors} -`); - constructor( - @inject(AiIntegrationBindings.SmartLLM) - private readonly sqlLLM: RuntimeLLMProvider, - @inject(AiIntegrationBindings.CheapLLM) - private readonly cheapllm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], - ) {} - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - let llm; - - const isSingleTable = - state.schema.tables && Object.keys(state.schema.tables).length === 1; - - // Use cheap LLM for validation fix retries — the query is close, just needs small corrections - const isValidationFixRetry = - state.feedbacks?.length && - state.feedbacks[state.feedbacks.length - 1].startsWith( - 'Query Validation Failed', - ); - - // Use changeType from ClassifyChangeNode to pick the right LLM - if ( - state.changeType === ChangeType.Minor || - isSingleTable || - isValidationFixRetry - ) { - llm = this.cheapllm; - } else { - llm = this.sqlLLM; - } - - const chain = RunnableSequence.from([this.sqlGenerationPrompt, llm]); - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Generating SQL query from the prompt - ${state.prompt}`, - }); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: 'Generating SQL query from the prompt', - }, - }); - - const output = await chain.invoke({ - dialect: this.config.db?.dialect ?? SupportedDBs.PostgreSQL, - question: state.prompt, - dbschema: this.schemaHelper.asString(state.schema), - checks: this._buildChecks(state), - feedbacks: await this.getFeedbacks(state), - exampleQueries: state.feedbacks?.length - ? '' - : await this.sampleQueries(state), - outputFormat: this.outputFormat, - }); - const response = stripThinkingTokens(output); - - const sql = - response - .replace(/^```(?:sql)?\s*/i, '') - .replace(/```\s*$/, '') - .trim() || undefined; - - if (!sql) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `SQL generation failed: ${response}`, - }); - return { - status: GenerationError.Failed, - replyToUser: - 'Failed to generate SQL query. Please try rephrasing your question or provide more details.', - } as DbQueryState; - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Generated SQL query: ${sql}`, - }); - - return { - status: EvaluationResult.Pass, - sql, - } as DbQueryState; - } - - async getFeedbacks(state: DbQueryState) { - if (state.feedbacks?.length) { - const lastFeedback = state.feedbacks[state.feedbacks.length - 1]; - const otherFeedbacks = state.feedbacks.slice(0, -1); - const feedbacks = await this.feedbackPrompt.format({ - query: state.sql, - feedback: `This was the error in the latest query you generated - \n${lastFeedback}`, - historicalErrors: otherFeedbacks.length - ? [ - ``, - `You already faced following issues in the past -`, - otherFeedbacks.join('\n'), - ``, - ].join('\n') - : '', - }); - return feedbacks; - } - return ''; - } - - async sampleQueries(state: DbQueryState) { - let startTag = ``; - let endTag = ``; - let baseLine = `Here is an example query for reference that is similar to the question asked and has been validated by the user`; - if (!state.fromCache) { - startTag = ``; - endTag = ``; - baseLine = `Here is the last valid SQL query that was generated for the user that is supposed to be used as the base line for the next query generation.`; - } - return state.sampleSql - ? `${startTag}\n${baseLine} - -${state.sampleSql} -This was generated for the following question - \n${state.sampleSqlPrompt} \n\n -${endTag}` - : ''; - } - - private _buildChecks(state: DbQueryState): string { - // Use the filtered checklist from GenerateChecklist if available - if (state.validationChecklist) { - return [ - '', - 'You must keep these additional details in mind while writing the query -', - ...state.validationChecklist.split('\n').map(check => `- ${check}`), - '', - ].join('\n'); - } - // Fallback to full checks - return [ - '', - 'You must keep these additional details in mind while writing the query -', - ...(this.checks ?? []).map(check => `- ${check}`), - ...this.schemaHelper - .getTablesContext(state.schema) - .map(check => `- ${check}`), - '', - ].join('\n'); - } -} diff --git a/src/components/db-query/nodes/syntactic-validator.node.ts b/src/components/db-query/nodes/syntactic-validator.node.ts deleted file mode 100644 index 0574ac8..0000000 --- a/src/components/db-query/nodes/syntactic-validator.node.ts +++ /dev/null @@ -1,105 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject} from '@loopback/context'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbQueryState} from '../state'; -import {EvaluationResult, IDbConnector} from '../types'; - -@graphNode(DbQueryNodes.SyntacticValidator) -export class SyntacticValidatorNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Connector) - private readonly connector: IDbConnector, - ) {} - - prompt = - PromptTemplate.fromTemplate(`You are an AI assistant that categorizes the SQL query error and identifies related tables. - -Here is the SQL query error that you need to categorize - -{error} - -Here is the query that resulted in the error - -{query} - -Here are all the available tables in the database - -{tableNames} - -Categorize the error into one of these two categories: -- table_not_found: Any error that indicates a table or column is missing -- query_error: All other errors - -Also identify ALL tables that are related to the error. Be generous - include tables that are directly involved in the error, tables referenced in the failing part of the query, and tables that might need to be joined or referenced to fix the error. It is better to include extra tables than to miss any. - -Return your response in exactly this format with no other text: -table_not_found or query_error -comma, separated, table, names -`); - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: 'Validating generated SQL query', - }, - }); - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Validating the query syntactically.`, - }); - - try { - if (!state.sql) { - throw new Error('No SQL query generated to validate'); - } - await this.connector.validate(state.sql); - return { - syntacticStatus: EvaluationResult.Pass, - } as DbQueryState; - } catch (error) { - const tableNames = Object.keys(state.schema?.tables ?? {}); - const chain = RunnableSequence.from([this.prompt, this.llm]); - const output = await chain.invoke({ - error: error.message, - query: state.sql, - tableNames: tableNames.join(', '), - }); - const result = stripThinkingTokens(output); - - const categoryMatch = /(.*?)<\/category>/s.exec(result); - const tablesMatch = /(.*?)<\/tables>/s.exec(result); - - const category = categoryMatch - ? (categoryMatch[1].trim() as EvaluationResult) - : (result.trim() as EvaluationResult); - const errorTables = tablesMatch - ? tablesMatch[1] - .split(',') - .map(t => t.trim()) - .filter(t => t.length > 0) - : []; - - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Query Validation Failed by DB: ${category} with error ${error.message}`, - }); - return { - syntacticStatus: category, - syntacticFeedback: `Query Validation Failed by DB: ${category} with error ${error.message}`, - syntacticErrorTables: errorTables, - } as DbQueryState; - } - } -} diff --git a/src/components/db-query/nodes/verify-checklist.node.ts b/src/components/db-query/nodes/verify-checklist.node.ts deleted file mode 100644 index 14b816e..0000000 --- a/src/components/db-query/nodes/verify-checklist.node.ts +++ /dev/null @@ -1,194 +0,0 @@ -import {AIMessage} from '@langchain/core/messages'; -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbQueryNodes} from '../nodes.enum'; -import {DbSchemaHelperService} from '../services'; -import {DbQueryState} from '../state'; -import {DbQueryConfig} from '../types'; - -@graphNode(DbQueryNodes.VerifyChecklist) -export class VerifyChecklistNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.SmartLLM) - private readonly smartLlm: RuntimeLLMProvider, - @inject(DbQueryAIExtensionBindings.Config) - private readonly config: DbQueryConfig, - @service(DbSchemaHelperService) - private readonly schemaHelper: DbSchemaHelperService, - @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) - private readonly checks?: string[], - @inject(AiIntegrationBindings.SmartNonThinkingLLM, {optional: true}) - private readonly smartNonThinkingLlm?: RuntimeLLMProvider, - ) {} - - private get llm(): RuntimeLLMProvider { - return this.smartNonThinkingLlm ?? this.smartLlm; - } - - basePrompt = ` - -You are given a user question, the tables selected for SQL generation, the relevant database schema, and a numbered list of rules/checks. -Return ONLY the indexes of the rules that are relevant to the user's question, the selected tables, and the given schema. - -A rule is relevant if: -- It directly affects how a correct SQL query should be written for this question. -- It is a dependency of another relevant rule (e.g. if rule 3 requires a currency conversion, and rule 5 defines how currency conversion works, both must be included). -- It applies to any of the selected tables or their relationships. - -Ensure: -- Any rule that is referenced by, or is a prerequisite for, another selected rule is also included. -- Do not include rules that are completely unrelated to the question, schema, or selected tables. - - - -{prompt} - - - -{tables} - - - -{schema} - - - -{indexedChecks} - - -`; - - evaluationOutputInstructions = ` -First, evaluate each rule inside an evaluation tag. For each rule, repeat the full rule text exactly as given, followed by " — Include" or " — Exclude" with a brief reason. -Then, return only the comma-separated list of included rule indexes inside a result tag. - -Example: - -1. When matching names, use ilike with wildcards — Include, query involves name matching -2. Format dates using to_char — Exclude, no date fields in this query -3. Always exclude lost deals — Include, query involves deals - -1,3 - -If no rules are relevant: none -`; - - simpleOutputInstructions = ` -Return ONLY the comma-separated list of relevant rule indexes inside a result tag. -Do NOT include any reasoning, analysis, or explanation — only the result tag. -Example: -1,3,5 -If no rules are relevant: -none -`; - - async execute( - state: DbQueryState, - config: LangGraphRunnableConfig, - ): Promise { - const empty = {} as DbQueryState; - - if (this.config.nodes?.verifyChecklistNode?.enabled === false) { - return empty; - } - - if (state.feedbacks?.length) { - return empty; - } - - const tableCount = Object.keys(state.schema?.tables ?? {}).length; - if (tableCount <= 2) { - return empty; - } - - const allChecks = [ - ...(this.checks ?? []), - ...this.schemaHelper.getTablesContext(state.schema), - ]; - - if (allChecks.length === 0) { - return empty; - } - - config.writer?.({ - type: LLMStreamEventType.Log, - data: 'Verifying validation checklist with chain-of-thought.', - }); - - const output = await this.invokeVerification(state, allChecks); - const verifiedIndexes = this.parseVerifiedIndexes(output, allChecks.length); - - if (verifiedIndexes.length === 0) { - return empty; - } - - const validationChecklist = this.mergeWithExisting( - state.validationChecklist, - verifiedIndexes, - allChecks, - ); - - return {validationChecklist} as DbQueryState; - } - - private async invokeVerification( - state: DbQueryState, - allChecks: string[], - ): Promise { - const indexedChecks = allChecks - .map((check, i) => `${i + 1}. ${check}`) - .join('\n'); - - const useEvaluation = - this.config.nodes?.verifyChecklistNode?.evaluation ?? false; - const promptTemplate = PromptTemplate.fromTemplate( - this.basePrompt + - (useEvaluation - ? this.evaluationOutputInstructions - : this.simpleOutputInstructions), - ); - - const chain = RunnableSequence.from([promptTemplate, this.llm]); - return chain.invoke({ - prompt: state.prompt, - tables: Object.keys(state.schema?.tables ?? {}).join(', '), - schema: this.schemaHelper.asString(state.schema), - indexedChecks, - }); - } - - private parseVerifiedIndexes(output: AIMessage, maxIndex: number): number[] { - const response = stripThinkingTokens(output).trim(); - const resultMatch = /(.*?)<\/result>/s.exec(response); - const indexStr = resultMatch ? resultMatch[1].trim() : response; - - if (!indexStr || indexStr === 'none') return []; - - return indexStr - .split(',') - .map(s => Number.parseInt(s.trim(), 10)) - .filter(n => !Number.isNaN(n) && n >= 1 && n <= maxIndex); - } - - private mergeWithExisting( - existing: string | undefined, - verifiedIndexes: number[], - allChecks: string[], - ): string { - const existingChecks = new Set( - (existing ?? '').split('\n').filter(c => c.length > 0), - ); - for (const check of verifiedIndexes.map(i => allChecks[i - 1])) { - existingChecks.add(check); - } - return Array.from(existingChecks).join('\n'); - } -} diff --git a/src/components/db-query/providers/datasets.retriever.ts b/src/components/db-query/providers/datasets.retriever.ts deleted file mode 100644 index 2c58b79..0000000 --- a/src/components/db-query/providers/datasets.retriever.ts +++ /dev/null @@ -1,37 +0,0 @@ -import {BaseRetriever} from '@langchain/core/retrievers'; -import {VectorStore} from '@langchain/core/vectorstores'; -import {inject, Provider, ValueOrPromise} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import {MemoryVectorStore} from 'langchain/vectorstores/memory'; -import {AiIntegrationBindings} from '../../../keys'; -import {DbQueryStoredTypes} from '../types'; -import {AuthenticationBindings} from 'loopback4-authentication'; -import {IAuthUserWithPermissions} from '@sourceloop/core'; - -export class DatasetRetriever implements Provider { - constructor( - @inject(AiIntegrationBindings.VectorStore) - private readonly vectorStore: VectorStore, - @inject(AuthenticationBindings.CURRENT_USER) - private readonly user: IAuthUserWithPermissions, - ) {} - value(): ValueOrPromise> { - if (this.vectorStore instanceof MemoryVectorStore) { - return this.vectorStore.asRetriever({ - k: 5, - filter: doc => - doc.metadata.type === DbQueryStoredTypes.DataSet && - doc.metadata.tenantId === this.user.tenantId, - searchType: 'similarity', - }); - } - return this.vectorStore.asRetriever({ - k: 5, - filter: { - type: DbQueryStoredTypes.DataSet, - tenantId: this.user.tenantId, - }, - searchType: 'similarity', - }); - } -} diff --git a/src/components/db-query/providers/index.ts b/src/components/db-query/providers/index.ts index c0c6839..3ece2d0 100644 --- a/src/components/db-query/providers/index.ts +++ b/src/components/db-query/providers/index.ts @@ -1,2 +1 @@ -export * from './datasets.retriever'; -export * from './templates.retriever'; +// Retriever files removed — LangChain-based retrievers deleted in Phase 7 diff --git a/src/components/db-query/providers/templates.retriever.ts b/src/components/db-query/providers/templates.retriever.ts deleted file mode 100644 index 2edeb79..0000000 --- a/src/components/db-query/providers/templates.retriever.ts +++ /dev/null @@ -1,37 +0,0 @@ -import {BaseRetriever} from '@langchain/core/retrievers'; -import {VectorStore} from '@langchain/core/vectorstores'; -import {inject, Provider, ValueOrPromise} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import {MemoryVectorStore} from 'langchain/vectorstores/memory'; -import {AiIntegrationBindings} from '../../../keys'; -import {DbQueryStoredTypes} from '../types'; -import {AuthenticationBindings} from 'loopback4-authentication'; -import {IAuthUserWithPermissions} from '@sourceloop/core'; - -export class TemplateRetriever implements Provider { - constructor( - @inject(AiIntegrationBindings.VectorStore) - private readonly vectorStore: VectorStore, - @inject(AuthenticationBindings.CURRENT_USER) - private readonly user: IAuthUserWithPermissions, - ) {} - value(): ValueOrPromise> { - if (this.vectorStore instanceof MemoryVectorStore) { - return this.vectorStore.asRetriever({ - k: 5, - filter: doc => - doc.metadata.type === DbQueryStoredTypes.Template && - doc.metadata.tenantId === this.user.tenantId, - searchType: 'similarity', - }); - } - return this.vectorStore.asRetriever({ - k: 5, - filter: { - type: DbQueryStoredTypes.Template, - tenantId: this.user.tenantId, - }, - searchType: 'similarity', - }); - } -} diff --git a/src/components/db-query/services/dataset-helper.service.ts b/src/components/db-query/services/dataset-helper.service.ts index 2ff374c..a7f8ea1 100644 --- a/src/components/db-query/services/dataset-helper.service.ts +++ b/src/components/db-query/services/dataset-helper.service.ts @@ -1,8 +1,8 @@ -import {VectorStore} from '@langchain/core/vectorstores'; import {inject, service} from '@loopback/core'; import {Filter} from '@loopback/repository'; import {HttpErrors} from '@loopback/rest'; import {AiIntegrationBindings} from '../../../keys'; +import {IVectorStore} from '../../../types'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryStoredTypes, IDataSet, IDataSetStore} from '../types'; import {PermissionHelper} from './permission-helper.service'; @@ -14,8 +14,8 @@ export class DataSetHelper { private readonly store: IDataSetStore, @service(PermissionHelper) private readonly permissionHelper: PermissionHelper, - @inject(AiIntegrationBindings.VectorStore) - private readonly vectorStore: VectorStore, + @inject(AiIntegrationBindings.AiSdkVectorStore) + private readonly vectorStore: IVectorStore, ) {} async checkPermissions(datasetId: string) { diff --git a/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts b/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts index 1caaa77..c32b134 100644 --- a/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts +++ b/src/components/db-query/services/knowledge-graph/db-knowledge-graph.service.ts @@ -1,8 +1,9 @@ -import {RunnableSequence} from '@langchain/core/runnables'; +import {generateText} from 'ai'; +import {embedMany} from 'ai'; import {BindingScope, inject, injectable} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {AiIntegrationBindings} from '../../../../keys'; -import {EmbeddingProvider, RuntimeLLMProvider} from '../../../../types'; +import {AiSdkEmbeddingModel, LLMProvider} from '../../../../types'; import {stripThinkingTokens} from '../../../../utils'; import {DbQueryAIExtensionBindings} from '../../keys'; import {DatabaseSchema, DbQueryConfig, TableSchema} from '../../types'; @@ -34,10 +35,10 @@ export class DbKnowledgeGraphService implements KnowledgeGraph< private maxClusterSize: number; // Max size of clusters to consider for concept extraction constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @inject(AiIntegrationBindings.EmbeddingModel) - private readonly embeddingModel: EmbeddingProvider, + @inject(AiIntegrationBindings.AiSdkCheapLLM) + private readonly llm: LLMProvider, + @inject(AiIntegrationBindings.AiSdkEmbeddingModel) + private readonly embeddingModel: AiSdkEmbeddingModel, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, ) { @@ -264,12 +265,14 @@ export class DbKnowledgeGraphService implements KnowledgeGraph< } private async generateEmbedding(text: string): Promise { - return this.embeddingModel.embedDocuments([text]).then(embeddings => { - if (embeddings.length === 0 || !embeddings[0]) { - throw new Error('Failed to generate embedding'); - } - return embeddings[0]; + const {embeddings} = await embedMany({ + model: this.embeddingModel, + values: [text], }); + if (embeddings.length === 0 || !embeddings[0]) { + throw new Error('Failed to generate embedding'); + } + return embeddings[0]; } // Smart concept extraction using clustering @@ -405,8 +408,11 @@ The output should be JUST a valid JSON and no other markdown or formatting text. Focus on the core business concept or data domain. AGAIN, ensure the output is a valid JSON object with no additional text or formatting that can be parsed directly.`; try { - const chain = RunnableSequence.from([this.llm, stripThinkingTokens]); - const response = await chain.invoke([{role: 'user', content: prompt}]); + const {text} = await generateText({ + model: this.llm, + messages: [{role: 'user', content: prompt}], + }); + const response = stripThinkingTokens(text); debug(`Extracted concept for cluster ${clusterIndex}:`, response); const concept = JSON.parse(response); diff --git a/src/components/db-query/services/template-helper.service.ts b/src/components/db-query/services/template-helper.service.ts index aabf390..cbd5102 100644 --- a/src/components/db-query/services/template-helper.service.ts +++ b/src/components/db-query/services/template-helper.service.ts @@ -1,8 +1,7 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; +import {generateText} from 'ai'; import {inject} from '@loopback/core'; import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; +import {LLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import { DatabaseSchema, @@ -10,7 +9,7 @@ import { QueryTemplateMetadata, TemplatePlaceholder, } from '../types'; -import {RunnableConfig} from '../../../graphs'; +import {RunnableConfig} from '../../../types/tool'; const MAX_TEMPLATE_RECURSION_DEPTH = 3; @@ -21,24 +20,28 @@ type ResolvedTemplate = { export class TemplateHelper { constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, + @inject(AiIntegrationBindings.AiSdkCheapLLM) + private readonly llm: LLMProvider, ) {} - extractionPrompt = PromptTemplate.fromTemplate(` - + private buildExtractionPrompt( + prompt: string, + template: string, + placeholders: string, + ): string { + return ` You are an expert at extracting parameter values from natural language prompts. Given a user prompt, a SQL template, and a list of placeholders with their descriptions and types, extract the value for each placeholder from the prompt. For sql_expression placeholders, generate a valid SQL fragment that fits the position of the placeholder in the template. -{prompt} +${prompt} -{template} +${template} -{placeholders} +${placeholders} Return each extracted value as an XML tag where the tag name is the placeholder name. @@ -51,7 +54,8 @@ Rules per type: - sql_expression: Return a complete, valid SQL fragment with proper SQL syntax including quotes where needed. Example: created_at > '2024-01-01' Do not return any other text or explanation, just the XML tags. -`); +`; + } async extractPlaceholderValues( placeholders: TemplatePlaceholder[], @@ -60,12 +64,6 @@ Do not return any other text or explanation, just the XML tags. config: RunnableConfig, schema?: DatabaseSchema, ): Promise> { - const chain = RunnableSequence.from([ - this.extractionPrompt, - this.llm, - stripThinkingTokens, - ]); - const placeholderDescriptions = placeholders .map(p => { let desc = `- ${p.name} (type: ${p.type}): ${p.description}`; @@ -76,16 +74,22 @@ Do not return any other text or explanation, just the XML tags. }) .join('\n'); - const response = await chain.invoke( - { - prompt, - template: sqlTemplate, - placeholders: placeholderDescriptions, - }, - config, - ); + const {text} = await generateText({ + model: this.llm, + messages: [ + { + role: 'user', + content: this.buildExtractionPrompt( + prompt, + sqlTemplate, + placeholderDescriptions, + ), + }, + ], + abortSignal: config.signal, + }); - return this._parseXmlValues(response, placeholders); + return this._parseXmlValues(stripThinkingTokens(text), placeholders); } private _getColumnContext( diff --git a/src/components/db-query/state.ts b/src/components/db-query/state.ts index 76ff040..8620e93 100644 --- a/src/components/db-query/state.ts +++ b/src/components/db-query/state.ts @@ -1,33 +1,35 @@ -import {Annotation} from '@langchain/langgraph'; -import {ChangeType, DatabaseSchema, Status} from './types'; import {AnyObject} from '@loopback/repository'; +import {ChangeType, DatabaseSchema, Status} from './types'; -export const DbQueryGraphStateAnnotation = Annotation.Root({ - prompt: Annotation, - schema: Annotation, - sql: Annotation, - status: Annotation, - id: Annotation, - feedbacks: Annotation, - replyToUser: Annotation, - datasetId: Annotation, - sampleSqlPrompt: Annotation, - sampleSql: Annotation, - fromCache: Annotation, - done: Annotation, - resultArray: Annotation, - description: Annotation, - directCall: Annotation, - validationChecklist: Annotation, - syntacticStatus: Annotation, - syntacticFeedback: Annotation, - semanticStatus: Annotation, - semanticFeedback: Annotation, - syntacticErrorTables: Annotation, - semanticErrorTables: Annotation, - changeType: Annotation, - fromTemplate: Annotation, - templateId: Annotation, -}); - -export type DbQueryState = typeof DbQueryGraphStateAnnotation.State; +/** + * State shape threaded through every step of the Mastra DbQuery workflow. + * Previously generated by `@langchain/langgraph` `Annotation.Root()` — + * rewritten as a plain TypeScript interface with zero LangGraph dependency. + */ +export interface DbQueryState { + prompt: string; + schema: DatabaseSchema; + sql?: string; + status?: Status; + id?: string; + feedbacks?: string[]; + replyToUser?: string; + datasetId?: string; + sampleSqlPrompt?: string; + sampleSql?: string; + fromCache?: boolean; + done?: boolean; + resultArray?: AnyObject[string][]; + description?: string; + directCall?: boolean; + validationChecklist?: string; + syntacticStatus?: Status; + syntacticFeedback?: string; + semanticStatus?: Status; + semanticFeedback?: string; + syntacticErrorTables?: string[]; + semanticErrorTables?: string[]; + changeType?: ChangeType; + fromTemplate?: boolean; + templateId?: string; +} diff --git a/src/components/db-query/testing/db-query.graph.builder.ts b/src/components/db-query/testing/db-query.graph.builder.ts deleted file mode 100644 index 66a6470..0000000 --- a/src/components/db-query/testing/db-query.graph.builder.ts +++ /dev/null @@ -1,39 +0,0 @@ -import {AnyObject} from '@loopback/repository'; -import {expect} from '@loopback/testlab'; -import {randomUUID} from 'crypto'; -import {DbQueryGraph} from '../db-query.graph'; -import {DatabaseSchema} from '../types'; -import {DbQueryGraphTestCase} from './types'; - -export function dbQueryToolTests(cases: DbQueryGraphTestCase[]) { - return cases.map(testCase => ({ - desc: testCase.prompt, - fn: async ( - schema: DatabaseSchema, - graphBuilder: DbQueryGraph, - datasetExecuter: (id: string) => Promise, - ) => { - const graph = await graphBuilder.build(); - const id = randomUUID(); - const state = await graph.invoke({ - prompt: testCase.prompt, - id, - schema, - sql: undefined, - status: undefined, - feedbacks: undefined, - replyToUser: undefined, - datasetId: undefined, - sampleSql: undefined, - sampleSqlPrompt: undefined, - fromCache: undefined, - done: false, - }); - if (!state.datasetId) { - throw new Error('Dataset ID is not defined in the state'); - } - const results = await datasetExecuter(state.datasetId); - expect(results).deepEqual(testCase.result); - }, - })); -} diff --git a/src/components/db-query/testing/generation.acceptance.builder.ts b/src/components/db-query/testing/generation.acceptance.builder.ts index 511966f..7c47502 100644 --- a/src/components/db-query/testing/generation.acceptance.builder.ts +++ b/src/components/db-query/testing/generation.acceptance.builder.ts @@ -15,7 +15,7 @@ import { LLMStreamTokenCountEvent, LLMStreamToolStatusEvent, ToolStatus, -} from '../../../graphs'; +} from '../../../types/events'; import {generateMarkdownTable, getModelNameFromEnv} from './utils'; import {writeFileSync} from 'fs'; import {AnyObject} from '@loopback/repository'; diff --git a/src/components/db-query/testing/get-table.node.builder.ts b/src/components/db-query/testing/get-table.node.builder.ts deleted file mode 100644 index 5ab16af..0000000 --- a/src/components/db-query/testing/get-table.node.builder.ts +++ /dev/null @@ -1,49 +0,0 @@ -import {RunnableConfig} from '@langchain/core/runnables'; -import {AnyObject} from '@loopback/repository'; -import {expect} from '@loopback/testlab'; -import {GetTablesNode} from '../nodes'; -import {DatabaseSchema} from '../types'; -import {GetTableNodeTestCase} from './types'; - -export function getTableNodeTests(cases: GetTableNodeTestCase[]) { - return cases.map(testCase => ({ - desc: testCase.query, - fn: async (schema: DatabaseSchema, node: GetTablesNode) => { - const result = await node.execute( - { - prompt: testCase.query, - id: 'test-query', - schema, - sql: undefined, - status: undefined, - feedbacks: undefined, - replyToUser: undefined, - datasetId: undefined, - sampleSql: undefined, - sampleSqlPrompt: undefined, - fromCache: undefined, - done: false, - resultArray: undefined, - directCall: false, - description: undefined, - validationChecklist: undefined, - syntacticStatus: undefined, - syntacticFeedback: undefined, - semanticStatus: undefined, - semanticFeedback: undefined, - syntacticErrorTables: undefined, - semanticErrorTables: undefined, - changeType: undefined, - fromTemplate: undefined, - templateId: undefined, - }, - { - writer: (event: AnyObject[string]) => {}, - } as unknown as RunnableConfig, - ); - testCase.expectedTables.forEach(table => { - expect(result.schema?.tables).to.have.property(table); - }); - }, - })); -} diff --git a/src/components/db-query/testing/index.ts b/src/components/db-query/testing/index.ts index 0e39e4b..97e2137 100644 --- a/src/components/db-query/testing/index.ts +++ b/src/components/db-query/testing/index.ts @@ -1,4 +1,3 @@ -export * from './db-query.graph.builder'; export * from './generation.acceptance.builder'; -export * from './get-table.node.builder'; export * from './types'; +export * from './utils'; diff --git a/src/components/db-query/tools/ask-about-dataset.tool.ts b/src/components/db-query/tools/ask-about-dataset.tool.ts deleted file mode 100644 index c09c132..0000000 --- a/src/components/db-query/tools/ask-about-dataset.tool.ts +++ /dev/null @@ -1,138 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {tool} from '@langchain/core/tools'; -import {Context, inject} from '@loopback/context'; -import {service} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import z from 'zod'; -import {graphTool} from '../../../decorators'; -import {IGraphTool, IRuntimeTool} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {stripThinkingTokens} from '../../../utils'; -import {DbQueryAIExtensionBindings} from '../keys'; -import {DbSchemaHelperService} from '../services'; -import {SchemaStore} from '../services/schema.store'; -import {IDataSetStore} from '../types'; - -@graphTool({ - description: - 'Tool for answering questions about an existing dataset, note that it can only answer questions about the dataset definition, not the data it contains. Call this only if you have a valid dataset ID available.', - inputSchema: z.object({ - datasetId: z - .string() - .describe('uuid ID of the dataset to answer the question for'), - question: z - .string() - .describe('The question that the user asked about the query.'), - }), -}) -export class AskAboutDatasetTool implements IGraphTool { - constructor( - @inject(DbQueryAIExtensionBindings.DatasetStore) - private readonly store: IDataSetStore, - @inject(AiIntegrationBindings.CheapLLM) - private readonly sqlllm: RuntimeLLMProvider, - @service(DbSchemaHelperService) - private readonly dbSchemaHelper: DbSchemaHelperService, - @service(SchemaStore) - private readonly schemaStore: SchemaStore, - // Use context injection so GlobalContext is resolved lazily at call time, - // not at construction time. This allows the tool to be instantiated from - // the application (singleton) context without requiring a live request. - @inject.context() - private readonly _ctx: Context, - ) {} - - key = 'ask-about-dataset'; - needsReview = false; - description = - 'Tool for answering questions about an existing dataset, note that it can only answer questions about the dataset definition, not the data it contains. Call this only if you have a valid dataset ID available.'; - inputSchema = z.object({ - datasetId: z - .string() - .describe('uuid ID of the dataset to answer the question for'), - question: z - .string() - .describe('The question that the user asked about the query.'), - }); - - private readonly prompt = - PromptTemplate.fromTemplate(`You are an AI assistant that answers questions about a query, without revealing any technical details, you need to answer the question the user's question. - Make sure you don't reveal the original query to the user, just answer the question based on the query. - Here is the query that the question was for - - {query} - - and here is the schema the query was generated for - - {schema} - - and here is the context that was provided for the query - - {context} - - and here is the user's question - - {question}`); - - /** - * Creates a runtime-agnostic tool that answers questions about an existing dataset. - */ - async createTool(): Promise { - // Resolve GlobalContext lazily. When called from a request context the - // checks will be populated; when called from the application context - // (e.g. during Mastra bridge startup) the resolution may fail and we - // gracefully fall back to an empty list. - let checks: string[] | undefined; - try { - checks = await this._ctx.get( - DbQueryAIExtensionBindings.GlobalContext, - {optional: true}, - ); - } catch { - checks = undefined; - } - - const chain = RunnableSequence.from([ - this.prompt, - this.sqlllm, - stripThinkingTokens, - ]); - - const schema = z.object({ - datasetId: z - .string() - .describe('uuid ID of the dataset to answer the question for'), - question: z - .string() - .describe('The question that the user asked about the query.'), - }) as AnyObject[string]; - - return tool( - async (args: {datasetId: string; question: string}) => { - const {query, tables} = await this.store.findById(args.datasetId); - const compressedSchema = this.schemaStore.filteredSchema(tables); - const response = await chain.invoke({ - query, - question: args.question, - schema: compressedSchema, - context: [ - ...(checks ?? []), - ...this.dbSchemaHelper.getTablesContext(compressedSchema), - ].join('\n'), - }); - return response; - }, - { - name: this.key, - description: - 'Tool for answering questions about an existing dataset, note that it can only answer questions about the dataset definition, not the data it contains. Call this only if you have a valid dataset ID available.', - schema, - }, - ); - } - - /** - * @deprecated Use createTool(). - */ - async build(): Promise { - return this.createTool(); - } -} diff --git a/src/components/db-query/tools/get-data-as-dataset.tool.ts b/src/components/db-query/tools/get-data-as-dataset.tool.ts index dda742e..c11c6f5 100644 --- a/src/components/db-query/tools/get-data-as-dataset.tool.ts +++ b/src/components/db-query/tools/get-data-as-dataset.tool.ts @@ -2,8 +2,11 @@ import {inject, service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {z} from 'zod'; import {graphTool} from '../../../decorators'; -import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../graphs'; -import {DbQueryGraph} from '../db-query.graph'; +import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../types/tool'; +import { + MastraDbQueryWorkflow, + MastraDbQueryContext, +} from '../../../mastra/db-query'; import {DbQueryConfig, Errors, GenerationError} from '../types'; import {DbQueryAIExtensionBindings} from '../keys'; import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../constant'; @@ -36,10 +39,10 @@ export class GetDataAsDatasetTool implements IGraphTool { ), }); constructor( - @service(DbQueryGraph) - private readonly queryPipeline: DbQueryGraph, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, + @service(MastraDbQueryWorkflow) + private readonly mastraWorkflow: MastraDbQueryWorkflow, ) {} getValue(result: Record): string { @@ -64,25 +67,30 @@ export class GetDataAsDatasetTool implements IGraphTool { } /** - * Creates a runtime-agnostic tool for dataset generation. + * Creates a Mastra-compatible `IRuntimeTool` that executes the imperative workflow. */ async createTool(): Promise { - const graph = await this.queryPipeline.build(); - const schema = z.object({ - prompt: z - .string() - .describe( - `Prompt from the user that will be used for generating an SQL query and create a dataset from it.`, - ), - }) as AnyObject[string]; - return graph.asTool({ + return { name: this.key, - description: `Query tool for generating SQL queries for a users request. Use it only when the user needs raw tabular data from the database. - Do not use this tool if the user's request involves trends, growth, decline, comparisons, distributions, patterns, or any form of analytical insight — use the 'generate-visualization' tool instead. - Note that it does not return the query, instead only a dataset ID that is not relevant to the user. - It internally fires an event that renders a grid for the dataset on the UI for the user to see.`, - schema, - }); + description: this.description, + schema: this.inputSchema, + invoke: async ( + input: unknown, + opts?: { + writer?: MastraDbQueryContext['writer']; + signal?: AbortSignal; + }, + ) => { + const {prompt} = input as {prompt: string}; + return this.mastraWorkflow.run( + {prompt}, + { + writer: opts?.writer, + signal: opts?.signal, + }, + ); + }, + } as IRuntimeTool; } /** diff --git a/src/components/db-query/tools/improve-dataset.tool.ts b/src/components/db-query/tools/improve-dataset.tool.ts index 39791da..d4d11df 100644 --- a/src/components/db-query/tools/improve-dataset.tool.ts +++ b/src/components/db-query/tools/improve-dataset.tool.ts @@ -2,8 +2,11 @@ import {inject, service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {z} from 'zod'; import {graphTool} from '../../../decorators'; -import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../graphs'; -import {DbQueryGraph} from '../db-query.graph'; +import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../types/tool'; +import { + MastraDbQueryWorkflow, + MastraDbQueryContext, +} from '../../../mastra/db-query'; import {DbQueryConfig, Errors, GenerationError} from '../types'; import {DbQueryAIExtensionBindings} from '../keys'; import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../constant'; @@ -38,10 +41,10 @@ export class ImproveDatasetTool implements IGraphTool { ), }); constructor( - @service(DbQueryGraph) - private readonly queryPipeline: DbQueryGraph, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, + @service(MastraDbQueryWorkflow) + private readonly mastraWorkflow: MastraDbQueryWorkflow, ) {} getValue(result: Record): string { @@ -66,26 +69,33 @@ export class ImproveDatasetTool implements IGraphTool { } /** - * Creates a runtime-agnostic tool for dataset improvement. + * Creates a Mastra-compatible `IRuntimeTool` that executes the imperative workflow. */ async createTool(): Promise { - const graph = await this.queryPipeline.build(); - const schema = z.object({ - datasetId: z - .string() - .describe(`UUID ID of the existing dataset to improve`), - prompt: z - .string() - .describe( - `A description of what changes or improvements the user wants in the existing dataset.`, - ), - }) as AnyObject[string]; - return graph.asTool({ + return { name: this.key, - description: - 'Tool for improving an existing dataset based on user feedback. It takes a dataset ID and a prompt describing the desired changes, and returns an updated dataset. Call this only if you have a valid dataset ID available.', - schema, - }); + description: this.description, + schema: this.inputSchema, + invoke: async ( + input: unknown, + opts?: { + writer?: MastraDbQueryContext['writer']; + signal?: AbortSignal; + }, + ) => { + const {datasetId, prompt} = input as { + datasetId: string; + prompt: string; + }; + return this.mastraWorkflow.run( + {prompt, datasetId}, + { + writer: opts?.writer, + signal: opts?.signal, + }, + ); + }, + } as IRuntimeTool; } /** diff --git a/src/components/db-query/tools/index.ts b/src/components/db-query/tools/index.ts index f410f5f..181aab0 100644 --- a/src/components/db-query/tools/index.ts +++ b/src/components/db-query/tools/index.ts @@ -1,3 +1,2 @@ -export * from './ask-about-dataset.tool'; export * from './get-data-as-dataset.tool'; export * from './improve-dataset.tool'; diff --git a/src/components/visualization/index.ts b/src/components/visualization/index.ts index 1c9a671..9b1a3c5 100644 --- a/src/components/visualization/index.ts +++ b/src/components/visualization/index.ts @@ -1,7 +1,5 @@ export * from './decorators'; -export * from './nodes'; export * from './tools'; export * from './types'; -export * from './visualizers'; export * from './state'; export * from './visualizer.component'; diff --git a/src/components/visualization/nodes/call-query-generation.node.ts b/src/components/visualization/nodes/call-query-generation.node.ts deleted file mode 100644 index 2a390db..0000000 --- a/src/components/visualization/nodes/call-query-generation.node.ts +++ /dev/null @@ -1,54 +0,0 @@ -import {service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; -import {DbQueryGraph, POST_DATASET_TAG} from '../../db-query'; -import {VisualizationGraphNodes} from '../nodes.enum'; -import {VisualizationGraphState} from '../state'; - -@graphNode(VisualizationGraphNodes.CallQueryGeneration, { - [POST_DATASET_TAG]: true, -}) -export class CallQueryGenerationNode implements IGraphNode { - constructor( - @service(DbQueryGraph) - private readonly queryPipeline: DbQueryGraph, - ) {} - async execute( - state: VisualizationGraphState, - config: RunnableConfig, - ): Promise { - if (state.datasetId) { - return state; - } - - const queryGraph = await this.queryPipeline.build(); - - const result = await queryGraph.invoke( - { - datasetId: state.datasetId, - directCall: true, - prompt: `Generate a query to fetch data for visualization based on the following user prompt: ${state.prompt}.${state.visualizer?.context ? ` Ensure that the query structure satisfies the following context: ${state.visualizer.context}` : ''}`, - }, - config, - ); - - if (!result.datasetId) { - config.writer?.({ - type: LLMStreamEventType.Error, - data: { - status: `Failed to create dataset for visualization: ${result.replyToUser ?? 'Unknown error'}`, - }, - }); - return { - ...state, - error: - result.replyToUser ?? 'Failed to create dataset for visualization', - }; - } - - return { - ...state, - datasetId: result.datasetId, - }; - } -} diff --git a/src/components/visualization/nodes/get-dataset-data.node.ts b/src/components/visualization/nodes/get-dataset-data.node.ts deleted file mode 100644 index 6a2fd81..0000000 --- a/src/components/visualization/nodes/get-dataset-data.node.ts +++ /dev/null @@ -1,38 +0,0 @@ -import {inject} from '@loopback/context'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; -import {VisualizationGraphState} from '../state'; -import {VisualizationGraphNodes} from '../nodes.enum'; -import { - DbQueryAIExtensionBindings, - IDataSetStore, - POST_DATASET_TAG, -} from '../..'; - -@graphNode(VisualizationGraphNodes.GetDatasetData, { - [POST_DATASET_TAG]: true, -}) -export class GetDatasetDataNode implements IGraphNode { - constructor( - @inject(DbQueryAIExtensionBindings.DatasetStore) - private readonly store: IDataSetStore, - ) {} - - async execute( - state: VisualizationGraphState, - config: RunnableConfig, - ): Promise { - const dataset = await this.store.findById(state.datasetId); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: 'Preparing visualization', - }, - }); - return { - ...state, - sql: dataset.query, - queryDescription: dataset.description, - }; - } -} diff --git a/src/components/visualization/nodes/index.ts b/src/components/visualization/nodes/index.ts deleted file mode 100644 index b54d513..0000000 --- a/src/components/visualization/nodes/index.ts +++ /dev/null @@ -1,4 +0,0 @@ -export * from './get-dataset-data.node'; -export * from './render-visualization.node'; -export * from './select-visualization.node'; -export * from './call-query-generation.node'; diff --git a/src/components/visualization/nodes/render-visualization.node.ts b/src/components/visualization/nodes/render-visualization.node.ts deleted file mode 100644 index 262fd31..0000000 --- a/src/components/visualization/nodes/render-visualization.node.ts +++ /dev/null @@ -1,47 +0,0 @@ -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, ToolStatus} from '../../../graphs'; -import {VisualizationGraphState} from '../state'; -import {VisualizationGraphNodes} from '../nodes.enum'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {POST_DATASET_TAG} from '../../db-query'; - -@graphNode(VisualizationGraphNodes.RenderVisualization, { - [POST_DATASET_TAG]: true, -}) -export class RenderVisualizationNode implements IGraphNode { - constructor() {} - - async execute( - state: VisualizationGraphState, - config: LangGraphRunnableConfig, - ): Promise { - const visualizer = state.visualizer; - if (!visualizer || !state.sql || !state.queryDescription) { - throw new Error('Invalid State'); - } - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: `Configuring ${visualizer.name}`, - }, - }); - const settings = await visualizer.getConfig(state); - - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: ToolStatus.Completed, - data: { - datasetId: state.datasetId, - visualization: visualizer.name, - config: settings || {}, - }, - }, - }); - return { - ...state, - done: true, - visualizerConfig: settings || {}, - }; - } -} diff --git a/src/components/visualization/nodes/select-visualization.node.ts b/src/components/visualization/nodes/select-visualization.node.ts deleted file mode 100644 index 8b8ab70..0000000 --- a/src/components/visualization/nodes/select-visualization.node.ts +++ /dev/null @@ -1,136 +0,0 @@ -import {Context, inject} from '@loopback/context'; -import {graphNode} from '../../../decorators'; -import {IGraphNode, LLMStreamEventType, RunnableConfig} from '../../../graphs'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {VisualizationGraphState} from '../state'; -import {VisualizationGraphNodes} from '../nodes.enum'; -import {PromptTemplate} from '@langchain/core/prompts'; -import {stripThinkingTokens} from '../../../utils'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {VISUALIZATION_KEY} from '../keys'; -import {IVisualizer} from '../types'; -import {POST_DATASET_TAG} from '../../db-query'; - -@graphNode(VisualizationGraphNodes.SelectVisualisation, { - [POST_DATASET_TAG]: true, -}) -export class SelectVisualizationNode implements IGraphNode { - prompt = PromptTemplate.fromTemplate(` - -You are expert Data Analysis Agent whose job is to suggest visualisations that would be best suited to display the results for a particular user prompt and the data extracted based on that prompt. -You are provided with 2 inputs - -- user prompt -- A list of visualization names with their descriptions that are supported. - -You need to suggest a visualisation from a list of visualisation that would best fit the user's request. - - - -{prompt} - - - -{sql} - - -{description} - - - -{visualizations} - - - - -The output should be a single string that has the name from the visualizations list and nothing else. -If none of the visualizations fit the requirement, return "none" followed by the changes required in the data to be able to render the visualization. -Do not try to force fit the prompt to any visualization if it does not make sense. Prefer to returning none with appropriate reason instead. - - -type-of-visualization - - -none: reason why the visualization is not possible with the current prompt. - - -`); - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - @inject.context() - private readonly context: Context, - ) {} - - async execute( - state: VisualizationGraphState, - config: RunnableConfig, - ): Promise { - const visualizations = await this._getVisualizations(); - if (state.type) { - const selected = visualizations.find(v => v.name === state.type); - if (!selected) { - throw new Error( - `No visualizer found with name ${state.type}, available visualizers are ${visualizations - .map(v => v.name) - .join(', ')}`, - ); - } - return { - ...state, - visualizer: selected, - visualizerName: selected.name, - }; - } - const chain = RunnableSequence.from([ - this.prompt, - this.llm, - stripThinkingTokens, - ]); - config.writer?.({ - type: LLMStreamEventType.ToolStatus, - data: { - status: 'Selecting best visualization for the data', - }, - }); - const output = await chain.invoke({ - prompt: state.prompt, - sql: state.sql, - description: state.queryDescription, - visualizations: visualizations - .map(v => `- ${v.name}: ${v.description}`) - .join('\n'), - }); - if (output.trim().startsWith('none')) { - return { - ...state, - error: output.trim().substring(4).trim(), - }; - } - const selected = visualizations.find(v => v.name === output.trim()); - if (!selected) { - throw new Error( - `No visualizer found with name ${output.trim()}, available visualizers are ${visualizations - .map(v => v.name) - .join(', ')}`, - ); - } - return { - ...state, - visualizer: selected, - visualizerName: selected.name, - }; - } - - private async _getVisualizations() { - const bindings = this.context.findByTag({ - [VISUALIZATION_KEY]: true, - }); - if (bindings.length === 0) { - throw new Error(`Node with key ${VISUALIZATION_KEY} not found`); - } - return Promise.all( - bindings.map(binding => this.context.get(binding.key)), - ); - } -} diff --git a/src/components/visualization/state.ts b/src/components/visualization/state.ts index 778e3da..9ba2d4b 100644 --- a/src/components/visualization/state.ts +++ b/src/components/visualization/state.ts @@ -1,19 +1,20 @@ -import {Annotation} from '@langchain/langgraph'; -import {IVisualizer} from './types'; import {AnyObject} from '@loopback/repository'; +import {IVisualizer} from './types'; -export const VisualizationGraphStateAnnotation = Annotation.Root({ - prompt: Annotation, - datasetId: Annotation, - sql: Annotation, - queryDescription: Annotation, - visualizer: Annotation, - visualizerName: Annotation, - done: Annotation, - visualizerConfig: Annotation, - error: Annotation, - type: Annotation, -}); - -export type VisualizationGraphState = - typeof VisualizationGraphStateAnnotation.State; +/** + * State shape threaded through every step of the Mastra Visualization workflow. + * Previously generated by `@langchain/langgraph` `Annotation.Root()` — + * rewritten as a plain TypeScript interface with zero LangGraph dependency. + */ +export interface VisualizationGraphState { + prompt: string; + datasetId: string; + sql?: string; + queryDescription?: string; + visualizer?: IVisualizer; + visualizerName?: string; + done?: boolean; + visualizerConfig?: AnyObject; + error?: string; + type?: string; +} diff --git a/src/components/visualization/tools/generate-visualization.tool.ts b/src/components/visualization/tools/generate-visualization.tool.ts index 42eda85..53617cb 100644 --- a/src/components/visualization/tools/generate-visualization.tool.ts +++ b/src/components/visualization/tools/generate-visualization.tool.ts @@ -2,10 +2,11 @@ import {Context, inject, service} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; import {z} from 'zod'; import {graphTool} from '../../../decorators'; -import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../graphs'; -import {VisualizationGraph} from '../visualization.graph'; -import {VISUALIZATION_KEY} from '../keys'; -import {IVisualizer} from '../types'; +import {IGraphTool, IRuntimeTool, ToolStatus} from '../../../types/tool'; +import { + MastraVisualizationWorkflow, + MastraVisualizationContext, +} from '../../../mastra/visualization'; @graphTool({ description: `Generates a visualization for the user's request. It takes in a prompt and an optional dataset ID. @@ -62,11 +63,12 @@ It does not return anything, instead it fires an event internally that renders t `Type of visualization to be generated (e.g. bar, line, pie). If not provided, the system will decide the best visualization based on the data and prompt.`, ), }); + constructor( - @service(VisualizationGraph) - private readonly visualizationGraph: VisualizationGraph, @inject.context() private readonly context: Context, + @service(MastraVisualizationWorkflow) + private readonly mastraWorkflow: MastraVisualizationWorkflow, ) {} getValue(result: Record): string { @@ -90,39 +92,31 @@ It does not return anything, instead it fires an event internally that renders t } /** - * Creates a runtime-agnostic visualization tool. + * Creates a Mastra-compatible visualization tool. */ async createTool(): Promise { - const visualizations = await this._getVisualizations(); - const graph = await this.visualizationGraph.build(); - const schema = z.object({ - prompt: z - .string() - .describe( - `Prompt from the user that will be used for generating the visualization.`, - ), - datasetId: z - .string() - .optional() - .describe( - `ID of the dataset that needs to be visualized. Use the dataset ID from 'get-data-as-dataset' or 'improve-dataset' tool if available. If not provided, the tool will internally fetch the data.`, - ), - type: z - .string() - .optional() - .describe( - `Type of visualization to be generated. It can be one of the following: ${visualizations.map(v => v.name).join(', ')}. If not provided, the system will decide the best visualization based on the data and prompt.`, - ), - }) as AnyObject[string]; - return graph.asTool({ + return { name: this.key, - description: `Generates a visualization for the user's request. It takes in a prompt and an optional dataset ID. -If the user's request involves trends, growth, decline, comparisons, distributions, patterns, correlations, or any analytical insight, ALWAYS use this tool instead of 'get-data-as-dataset'. -No need to call 'get-data-as-dataset' tool before this — if the dataset ID is not provided, this tool will internally fetch the data to be visualized. -It does not return anything, instead it fires an event internally that renders the visualization on the UI for the user to see. -It supports the following types of visualizations: ${visualizations.map(v => v.name).join(', ')}.`, - schema, - }); + description: this.description, + schema: this.inputSchema, + invoke: async ( + input: unknown, + opts?: { + writer?: MastraVisualizationContext['writer']; + signal?: AbortSignal; + }, + ) => { + const {prompt, datasetId, type} = input as { + prompt: string; + datasetId?: string; + type?: string; + }; + return this.mastraWorkflow.run( + {prompt, datasetId, type}, + {writer: opts?.writer, signal: opts?.signal}, + ); + }, + } as IRuntimeTool; } /** @@ -131,16 +125,4 @@ It supports the following types of visualizations: ${visualizations.map(v => v.n async build(): Promise { return this.createTool(); } - - private async _getVisualizations() { - const bindings = this.context.findByTag({ - [VISUALIZATION_KEY]: true, - }); - if (bindings.length === 0) { - throw new Error(`Node with key ${VISUALIZATION_KEY} not found`); - } - return Promise.all( - bindings.map(binding => this.context.get(binding.key)), - ); - } } diff --git a/src/components/visualization/visualization.graph.ts b/src/components/visualization/visualization.graph.ts deleted file mode 100644 index 65eb05e..0000000 --- a/src/components/visualization/visualization.graph.ts +++ /dev/null @@ -1,63 +0,0 @@ -import {END, START, StateGraph} from '@langchain/langgraph'; -import {BaseGraph} from '../../graphs'; -import { - VisualizationGraphState, - VisualizationGraphStateAnnotation, -} from './state'; -import {VisualizationGraphNodes} from './nodes.enum'; - -export class VisualizationGraph extends BaseGraph { - async build() { - const graph = new StateGraph(VisualizationGraphStateAnnotation); - graph - .addNode( - VisualizationGraphNodes.CallQueryGeneration, - await this._getNodeFn(VisualizationGraphNodes.CallQueryGeneration), - ) - .addNode( - VisualizationGraphNodes.GetDatasetData, - await this._getNodeFn(VisualizationGraphNodes.GetDatasetData), - ) - .addNode( - VisualizationGraphNodes.SelectVisualisation, - await this._getNodeFn(VisualizationGraphNodes.SelectVisualisation), - ) - .addNode( - VisualizationGraphNodes.RenderVisualization, - await this._getNodeFn(VisualizationGraphNodes.RenderVisualization), - ) - .addEdge(START, VisualizationGraphNodes.SelectVisualisation) - .addConditionalEdges( - VisualizationGraphNodes.SelectVisualisation, - state => { - if (state.error) { - return 'Error'; - } - return 'Success'; - }, - { - Error: END, - Success: VisualizationGraphNodes.CallQueryGeneration, - }, - ) - .addConditionalEdges( - VisualizationGraphNodes.CallQueryGeneration, - state => { - if (state.error) { - return 'Error'; - } - return 'Success'; - }, - { - Error: END, - Success: VisualizationGraphNodes.GetDatasetData, - }, - ) - .addEdge( - VisualizationGraphNodes.GetDatasetData, - VisualizationGraphNodes.RenderVisualization, - ) - .addEdge(VisualizationGraphNodes.RenderVisualization, END); - return graph.compile(); - } -} diff --git a/src/components/visualization/visualizer.component.ts b/src/components/visualization/visualizer.component.ts index b734744..718e5fc 100644 --- a/src/components/visualization/visualizer.component.ts +++ b/src/components/visualization/visualizer.component.ts @@ -8,15 +8,13 @@ import { ServiceOrProviderClass, } from '@loopback/core'; import {AnyObject} from '@loopback/repository'; -import {VisualizationGraph} from './visualization.graph'; -import { - CallQueryGenerationNode, - GetDatasetDataNode, - RenderVisualizationNode, - SelectVisualizationNode, -} from './nodes'; import {GenerateVisualizationTool} from './tools/generate-visualization.tool'; -import {PieVisualizer, BarVisualizer, LineVisualizer} from './visualizers'; +import { + MastraVisualizationWorkflow, + MastraBarVisualizerService, + MastraLineVisualizerService, + MastraPieVisualizerService, +} from '../../mastra/visualization'; export class VisualizerComponent implements Component { services: ServiceOrProviderClass[] | undefined; @@ -32,19 +30,16 @@ export class VisualizerComponent implements Component { this.bindings = []; this.lifeCycleObservers = []; this.services = [ - // graph - VisualizationGraph, // tools GenerateVisualizationTool, - // nodes - GetDatasetDataNode, - SelectVisualizationNode, - RenderVisualizationNode, - CallQueryGenerationNode, - // visualizers - PieVisualizer, - BarVisualizer, - LineVisualizer, + + // ── Mastra path ────────────────────────────────────────────────────── + // Workflow orchestrator + MastraVisualizationWorkflow, + // Visualizer services (use AI SDK generateObject()) + MastraBarVisualizerService, + MastraLineVisualizerService, + MastraPieVisualizerService, ]; this.components = []; } diff --git a/src/components/visualization/visualizers/bar.visualizer.ts b/src/components/visualization/visualizers/bar.visualizer.ts deleted file mode 100644 index cd44ad5..0000000 --- a/src/components/visualization/visualizers/bar.visualizer.ts +++ /dev/null @@ -1,79 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {IVisualizer} from '../types'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {inject} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import {VisualizationGraphState} from '../state'; -import z from 'zod'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {visualizer} from '../decorators/visualizer.decorator'; - -@visualizer() -export class BarVisualizer implements IVisualizer { - name = 'bar'; - description = `Renders the data in a bar chart format. Best for comparing values across different categories or showing trends over time.`; - renderPrompt = PromptTemplate.fromTemplate(` - -You are an expert data visualization assistant. Your task is to create a bar chart config based on the provided SQL query, it's description and user prompt. Follow these steps: -1. Analyze the SQL query results to understand the data structure. -2. Identify the category column (x-axis) and value column (y-axis) for the bar chart. -3. Create a configuration object for the bar chart using the identified columns. -4. Return the bar chart configuration object. - - - -{sql} - - -{description} - - -{userPrompt} - -`); - - context?: string | undefined = - `A bar chart requires data with at exactly two columns: one for the categories (x-axis) and one for the values (y-axis). Ensure that the category column contains discrete values representing different groups or categories, while the value column contains numerical data that can be compared across these categories. Bar charts can be oriented either vertically or horizontally depending on the data representation needs.`; - - schema = z.object({ - categoryColumn: z - .string() - .describe('Column to be used for categories (x-axis) in the bar chart'), - valueColumn: z - .string() - .describe('Column to be used for values (y-axis) in the bar chart'), - orientation: z - .string() - .default('vertical') - .describe( - 'Orientation of the bar chart: `vertical` or `horizontal` without backticks', - ), - }) as z.AnyZodObject; - - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - ) {} - - async getConfig(state: VisualizationGraphState): Promise { - if (!state.sql || !state.queryDescription || !state.prompt) { - throw new Error('Invalid State'); - } - const llmWithStructuredOutput = this.llm.withStructuredOutput( - this.schema, - ); - - const chain = RunnableSequence.from([ - this.renderPrompt, - llmWithStructuredOutput, - ]); - - const settings = await chain.invoke({ - sql: state.sql!, - description: state.queryDescription!, - userPrompt: state.prompt!, - }); - return settings; - } -} diff --git a/src/components/visualization/visualizers/index.ts b/src/components/visualization/visualizers/index.ts deleted file mode 100644 index b0080b5..0000000 --- a/src/components/visualization/visualizers/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export * from './pie.visualizer'; -export * from './bar.visualizer'; -export * from './line.visualizer'; diff --git a/src/components/visualization/visualizers/line.visualizer.ts b/src/components/visualization/visualizers/line.visualizer.ts deleted file mode 100644 index d0356a5..0000000 --- a/src/components/visualization/visualizers/line.visualizer.ts +++ /dev/null @@ -1,94 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {IVisualizer} from '../types'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {inject} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import {VisualizationGraphState} from '../state'; -import z from 'zod'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {visualizer} from '../decorators/visualizer.decorator'; - -@visualizer() -export class LineVisualizer implements IVisualizer { - name = 'line'; - - context?: string | undefined = - `A line chart requires data with exactly 3 columns: one for the x-axis (typically time or sequential data), one for the y-axis (values), and one series type column to distinguish multiple lines/series in the chart. The series type column is important for grouping data into separate lines.`; - - description = `Renders the data in a line chart format. Best for showing trends and changes over time or continuous data.`; - renderPrompt = PromptTemplate.fromTemplate(` - -You are an expert data visualization assistant. Your task is to create a line chart config based on the provided SQL query, it's description and user prompt. Follow these steps: -1. Analyze the SQL query results to understand the data structure. -2. Identify the x-axis column (typically time or sequential data) and y-axis column (values) for the line chart. -3. Determine if there are multiple series to be plotted (multiple lines) with combination of multiple columns, or single series based on single column. -4. Create a configuration object for the line chart using the identified columns. -5. Return the line chart configuration object. - - - -{sql} - - -{description} - - -{userPrompt} - -`); - - schema = z.object({ - xAxisColumn: z - .string() - .describe( - 'Single column name to be used for x-axis in the line chart (typically time or sequential data)', - ), - yAxisColumn: z - .string() - .describe( - 'Single column name to be used for y-axis values in the line chart', - ), - seriesColumns: z - .string() - .describe( - 'Optional column to group data into multiple lines/series, leave it as empty string if not needed. It can cover multiple columns separated by comma if the query needs to show multiple lines based on multiple columns. The UI supports multiple series in line chart by forming a combined key.', - ), - }) as z.AnyZodObject; - - constructor( - @inject(AiIntegrationBindings.SmartNonThinkingLLM) - private readonly llm: RuntimeLLMProvider, - ) {} - - async getConfig(state: VisualizationGraphState): Promise { - if (!state.sql || !state.queryDescription || !state.prompt) { - throw new Error('Invalid State'); - } - const llmWithStructuredOutput = this.llm.withStructuredOutput( - this.schema, - ); - - const chain = RunnableSequence.from([ - this.renderPrompt, - llmWithStructuredOutput, - ]); - - const settings = await chain.invoke({ - sql: state.sql!, - description: state.queryDescription!, - userPrompt: state.prompt!, - }); - if ( - settings.seriesColumns === '' || - settings.seriesColumns === undefined || - settings.seriesColumns === null - ) { - settings.seriesColumns = null; - } else { - settings.seriesColumns = - settings.seriesColumns?.split(',').map((s: string) => s.trim()) ?? []; - } - return settings; - } -} diff --git a/src/components/visualization/visualizers/pie.visualizer.ts b/src/components/visualization/visualizers/pie.visualizer.ts deleted file mode 100644 index 4004ce5..0000000 --- a/src/components/visualization/visualizers/pie.visualizer.ts +++ /dev/null @@ -1,73 +0,0 @@ -import {PromptTemplate} from '@langchain/core/prompts'; -import {IVisualizer} from '../types'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {inject} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import {VisualizationGraphState} from '../state'; -import z from 'zod'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {visualizer} from '../decorators/visualizer.decorator'; - -@visualizer() -export class PieVisualizer implements IVisualizer { - name = 'pie'; - description = `Renders the data in a pie chart format. Best for visualizing proportions and percentages among categories.`; - renderPrompt = PromptTemplate.fromTemplate(` - -You are an expert data visualization assistant. Your task is to create a pie chart config based on the provided SQL query, it's description and user prompt. Follow these steps: -1. Analyze the SQL query results to understand the data structure. -2. Identify the key categories and their corresponding values for the pie chart. -3. Create a configuration object for the pie chart using the identified categories and values. -4. Return the pie chart configuration object. - - - -{sql} - - -{description} - - -{userPrompt} - -`); - - context?: string | undefined = - `A pie chart requires data with at least two columns: one for the labels (categories) and one for the values (numerical data). Ensure that the values are non-negative and represent parts of a whole, as pie charts are used to visualize proportions and percentages among different categories.`; - - schema = z.object({ - labelColumn: z - .string() - .describe('Column to be used for labels in the pie chart'), - valueColumn: z - .string() - .describe('Column to be used for values in the pie chart'), - }) as z.AnyZodObject; - - constructor( - @inject(AiIntegrationBindings.CheapLLM) - private readonly llm: RuntimeLLMProvider, - ) {} - - async getConfig(state: VisualizationGraphState): Promise { - if (!state.sql || !state.queryDescription || !state.prompt) { - throw new Error('Invalid State'); - } - const llmWithStructuredOutput = this.llm.withStructuredOutput( - this.schema, - ); - - const chain = RunnableSequence.from([ - this.renderPrompt, - llmWithStructuredOutput, - ]); - - const settings = await chain.invoke({ - sql: state.sql!, - description: state.queryDescription!, - userPrompt: state.prompt!, - }); - return settings; - } -} diff --git a/src/controllers/chat.controller.ts b/src/controllers/chat.controller.ts index c1013ea..9e1ce80 100644 --- a/src/controllers/chat.controller.ts +++ b/src/controllers/chat.controller.ts @@ -4,7 +4,7 @@ import {get, param} from '@loopback/rest'; import {OPERATION_SECURITY_SPEC} from '@sourceloop/core'; import {authenticate, STRATEGY} from 'loopback4-authentication'; import {authorize} from 'loopback4-authorization'; -import {ChatStore} from '../graphs/chat/chat.store'; +import {ChatStore} from '../services/chat.store'; import {Chat} from '../models'; import {PermissionKey} from '../permissions'; diff --git a/src/graphs/base.graph.ts b/src/graphs/base.graph.ts deleted file mode 100644 index cc61e4c..0000000 --- a/src/graphs/base.graph.ts +++ /dev/null @@ -1,27 +0,0 @@ -import {CompiledGraph} from '@langchain/langgraph'; -import {Context, inject} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import {GRAPH_NODE_NAME} from '../constant'; -import {IGraphNode, resolveNodeExecution} from './types'; - -export abstract class BaseGraph { - @inject.context() - protected context: Context; - - abstract build(): Promise>; - - protected async _getNodeFn(key: string) { - const bindings = this.context.findByTag({ - [GRAPH_NODE_NAME]: key, - }); - if (bindings.length === 0) { - throw new Error(`Node with key ${key} not found`); - } - if (bindings.length > 1) { - throw new Error(`Multiple nodes found with key ${key}`); - } - const binding = bindings[0]; - const node = await this.context.get>(binding.key); - return resolveNodeExecution(node); - } -} diff --git a/src/graphs/chat/chat.graph.ts b/src/graphs/chat/chat.graph.ts deleted file mode 100644 index 7e0cdc0..0000000 --- a/src/graphs/chat/chat.graph.ts +++ /dev/null @@ -1,144 +0,0 @@ -import {AIMessage} from '@langchain/core/messages'; -import {END, START, StateGraph} from '@langchain/langgraph'; -import {BindingScope, inject, injectable} from '@loopback/core'; -import {AiIntegrationBindings} from '../../keys'; -import {TokenCounter} from '../../services/token-counter.service'; -import {ToolStore} from '../../types'; -import {BaseGraph} from '../base.graph'; -import {ChatGraphAnnotation, ChatState} from '../state'; -import {ChatNodes} from './nodes.enum'; -import {AnyObject} from '@loopback/repository'; - -@injectable({scope: BindingScope.REQUEST}) -export class ChatGraph extends BaseGraph { - constructor( - @inject(AiIntegrationBindings.Tools) - private readonly tools: ToolStore, - @inject('services.TokenCounter') - private readonly tokenCounter: TokenCounter, - @inject(AiIntegrationBindings.ObfHandler, {optional: true}) - protected readonly obfHandler?: AnyObject[string], - ) { - super(); - } - async execute( - query: string, - files: Express.Multer.File[] | Express.Multer.File, - abort: AbortSignal, - id?: string, - ) { - let fileArray: Express.Multer.File[] = []; - if (Array.isArray(files)) { - fileArray = files; - } else if (files) { - fileArray.push(files); - } else { - // do nothing if no files are provided - } - const graph = await this.build(); - - const inputs: ChatState = { - id, - messages: [], - files: fileArray, - prompt: query, - userMessage: undefined, - aiMessage: undefined, - }; - - return graph.stream(inputs, { - streamMode: 'custom' as const, - recursionLimit: 60, - configurable: { - // eslint-disable-next-line @typescript-eslint/naming-convention - thread_id: id, - }, - signal: abort, - callbacks: [ - { - handleLLMStart: ( - llm, - prompts, - runId, - parentRunId, - extraParams, - tags, - metadata, - ) => { - this.tokenCounter.handleLlmStart( - runId, - (metadata?.ls_model_name as string) || 'unknown', - ); - }, - handleLLMEnd: (output, runId) => { - this.tokenCounter.handleLlmEnd(runId, output); - }, - }, - this.obfHandler ? this.obfHandler : {}, - ], - }); - } - async build() { - const graph = new StateGraph(ChatGraphAnnotation); - const toolsMap = this.tools.map; - // add nodes - graph - .addNode( - ChatNodes.TrimMessages, - await this._getNodeFn(ChatNodes.TrimMessages), - ) - .addNode(ChatNodes.CallLLM, await this._getNodeFn(ChatNodes.CallLLM)) - .addNode(ChatNodes.RunTool, await this._getNodeFn(ChatNodes.RunTool)) - .addNode( - ChatNodes.SummariseFile, - await this._getNodeFn(ChatNodes.SummariseFile), - ) - .addNode( - ChatNodes.InitSession, - await this._getNodeFn(ChatNodes.InitSession), - ) - .addNode( - ChatNodes.EndSession, - await this._getNodeFn(ChatNodes.EndSession), - ) - // add edges - .addEdge(START, ChatNodes.InitSession) - .addEdge(ChatNodes.InitSession, ChatNodes.SummariseFile) - .addConditionalEdges( - ChatNodes.SummariseFile, - (state: ChatState) => { - if (state.files && state.files.length > 0) { - return ChatNodes.SummariseFile; - } - return ChatNodes.CallLLM; - }, - [ChatNodes.SummariseFile, ChatNodes.CallLLM], - ) - .addConditionalEdges( - ChatNodes.CallLLM, - (state: ChatState) => { - const lastMessage = state.messages[ - state.messages.length - 1 - ] as AIMessage; - if (!lastMessage?.tool_calls?.length) { - return ChatNodes.EndSession; - } - if (toolsMap[lastMessage?.tool_calls[0].name].needsReview === false) { - return ChatNodes.RunTool; - } else { - throw new Error( - `Tool ${lastMessage.tool_calls[0].name} requires user review which is not implemented yet.`, - ); - } - }, - [ChatNodes.RunTool, ChatNodes.EndSession], - ) - .addEdge(ChatNodes.RunTool, ChatNodes.TrimMessages) - .addEdge(ChatNodes.TrimMessages, ChatNodes.CallLLM) - .addEdge(ChatNodes.EndSession, END); - - const compiled = graph.compile({}); - - return compiled; - } -} diff --git a/src/graphs/chat/index.ts b/src/graphs/chat/index.ts deleted file mode 100644 index 219b49f..0000000 --- a/src/graphs/chat/index.ts +++ /dev/null @@ -1,5 +0,0 @@ -export * from './chat-metadata.type'; -export * from './chat.graph'; -export * from './chat.store'; -export * from './nodes'; -export * from './nodes.enum'; diff --git a/src/graphs/chat/nodes.enum.ts b/src/graphs/chat/nodes.enum.ts deleted file mode 100644 index af1359f..0000000 --- a/src/graphs/chat/nodes.enum.ts +++ /dev/null @@ -1,8 +0,0 @@ -export enum ChatNodes { - CallLLM = 'call_llm', - TrimMessages = 'trim_messages', - RunTool = 'run_tool', - SummariseFile = 'summarise_file', - InitSession = 'init_session', - EndSession = 'end_session', -} diff --git a/src/graphs/chat/nodes/call-llm.node.ts b/src/graphs/chat/nodes/call-llm.node.ts deleted file mode 100644 index b065798..0000000 --- a/src/graphs/chat/nodes/call-llm.node.ts +++ /dev/null @@ -1,57 +0,0 @@ -import {AIMessage} from '@langchain/core/messages'; -import {inject} from '@loopback/context'; -import {service} from '@loopback/core'; -import {HttpErrors} from '@loopback/rest'; -import {graphNode} from '../../../decorators'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider, ToolStore} from '../../../types'; -import {getTextContent} from '../../../utils'; -import {LLMStreamEventType} from '../../event.types'; -import {ChatState} from '../../state'; -import {IGraphNode, resolveGraphTool, RunnableConfig} from '../../types'; -import {ChatStore} from '../chat.store'; -import {ChatNodes} from '../nodes.enum'; - -const debug = require('debug')('ai-integration:chat:call-llm.node'); - -@graphNode(ChatNodes.CallLLM) -export class CallLLMNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.ChatLLM) - private readonly llm: RuntimeLLMProvider, - @inject(AiIntegrationBindings.Tools) - private readonly tools: ToolStore, - @service(ChatStore) - private readonly chatStore: ChatStore, - ) {} - - async execute(state: ChatState, config: RunnableConfig): Promise { - const tools = await Promise.all( - this.tools.list.map(tool => resolveGraphTool(tool, config)), - ); - debug( - 'Calling LLM with tools:', - tools.map(tool => tool.name), - ); - const response: AIMessage = await this.llm - .bindTools(tools) - .invoke(state.messages); - const text = getTextContent(response.content).trim(); - if (!state.id) { - debug('No chat ID found in state, this is unexpected'); - throw new HttpErrors.InternalServerError(); - } - const aiMessage = await this.chatStore.addAIMessage(state.id, response); - - if (text) { - config.writer?.({ - type: LLMStreamEventType.Message, - data: { - message: getTextContent(response.content), - }, - }); - } - - return {...state, messages: [response], aiMessage}; - } -} diff --git a/src/graphs/chat/nodes/context-compression.node.ts b/src/graphs/chat/nodes/context-compression.node.ts deleted file mode 100644 index 7cca76a..0000000 --- a/src/graphs/chat/nodes/context-compression.node.ts +++ /dev/null @@ -1,50 +0,0 @@ -import {trimMessages} from '@langchain/core/messages'; -import {inject} from '@loopback/core'; -import {DEFAULT_MAX_TOKEN_COUNT} from '../../../constant'; -import {graphNode} from '../../../decorators'; -import {AiIntegrationBindings} from '../../../keys'; -import {AIIntegrationConfig} from '../../../types'; -import {approxTokenCounter} from '../../../utils'; -import {LLMStreamEventType} from '../../event.types'; -import {ChatState} from '../../state'; -import {IGraphNode, RunnableConfig} from '../../types'; -import {ChatNodes} from '../nodes.enum'; - -@graphNode(ChatNodes.TrimMessages) -export class ContextCompressionNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.Config) - private readonly config: AIIntegrationConfig, - ) {} - - async execute(state: ChatState, config: RunnableConfig): Promise { - const maxTokenCount = +( - this.config.maxTokenCount ?? - process.env.MAX_TOKEN_COUNT ?? - DEFAULT_MAX_TOKEN_COUNT - ); - const tokenCount = state.messages.reduce( - (count, message) => count + approxTokenCounter(message.content), - 0, - ); - - if (tokenCount > maxTokenCount) { - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Token count ${tokenCount} exceeds limit ${maxTokenCount}. Compressing context.`, - }); - const trimmed = await trimMessages(state.messages, { - maxTokens: maxTokenCount, - strategy: 'last', - tokenCounter: approxTokenCounter, - includeSystem: true, - }); - return { - ...state, - messages: trimmed, - }; - } - - return state; - } -} diff --git a/src/graphs/chat/nodes/end-session.node.ts b/src/graphs/chat/nodes/end-session.node.ts deleted file mode 100644 index 497669c..0000000 --- a/src/graphs/chat/nodes/end-session.node.ts +++ /dev/null @@ -1,43 +0,0 @@ -import {service} from '@loopback/core'; -import {HttpErrors} from '@loopback/rest'; -import {ChatStore} from '..'; -import {graphNode} from '../../../decorators'; -import {TokenCounter} from '../../../services'; -import {LLMStreamEventType} from '../../event.types'; -import {ChatState} from '../../state'; -import {IGraphNode, RunnableConfig} from '../../types'; -import {ChatNodes} from '../nodes.enum'; -const debug = require('debug')('ai-integration:chat:end-session.node'); -@graphNode(ChatNodes.EndSession) -export class EndSessionNode implements IGraphNode { - constructor( - @service(ChatStore) - private readonly chatStore: ChatStore, - @service(TokenCounter) - private readonly tokenCounter: TokenCounter, - ) {} - async execute(state: ChatState, config: RunnableConfig): Promise { - const tokenCounts = this.tokenCounter.getCounts(); - config.writer?.({ - type: LLMStreamEventType.TokenCount, - data: { - inputTokens: tokenCounts.inputs, - outputTokens: tokenCounts.outputs, - }, - }); - if (!state.id) { - // If the chat ID is not defined, we cannot proceed with the session end. - debug('No chat ID found in state, this is unexpected'); - throw new HttpErrors.InternalServerError(); - } - await this.chatStore.updateCounts( - state.id, - tokenCounts.inputs, - tokenCounts.outputs, - tokenCounts.map, - ); - // This node is used to end the session, so we can return the state as is. - // You might want to add any cleanup logic here if needed. - return state; - } -} diff --git a/src/graphs/chat/nodes/index.ts b/src/graphs/chat/nodes/index.ts deleted file mode 100644 index 27b017c..0000000 --- a/src/graphs/chat/nodes/index.ts +++ /dev/null @@ -1,6 +0,0 @@ -export * from './call-llm.node'; -export * from './context-compression.node'; -export * from './end-session.node'; -export * from './init-session.node'; -export * from './run-tool.node'; -export * from './summarise-file.node'; diff --git a/src/graphs/chat/nodes/init-session.node.ts b/src/graphs/chat/nodes/init-session.node.ts deleted file mode 100644 index 3f5a221..0000000 --- a/src/graphs/chat/nodes/init-session.node.ts +++ /dev/null @@ -1,81 +0,0 @@ -import { - BaseMessage, - HumanMessage, - SystemMessage, -} from '@langchain/core/messages'; -import {LangGraphRunnableConfig} from '@langchain/langgraph'; -import {inject, service} from '@loopback/core'; -import {graphNode} from '../../../decorators'; -import {Message} from '../../../models'; -import {LLMStreamEventType} from '../../event.types'; -import {ChatState} from '../../state'; -import {IGraphNode} from '../../types'; -import {ChatStore} from '../chat.store'; -import {ChatNodes} from '../nodes.enum'; -import {AiIntegrationBindings} from '../../../keys'; -const debug = require('debug')('ai-integration:chat:init-session.node'); -@graphNode(ChatNodes.InitSession) -export class InitSessionNode implements IGraphNode { - constructor( - @service(ChatStore) - private readonly chatStore: ChatStore, - @inject(AiIntegrationBindings.SystemContext, {optional: true}) - private readonly systemContext?: string[], - ) {} - async execute( - state: ChatState, - config: LangGraphRunnableConfig, - ): Promise { - const chat = await this.chatStore.init(state.prompt, state.id); - if (!state.id) { - debug(`New session created with ID: ${chat.id}`); - config.writer?.({ - type: LLMStreamEventType.Init, - data: { - sessionId: chat.id, - }, - }); - } - const userMessage = new HumanMessage({ - content: state.prompt, - }); - const savedUserMessage = await this.chatStore.addHumanMessage( - chat.id, - userMessage, - ); - return { - ...state, - id: chat.id, - userMessage: savedUserMessage, - messages: [ - new SystemMessage({ - content: [ - `You are a helpful AI assistant. You MUST always use one of the available tools to handle the user's request. Never respond with just text on the first message — always call the closest matching tool, even if you are unsure. The tool will reject the request if it is not suitable.`, - `If you are not sure about the result, you can ask the user to review the result and provide feedback.`, - `Only use a single tool in a single message, but you can use multiple tools over subsequent messages if it could help with the user's requirements.`, - `If the user provides feedback, you can use that feedback to improve the result.`, - `Do not write any redundant messages before or after tool calls, be as concise as possible.`, - `Do not hallucinate details or make up information.`, - `Do not make assumptions about user's intent beyond what is explicitly provided in the prompt, and keep this in mind while calling tools.`, - `Do not use technical jargon in the response, show any internal IDs, or implementation details to the user.`, - `Current date is ${new Date().toDateString()}`, - ...(this.systemContext ?? []), - ].join('\n'), - }), - ...(await this._formatMessage(chat.messages)), - ], - }; - } - - private async _formatMessage(messages: Message[]): Promise { - if (!messages) { - return []; - } - const graphMessages = await Promise.all( - messages.map(message => this.chatStore.toMessage(message)), - ); - return graphMessages.filter( - (message): message is BaseMessage => message !== undefined, - ); - } -} diff --git a/src/graphs/chat/nodes/run-tool.node.ts b/src/graphs/chat/nodes/run-tool.node.ts deleted file mode 100644 index 382723f..0000000 --- a/src/graphs/chat/nodes/run-tool.node.ts +++ /dev/null @@ -1,94 +0,0 @@ -import {AIMessage, ToolMessage} from '@langchain/core/messages'; -import {inject} from '@loopback/context'; -import {service} from '@loopback/core'; -import {HttpErrors} from '@loopback/rest'; -import {graphNode} from '../../../decorators'; -import {AiIntegrationBindings} from '../../../keys'; -import {ToolStore} from '../../../types'; -import {LLMStreamEventType} from '../../event.types'; -import {ChatState} from '../../state'; -import { - IGraphNode, - resolveGraphTool, - RunnableConfig, - ToolStatus, -} from '../../types'; -import {ChatStore} from '../chat.store'; -import {ChatNodes} from '../nodes.enum'; - -const debug = require('debug')('ai-integration:chat:run-tool.node'); - -@graphNode(ChatNodes.RunTool) -export class RunToolNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.Tools) - private readonly tools: ToolStore, - @service(ChatStore) - private readonly chatStore: ChatStore, - ) {} - - async execute(state: ChatState, config: RunnableConfig): Promise { - if (!state.id) { - debug('No chat ID found in state, this is unexpected'); - throw new HttpErrors.InternalServerError(); - } - if (!state.aiMessage) { - debug('No last AI message found in state, this is unexpected'); - throw new HttpErrors.InternalServerError(); - } - const newMessages: ToolMessage[] = []; - const tools = this.tools.map; - const lastMessage = state.messages[state.messages.length - 1] as AIMessage & - ToolMessage; - if ( - !lastMessage || - lastMessage.tool_call_id || - !lastMessage.tool_calls?.length - ) { - return state; - } - const toolCalls = lastMessage.tool_calls!; - - for (const toolCall of toolCalls) { - config.writer?.({ - type: LLMStreamEventType.Tool, - data: { - id: toolCall.id, - tool: toolCall.name, - data: toolCall.args, - status: ToolStatus.Running, - }, - }); - const toolObj = tools[toolCall.name as keyof typeof tools]; - const tool = await resolveGraphTool(toolObj, config); - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Running tool: ${toolCall.name} with args: ${JSON.stringify(toolCall.args, undefined, 2)}`, - }); - const result = await tool.invoke(toolCall.args); - - const output = toolObj.getValue?.(result) ?? result; - const metadata = toolObj.getMetadata?.(result) ?? {}; - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Tool output: ${output}`, - }); - const toolMessage = new ToolMessage({ - name: toolCall.name, - content: String(output), - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_call_id: toolCall.id!, - }); - - await this.chatStore.addToolMessage( - state.id, - toolMessage, - metadata, - state.aiMessage, - toolCall.args, - ); - newMessages.push(toolMessage); - } - return {...state, messages: newMessages}; - } -} diff --git a/src/graphs/chat/nodes/summarise-file.node.ts b/src/graphs/chat/nodes/summarise-file.node.ts deleted file mode 100644 index 657fd00..0000000 --- a/src/graphs/chat/nodes/summarise-file.node.ts +++ /dev/null @@ -1,148 +0,0 @@ -import {HumanMessage} from '@langchain/core/messages'; -import {PromptTemplate} from '@langchain/core/prompts'; -import {RunnableSequence} from '@langchain/core/runnables'; -import {LangGraphRunnableConfig, Messages} from '@langchain/langgraph'; -import {inject} from '@loopback/context'; -import {service} from '@loopback/core'; -import {AnyObject} from '@loopback/repository'; -import {HttpErrors} from '@loopback/rest'; -import {graphNode} from '../../../decorators'; -import {AiIntegrationBindings} from '../../../keys'; -import {RuntimeLLMProvider} from '../../../types'; -import {mergeAttachments, stripThinkingTokens} from '../../../utils'; -import {LLMStreamEventType} from '../../event.types'; -import {ChatState} from '../../state'; -import {IGraphNode} from '../../types'; -import {ChatStore} from '../chat.store'; -import {ChatNodes} from '../nodes.enum'; - -const debug = require('debug')('ai-integration:chat:summarise-file.node'); - -@graphNode(ChatNodes.SummariseFile) -export class SummariseFileNode implements IGraphNode { - constructor( - @inject(AiIntegrationBindings.FileLLM) - private readonly llm: RuntimeLLMProvider, - @service(ChatStore) - private readonly chatStore: ChatStore, - ) {} - - prompt = - PromptTemplate.fromTemplate(`You are an AI assistant that summarizes file content keeping all the important details in mind. - Make sure that you don't miss any important details and summarize the content in a concise manner. - While summarizing the content, make sure that you keep the user's prompt in mind and summarize the content in a way that it can be used to answer the user's query. - You will be provided with user's original prompt and one file among the files that user provided. - You will summarize the one file at a time so don't worry about the other files mentioned in the user's prompt. - The summary should be relatively short and only contain the important details that are relevant to the user's query. - The output should just be a plain text string without any additional markdown syntax or any special formatting. - Here is the user's prompt: - {prompt} - `); - - async execute( - state: ChatState, - config: LangGraphRunnableConfig, - ): Promise { - if (!state.id) { - debug('No chat ID found in state, this is unexpected'); - throw new HttpErrors.InternalServerError(); - } - if (!state.userMessage) { - debug('No last user message found in state, this is unexpected'); - throw new HttpErrors.InternalServerError(); - } - if (state.files && state.files.length > 0) { - const file = state.files[0]; // Assuming we are only processing the first file for now, we'll iterate till this list is empty - config.writer?.({ - type: LLMStreamEventType.Log, - data: `Processing file: ${file.originalname}`, - }); - config.writer?.({ - type: LLMStreamEventType.Status, - data: `Reading file: ${file.originalname}`, - }); - const fileContent = this.buildFileContent(file); - const prompt: Messages = [ - { - role: 'system', - content: await this.prompt.format({ - prompt: state.prompt, - }), - }, - { - role: 'user', - content: [ - { - type: 'text', - text: state.prompt, - }, - fileContent, - ], - }, - ]; - const chain = RunnableSequence.from([this.llm, stripThinkingTokens]); - - const summary = await chain.invoke(prompt); - - await this.chatStore.addAttachmentMessage( - state.id, - state.userMessage, - file, - summary, - ); - - const response = mergeAttachments( - state.prompt, - file.originalname, - summary, - ); - const newFiles = state.files.slice(1); // Remove the first file after processing - if (newFiles.length > 0) { - // If there are more files, we need to continue processing them - return { - ...state, - prompt: response, - files: newFiles, - }; - } else { - // If there are no more files, we can return the final state - return { - ...state, - prompt: response, - messages: [ - new HumanMessage({ - content: response, - }), - ], - files: [], - }; - } - } - - // code should reach here if there were no files to process to begin with - return { - ...state, - messages: [ - new HumanMessage({ - content: state.prompt, - }), - ], - files: [], - }; - } - - private buildFileContent(file: Express.Multer.File): AnyObject { - if (this.llm.getFile) { - return this.llm.getFile(file); - } else { - return { - type: 'file', - // eslint-disable-next-line @typescript-eslint/naming-convention - source_type: 'base64', - data: file.buffer?.toString('base64') ?? '', - // eslint-disable-next-line @typescript-eslint/naming-convention - mime_type: 'application/pdf', - }; - } - } -} diff --git a/src/graphs/index.ts b/src/graphs/index.ts deleted file mode 100644 index 6d30462..0000000 --- a/src/graphs/index.ts +++ /dev/null @@ -1,5 +0,0 @@ -export * from './base.graph'; -export * from './chat'; -export * from './event.types'; -export * from './state'; -export * from './types'; diff --git a/src/graphs/state.ts b/src/graphs/state.ts deleted file mode 100644 index 01e2a62..0000000 --- a/src/graphs/state.ts +++ /dev/null @@ -1,12 +0,0 @@ -import {Annotation, MessagesAnnotation} from '@langchain/langgraph'; -import {Message} from '../models'; - -export const ChatGraphAnnotation = Annotation.Root({ - ...MessagesAnnotation.spec, - id: Annotation, - files: Annotation, - prompt: Annotation, - userMessage: Annotation, - aiMessage: Annotation, -}); -export type ChatState = typeof ChatGraphAnnotation.State; diff --git a/src/graphs/types.ts b/src/graphs/types.ts deleted file mode 100644 index 2249256..0000000 --- a/src/graphs/types.ts +++ /dev/null @@ -1,143 +0,0 @@ -import {AIMessage, HumanMessage, ToolMessage} from '@langchain/core/messages'; -import {AnyObject, Command} from '@loopback/repository'; - -/** - * Runtime-agnostic execution config that can carry stream writers and runtime metadata. - */ -export type RunnableConfig = { - configurable?: Record; - signal?: AbortSignal; - writer?: (chunk: unknown) => void; -}; - -/** - * Node step execution function compatible with Mastra-like step execution. - */ -export type GraphStepExecuteFn = ( - state: T, - config: RunnableConfig, -) => Promise | Command>; - -/** - * Minimal step contract required by Phase 1 interface migration. - */ -export interface IGraphStep { - execute: GraphStepExecuteFn; -} - -/** - * Graph node contract supporting both legacy `execute` and Mastra-style `createStep`. - */ -export interface IGraphNode { - createStep?(config?: RunnableConfig): Promise> | IGraphStep; - execute?: GraphStepExecuteFn; -} - -export type SavedMessage = HumanMessage | AIMessage | ToolMessage; - -/** - * Minimal runtime tool contract shared across LangGraph and Mastra-compatible tooling. - * - * `description` and `schema` are optional but MUST be populated by any tool that - * needs to work with Mastra. LangChain StructuredTool instances (returned by - * `build()` / `createTool()`) already carry these properties; they just were not - * exposed through this interface previously. - */ -export interface IRuntimeTool { - name: string; - /** Human-readable description shown to the LLM when deciding which tool to call. */ - description?: string; - /** Zod schema describing the tool's input. Typed as `unknown` to avoid a hard - * dependency on `zod` in consuming packages; cast to `ZodObject` at the call site. */ - schema?: unknown; - invoke(input: TArgs): Promise; -} - -/** - * Tool contract supporting Mastra-style `createTool` and legacy `build` for compatibility. - */ -export interface IGraphTool { - key: string; - /** - * Human-readable description exposed at Mastra agent registration time. - * - * Populate this as a class property so the Mastra bridge factory can read it - * WITHOUT calling `createTool()` / `build()`. This avoids resolving the full - * dependency tree (e.g. graph nodes with request-scoped dependencies) at - * application startup. - */ - description?: string; - /** - * Zod schema for the tool's input, exposed at Mastra agent registration time. - * - * Same rationale as `description` — must be accessible without a `createTool()` call. - * Type is `unknown` to avoid a hard dependency on `zod` in consuming code; cast to - * `ZodObject` at the call site. - */ - inputSchema?: unknown; - createTool?(config: RunnableConfig): Promise; - /** - * @deprecated Use `createTool()`. - */ - build?(config: RunnableConfig): Promise; - getValue?(result: unknown): string; - getMetadata?(result: unknown): AnyObject; - needsReview?: boolean; -} - -/** - * Resolves the executable function for a node, preferring `execute` and falling back to `createStep`. - */ -export async function resolveNodeExecution( - node: IGraphNode, -): Promise> { - if (node.execute) { - return node.execute.bind(node); - } - - if (node.createStep) { - const step = await node.createStep(); - return step.execute.bind(step); - } - - throw new Error('Node must implement either execute() or createStep().'); -} - -/** - * Resolves a runtime tool from the migrated contract while preserving legacy fallback. - */ -export async function resolveGraphTool( - tool: IGraphTool, - config: RunnableConfig, -): Promise { - if (tool.createTool) { - return tool.createTool(config); - } - - if (tool.build) { - return tool.build(config); - } - - throw new Error(`Tool ${tool.key} does not implement createTool().`); -} - -export type IGraphDirectEdge = { - from: string; - to: string; -}; - -export type IGraphConditionalEdge = { - from: string; - toList: string[]; - branchingFunction(state: T): string; -}; - -export type IGraphEdge = - | IGraphDirectEdge - | IGraphConditionalEdge; - -export enum ToolStatus { - Running = 'running', - Completed = 'completed', - Failed = 'failed', -} diff --git a/src/index.ts b/src/index.ts index d2358ec..b16c7ee 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,11 +3,12 @@ export * from './components'; export * from './constant'; export * from './controllers'; export * from './decorators'; -export * from './graphs'; export * from './keys'; export * from './mastra'; export * from './providers'; export * from './services'; export * from './transports'; export * from './types'; +export * from './types/events'; +export * from './types/tool'; export * from './utils'; diff --git a/src/keys.ts b/src/keys.ts index cb0beb2..dae91b8 100644 --- a/src/keys.ts +++ b/src/keys.ts @@ -1,4 +1,3 @@ -import {VectorStore as VectorStoreType} from '@langchain/core/vectorstores'; import {BindingKey} from '@loopback/context'; import { IMastraBridge, @@ -7,10 +6,12 @@ import { import {ITransport} from './transports/types'; import { AIIntegrationConfig, + AiSdkEmbeddingModel as AiSdkEmbeddingModelType, EmbeddingProvider, ICache, + IVectorStore, IWorkflowPersistence, - RuntimeLLMProvider, + LLMProvider, ToolStore, } from './types'; import {ILimitStrategy} from './services/limit-strategies/types'; @@ -19,24 +20,76 @@ export namespace AiIntegrationBindings { export const Config = BindingKey.create( 'services.ai-reporting.config', ); - export const SmartLLM = BindingKey.create( + /** + * @deprecated Use `AiSdkSmartLLM` for the Mastra/AI SDK execution path. + */ + export const SmartLLM = BindingKey.create( 'services.ai-reporting.smartLLMProvider', ); - export const CheapLLM = BindingKey.create( + /** + * @deprecated Use `AiSdkCheapLLM` for the Mastra/AI SDK execution path. + */ + export const CheapLLM = BindingKey.create( 'services.ai-reporting.cheapLLMProvider', ); - export const FileLLM = BindingKey.create( + /** + * @deprecated Use `AiSdkFileLLM` for the Mastra/AI SDK execution path. + */ + export const FileLLM = BindingKey.create( 'services.ai-reporting.fileLLMProvider', ); - export const ChatLLM = BindingKey.create( + /** + * @deprecated Use `AiSdkFileLLM` for the Mastra/AI SDK execution path. + */ + export const ChatLLM = BindingKey.create( 'services.ai-reporting.chatLLMProvider', ); - export const SmartNonThinkingLLM = BindingKey.create( + /** + * @deprecated Use `AiSdkSmartNonThinkingLLM` for the Mastra/AI SDK execution path. + */ + export const SmartNonThinkingLLM = BindingKey.create( 'services.ai-reporting.smartNonThinkingLLMProvider', ); + /** + * AI SDK (`LanguageModel`) bindings — used by Mastra path nodes (Phase 3+). + * Bind these to AI SDK provider instances (e.g. from `lb4-llm-chat-component/openai` + * using `OpenAISdk`, etc.). The legacy `SmartLLM` / `CheapLLM` bindings remain + * for the LangGraph path and are unaffected. + */ + export const AiSdkSmartLLM = BindingKey.create( + 'services.ai-reporting.aiSdkSmartLLMProvider', + ); + export const AiSdkCheapLLM = BindingKey.create( + 'services.ai-reporting.aiSdkCheapLLMProvider', + ); + export const AiSdkFileLLM = BindingKey.create( + 'services.ai-reporting.aiSdkFileLLMProvider', + ); + export const AiSdkSmartNonThinkingLLM = BindingKey.create( + 'services.ai-reporting.aiSdkSmartNonThinkingLLMProvider', + ); + /** + * AI SDK (`LanguageModel`) binding for the chat execution path. + * + * Bind a `LanguageModel` instance here. Used by `MastraChatAgent` to run + * the conversational LLM loop directly via `streamText()`. + */ + export const AiSdkChatLLM = BindingKey.create( + 'services.ai-reporting.aiSdkChatLLMProvider', + ); export const EmbeddingModel = BindingKey.create( 'services.ai-reporting.embeddingModel', ); + /** + * AI SDK embedding model binding for the Mastra execution path. + * + * Bind an `EmbeddingModel` (e.g. from `@ai-sdk/openai`) here. + * Used by `PgVectorSdkStore` to compute document and query embeddings + * without any LangChain dependency. + */ + export const AiSdkEmbeddingModel = BindingKey.create( + 'services.ai-reporting.aiSdkEmbeddingModel', + ); export const WorkflowPersistence = BindingKey.create( 'services.ai-reporting.workflow-persistence', ); @@ -50,9 +103,22 @@ export namespace AiIntegrationBindings { export const Transport = BindingKey.create( 'services.ai-reporting.transport', ); - export const VectorStore = BindingKey.create( + /** + * @deprecated Use `AiSdkVectorStore` for the Mastra/AI SDK execution path. + */ + export const VectorStore = BindingKey.create( 'services.ai-reporting.vector-store', ); + /** + * Mastra-path vector store binding. + * + * Bind a `PgVectorSdkStore` (or any `IVectorStore` implementation) here for use + * by `DatasetSearchService` and `TemplateSearchService` in the Mastra execution path. + * The LangGraph path continues to use `VectorStore` above. + */ + export const AiSdkVectorStore = BindingKey.create( + 'services.ai-reporting.aiSdkVectorStore', + ); export const Cache = BindingKey.create('services.ai-reporting.cache'); export const LimitStrategy = BindingKey.create( 'services.ai-reporting.limit-strategy', @@ -60,6 +126,15 @@ export namespace AiIntegrationBindings { export const ObfHandler = BindingKey.create( 'services.ai-reporting.obf-handler', ); + /** + * Mastra-path Langfuse client binding. + * + * Registered automatically by `LangfuseMastraComponent`. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + export const LangfuseMastraClient = BindingKey.create( + 'services.ai-reporting.langfuse-mastra-client', + ); export const MastraBridge = BindingKey.create( 'services.ai-reporting.mastra-bridge', ); diff --git a/src/mastra/chat/mappers/message.mapper.ts b/src/mastra/chat/mappers/message.mapper.ts deleted file mode 100644 index 759d8f3..0000000 --- a/src/mastra/chat/mappers/message.mapper.ts +++ /dev/null @@ -1,60 +0,0 @@ -import {AIMessage, BaseMessage, ToolMessage} from '@langchain/core/messages'; -import {getTextContent} from '../../../utils'; -import {MastraAgentMessage, MastraAssistantContentPart} from '../../types'; - -/** - * Converts a LangChain `BaseMessage[]` to the `MastraAgentMessage[]` format - * expected by `IMastraChatAgentRunnable.stream()`. - * - * The output shapes are compatible with the AI SDK `CoreMessage` type so - * real `@mastra/core` Agent instances accept them without further adaptation. - */ -export function toMastraMessages( - messages: BaseMessage[], -): MastraAgentMessage[] { - const result: MastraAgentMessage[] = []; - for (const msg of messages) { - const type = msg._getType(); - if (type === 'system') { - result.push({role: 'system', content: getTextContent(msg.content)}); - } else if (type === 'human') { - result.push({role: 'user', content: getTextContent(msg.content)}); - } else if (type === 'ai') { - const aiMsg = msg as AIMessage; - if (aiMsg.tool_calls?.length) { - const parts: MastraAssistantContentPart[] = []; - const text = getTextContent(aiMsg.content); - if (text.trim()) parts.push({type: 'text', text}); - for (const tc of aiMsg.tool_calls) { - parts.push({ - type: 'tool-call', - toolCallId: tc.id ?? '', - toolName: tc.name, - args: tc.args as Record, - }); - } - result.push({role: 'assistant', content: parts}); - } else { - result.push({ - role: 'assistant', - content: getTextContent(aiMsg.content), - }); - } - } else if (type === 'tool') { - const toolMsg = msg as ToolMessage; - result.push({ - role: 'tool', - content: [ - { - type: 'tool-result', - toolCallId: toolMsg.tool_call_id, - toolName: toolMsg.name ?? '', - result: toolMsg.content, - }, - ], - }); - } - // Other message types are dropped — not used in this flow - } - return result; -} diff --git a/src/mastra/chat/mastra-chat.agent.ts b/src/mastra/chat/mastra-chat.agent.ts index d5f1a99..3848b11 100644 --- a/src/mastra/chat/mastra-chat.agent.ts +++ b/src/mastra/chat/mastra-chat.agent.ts @@ -1,31 +1,24 @@ -import {BaseMessage, HumanMessage} from '@langchain/core/messages'; import {inject, injectable, BindingScope, service} from '@loopback/core'; -import {HttpErrors} from '@loopback/rest'; -import {LLMStreamEvent, LLMStreamEventType} from '../../graphs/event.types'; -import {ChatStore} from '../../graphs/chat/chat.store'; +import {streamText, tool, ToolSet, stepCountIs} from 'ai'; +import {LLMStreamEvent, LLMStreamEventType} from '../../types/events'; +import {ChatStore} from '../../services/chat.store'; import {AiIntegrationBindings} from '../../keys'; -import {IMastraBridge} from '../../services/mastra-bridge.service'; -import {AIIntegrationConfig, RuntimeLLMProvider, ToolStore} from '../../types'; -import {IMastraChatAgentRunnable} from '../types'; -import { - mastraRequestToolStore, - mastraRequestWriterStore, -} from '../request-tool-store'; +import {AIIntegrationConfig, LLMProvider, ToolStore} from '../../types'; +import {MastraAgentMessage} from '../types'; +import {IRuntimeTool} from '../../types/tool'; +import {z} from 'zod'; +import {normalizeMessages} from './utils/normalize-messages.util'; +import {adaptStreamResult} from './utils/adapt-stream.util'; +import {mastraRequestWriterStore} from '../request-tool-store'; import {compressContextIfNeeded} from './steps/context-compression.step'; import {initSession} from './steps/init-session.step'; import {summariseOneFile} from './steps/summarise-file.step'; import {handleStream} from './steps/stream-handler.step'; -import {toMastraMessages} from './mappers/message.mapper'; import {accumulateUsage} from './utils/token-accumulator.util'; import {TokenAccumulator} from './types/chat.types'; const debug = require('debug')('ai-integration:mastra:chat-agent'); -/** - * Registered name used to retrieve the chat agent via the Mastra bridge. - */ -export const MASTRA_CHAT_AGENT_NAME = 'chat-agent'; - /** * Mastra-runtime chat execution service. * @@ -34,18 +27,17 @@ export const MASTRA_CHAT_AGENT_NAME = 'chat-agent'; * 1. `initSession` — load/create chat, persist human message, build history * 2. `summariseOneFile`* — pre-process uploaded files before the agent sees them * 3. `compressContextIfNeeded`— trim history if it exceeds the token budget - * 4. `toMastraMessages` — convert LangChain messages → Mastra/AI SDK format - * 5. Bridge agent execution — Mastra Agent owns the CallLLM ↔ RunTool loop - * 6. `handleStream` — adapt Mastra events → LLMStreamEvent, persist steps - * 7. EndSession — emit TokenCount, update DB + * 4. `streamText` execution — library calls AI SDK directly, owns the LLM ↔ tool loop + * 5. `handleStream` — adapt AI SDK events → LLMStreamEvent, persist steps + * 6. EndSession — emit TokenCount, update DB */ @injectable({scope: BindingScope.REQUEST}) export class MastraChatAgent { constructor( - @inject(AiIntegrationBindings.MastraBridge) - private readonly mastraBridge: IMastraBridge, - @inject(AiIntegrationBindings.FileLLM) - private readonly fileLLM: RuntimeLLMProvider, + @inject(AiIntegrationBindings.AiSdkChatLLM) + private readonly chatLLM: LLMProvider, + @inject(AiIntegrationBindings.AiSdkFileLLM) + private readonly fileLLM: LLMProvider, @inject(AiIntegrationBindings.Config) private readonly aiConfig: AIIntegrationConfig, @inject(AiIntegrationBindings.Tools) @@ -107,35 +99,17 @@ export class MastraChatAgent { } // ── Step 3: Build message list + fallback context compression ──────────── - const rawMessages: BaseMessage[] = [ + const rawMessages: MastraAgentMessage[] = [ ...baseMessages, - new HumanMessage({content: finalPrompt}), + {role: 'user', content: finalPrompt}, ]; const compressedMessages = await compressContextIfNeeded( rawMessages, this.aiConfig.maxTokenCount, ); - // ── Step 4: Map messages to Mastra format ───────────────────────────────── - const agentMessages = toMastraMessages(compressedMessages); - - // ── Step 5: Obtain agent from bridge ───────────────────────────────────── - const agent = this.mastraBridge.getTypedAgent( - MASTRA_CHAT_AGENT_NAME, - ); - if (!agent) { - throw new HttpErrors.NotImplemented( - `Mastra chat agent '${MASTRA_CHAT_AGENT_NAME}' is not registered. ` + - 'Bind a MastraRuntimeFactory at AiIntegrationBindings.MastraRuntimeFactory ' + - "that registers a chat agent under the name 'chat-agent'.", - ); - } - - // ── Step 5b: Build per-request tool map and register for bridge tools ───── - const requestToolMap = new Map< - string, - import('../../graphs/types').IRuntimeTool - >(); + // ── Step 4: Build per-request tool map ────────────────────────────────── + const requestToolMap = new Map(); for (const graphTool of this.tools.list) { try { // Build tools at request time (they may have request-scoped dependencies). @@ -148,15 +122,13 @@ export class MastraChatAgent { // Wrap invoke to inject the lazy writer into the LangGraph config so // internal graph nodes (e.g. RenderVisualizationNode, SaveDatasetNode) // get config.writer and their ToolStatus events reach the SSE stream. - // Tools built with createTool() ignore the config param, but LangGraph + // Tools built with createTool() ignore the config param, but the // tool.invoke(input, { writer }) passes it straight to every node. const lazyWriter = { writer: (event: unknown) => - mastraRequestWriterStore.get(chatId)?.( - event as import('../../graphs/event.types').LLMStreamEvent, - ), + mastraRequestWriterStore.get(chatId)?.(event as LLMStreamEvent), }; - const wrappedRt: import('../../graphs/types').IRuntimeTool = { + const wrappedRt: IRuntimeTool = { name: rt.name, description: rt.description, schema: rt.schema, @@ -164,13 +136,10 @@ export class MastraChatAgent { invoke: (input: unknown) => (rt as any).invoke(input, lazyWriter), }; - // Store by kebab key (e.g. 'get-data-as-dataset') — LangGraph path + // Keyed by graphTool.key (e.g. 'get-data-as-dataset'). + // This is the name the LLM will use when calling the tool, and what + // handleStream uses to look up display values. requestToolMap.set(graphTool.key, wrappedRt); - // Also store by class name (e.g. 'GetDataAsDatasetTool') — Mastra factory path - const className = (graphTool as object).constructor?.name; - if (className && className !== graphTool.key) { - requestToolMap.set(className, wrappedRt); - } } } catch (err) { debug( @@ -180,25 +149,49 @@ export class MastraChatAgent { ); } } - mastraRequestToolStore.set(chatId, requestToolMap); debug( - 'Registered %d tools for chatId %s: %s', + 'Built %d tools for chatId %s: %s', requestToolMap.size, chatId, [...requestToolMap.keys()].join(', '), ); - // ── Step 6: Stream from bridge agent, adapt events to LLMStreamEvent ───── + // ── Step 5: Build AI SDK tool set and call streamText() directly ────────── + const aiTools: ToolSet = {}; + for (const [toolName, rt] of requestToolMap) { + if (aiTools[toolName]) continue; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const inputSchema: z.ZodTypeAny = (rt.schema as any) ?? z.object({}); + aiTools[toolName] = tool({ + description: rt.description ?? toolName, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + inputSchema: inputSchema as any, + execute: async (input: unknown) => { + try { + return await rt.invoke(input); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : String(err); + debug('Tool "%s" failed: %s', toolName, msg); + return {error: true, message: msg}; + } + }, + }); + } + debug( - 'Delegating to Mastra bridge agent — %d messages, %d tools', - agentMessages.length, - requestToolMap.size, + 'Calling streamText() directly — %d messages, %d tools', + compressedMessages.length, + Object.keys(aiTools).length, ); try { - const agentStream = await agent.stream(agentMessages, { - signal: abort, - threadId: chatId, + const streamResult = streamText({ + model: this.chatLLM, + messages: normalizeMessages(compressedMessages), + tools: aiTools, + stopWhen: stepCountIs(10), + abortSignal: abort, }); + const agentStream = adaptStreamResult(streamResult); for await (const event of handleStream({ agentStream, @@ -246,8 +239,7 @@ export class MastraChatAgent { tokens.map, ); } finally { - // Always clean up per-request stores so memory doesn't leak. - mastraRequestToolStore.delete(chatId); + // Clean up the per-request writer store so memory doesn't leak. mastraRequestWriterStore.delete(chatId); } } diff --git a/src/mastra/chat/steps/context-compression.step.ts b/src/mastra/chat/steps/context-compression.step.ts index 0e19a11..2de91a9 100644 --- a/src/mastra/chat/steps/context-compression.step.ts +++ b/src/mastra/chat/steps/context-compression.step.ts @@ -1,8 +1,8 @@ -import {BaseMessage, trimMessages} from '@langchain/core/messages'; import {DEFAULT_MAX_TOKEN_COUNT} from '../../../constant'; import {approxTokenCounter} from '../../../utils'; +import {MastraAgentMessage} from '../../types'; -const debug = require('debug')('ai-integration:mastra:chat-agent'); +const debug = require('debug')('mastra:chat:context-compression'); /** * Mirrors `ContextCompressionNode`: trims the message list to `maxTokenCount` @@ -18,9 +18,9 @@ const debug = require('debug')('ai-integration:mastra:chat-agent'); * `MAX_TOKEN_COUNT` env var or the package default. */ export async function compressContextIfNeeded( - messages: BaseMessage[], + messages: MastraAgentMessage[], maxTokenCount: number | undefined, -): Promise { +): Promise { const limit = +( maxTokenCount ?? process.env.MAX_TOKEN_COUNT ?? @@ -30,18 +30,43 @@ export async function compressContextIfNeeded( (acc, m) => acc + approxTokenCounter(m.content), 0, ); - if (tokenCount > limit) { - debug( - 'Compressing context before agent call: %d tokens > limit %d', - tokenCount, - limit, - ); - return trimMessages(messages, { - maxTokens: limit, - strategy: 'last', - tokenCounter: approxTokenCounter, - includeSystem: true, - }); + + debug('Original messages: %d', messages.length); + + if (tokenCount <= limit) { + return messages; + } + + debug( + 'Compressing context before agent call: %d tokens > limit %d', + tokenCount, + limit, + ); + + // Always keep the system message, then fill remaining budget with the most + // recent messages (reverse iteration = newest first). + const systemMessages = messages.filter(m => m.role === 'system'); + const nonSystemMessages = messages.filter(m => m.role !== 'system'); + + const systemTokens = systemMessages.reduce( + (acc, m) => acc + approxTokenCounter(m.content), + 0, + ); + let remaining = limit - systemTokens; + const kept: MastraAgentMessage[] = []; + + for (let i = nonSystemMessages.length - 1; i >= 0; i--) { + const msg = nonSystemMessages[i]; + const tokens = approxTokenCounter(msg.content); + if (tokens > remaining) break; + kept.unshift(msg); + remaining -= tokens; } - return messages; + + const trimmed = [...systemMessages, ...kept]; + + debug('Trimmed messages: %d', trimmed.length); + debug('Token count: %d', limit - remaining); + + return trimmed; } diff --git a/src/mastra/chat/steps/init-session.step.ts b/src/mastra/chat/steps/init-session.step.ts index b361ad9..55b35b4 100644 --- a/src/mastra/chat/steps/init-session.step.ts +++ b/src/mastra/chat/steps/init-session.step.ts @@ -1,17 +1,15 @@ -import { - BaseMessage, - HumanMessage, - SystemMessage, -} from '@langchain/core/messages'; -import {ChatStore} from '../../../graphs/chat/chat.store'; +import {ChatStore, SavedMessage} from '../../../services/chat.store'; import {Message} from '../../../models'; +import {MastraAgentMessage} from '../../types'; + +const debug = require('debug')('ai-integration:mastra:chat:init-session'); /** * Result returned by `initSession`. */ export interface InitSessionResult { chatId: string; - baseMessages: BaseMessage[]; + baseMessages: MastraAgentMessage[]; userMessage: Message; } @@ -30,13 +28,16 @@ export async function initSession( chatStore: ChatStore, systemPrompt: string, ): Promise { + debug('step start', {id, promptLength: prompt.length}); const chat = await chatStore.init(prompt, id); - const savedUserMessage = await chatStore.addHumanMessage( - chat.id, - new HumanMessage({content: prompt}), - ); + debug('chat initialised: %s (new=%s)', chat.id, !id); + const savedUserMessage = await chatStore.addHumanMessage(chat.id, prompt); const history = await formatHistory(chat.messages ?? [], chatStore); - const systemMessage = new SystemMessage({content: systemPrompt}); + const systemMessage: MastraAgentMessage = { + role: 'system', + content: systemPrompt, + }; + debug('history loaded: %d messages', history.length); return { chatId: chat.id, baseMessages: [systemMessage, ...history], @@ -45,15 +46,15 @@ export async function initSession( } /** - * Converts DB `Message` rows back to LangChain `BaseMessage` instances. + * Converts DB `Message` rows back to `MastraAgentMessage` instances. * Undefined entries (unsupported message roles) are filtered out. */ async function formatHistory( dbMessages: Message[], chatStore: ChatStore, -): Promise { +): Promise { const converted = await Promise.all( dbMessages.map(m => chatStore.toMessage(m)), ); - return converted.filter((m): m is BaseMessage => m !== undefined); + return converted.filter((m): m is SavedMessage => m !== undefined); } diff --git a/src/mastra/chat/steps/save-step.step.ts b/src/mastra/chat/steps/save-step.step.ts index b3eee8f..44c4414 100644 --- a/src/mastra/chat/steps/save-step.step.ts +++ b/src/mastra/chat/steps/save-step.step.ts @@ -1,5 +1,4 @@ -import {AIMessage, ToolMessage} from '@langchain/core/messages'; -import {ChatStore} from '../../../graphs/chat/chat.store'; +import {ChatStore} from '../../../services/chat.store'; import {ToolStore} from '../../../types'; import {StepBuffer} from '../types/chat.types'; @@ -26,19 +25,14 @@ export async function saveStep( const hasToolCalls = step.toolCalls.length > 0; if (!text.trim() && !hasToolCalls) return; - const aiMsg = new AIMessage({ - content: text || ' ', - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_calls: hasToolCalls - ? step.toolCalls.map(tc => ({ - id: tc.id, - name: tc.name, - args: tc.args, - type: 'tool_call' as const, - })) - : [], - }); - const savedAiMsg = await chatStore.addAIMessage(chatId, aiMsg); + const toolCallsForAi = hasToolCalls + ? step.toolCalls.map(tc => ({id: tc.id, name: tc.name, args: tc.args})) + : []; + const savedAiMsg = await chatStore.addAIMessage( + chatId, + text || ' ', + toolCallsForAi, + ); for (const toolCall of step.toolCalls) { const toolResult = step.toolResults.get(toolCall.id); @@ -49,15 +43,11 @@ export async function saveStep( } const output = toolDef?.getValue?.(toolResult.result) ?? toolResult.result; const metadata = toolDef?.getMetadata?.(toolResult.result) ?? {}; - const toolMsg = new ToolMessage({ - name: toolCall.name, - content: String(output), - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_call_id: toolCall.id, - }); await chatStore.addToolMessage( chatId, - toolMsg, + toolCall.id, + toolCall.name, + String(output), metadata, savedAiMsg, toolCall.args, diff --git a/src/mastra/chat/steps/stream-handler.step.ts b/src/mastra/chat/steps/stream-handler.step.ts index eaae622..bf17b7d 100644 --- a/src/mastra/chat/steps/stream-handler.step.ts +++ b/src/mastra/chat/steps/stream-handler.step.ts @@ -1,7 +1,7 @@ -import {LLMStreamEvent, LLMStreamEventType} from '../../../graphs/event.types'; -import {ToolStatus} from '../../../graphs/types'; +import {LLMStreamEvent, LLMStreamEventType} from '../../../types/events'; +import {ToolStatus} from '../../../types/tool'; import {ToolStore} from '../../../types'; -import {ChatStore} from '../../../graphs/chat/chat.store'; +import {ChatStore} from '../../../services/chat.store'; import {MastraAgentStreamOutput} from '../../types'; import {StepBuffer, TokenAccumulator} from '../types/chat.types'; import {accumulateUsage} from '../utils/token-accumulator.util'; diff --git a/src/mastra/chat/steps/summarise-file.step.ts b/src/mastra/chat/steps/summarise-file.step.ts index 2690d46..7315a34 100644 --- a/src/mastra/chat/steps/summarise-file.step.ts +++ b/src/mastra/chat/steps/summarise-file.step.ts @@ -1,11 +1,14 @@ -import {AIMessage} from '@langchain/core/messages'; -import {ChatStore} from '../../../graphs/chat/chat.store'; +import {generateText} from 'ai'; +import {ChatStore} from '../../../services/chat.store'; import {Message} from '../../../models'; -import {resolveLegacyLLMProvider, RuntimeLLMProvider} from '../../../types'; -import {mergeAttachments, stripThinkingTokens} from '../../../utils'; +import {LLMProvider} from '../../../types'; +import {mergeAttachments} from '../../../utils'; +import {stripThinkingFromText} from '../../db-query/utils/thinking.util'; import {TokenAccumulator} from '../types/chat.types'; import {accumulateUsage} from '../utils/token-accumulator.util'; +const debug = require('debug')('mastra:chat:summarise-file'); + /** * Parameters for `summariseOneFile`. */ @@ -16,15 +19,15 @@ export interface SummariseFileParams { userMessage: Message; /** Mutated in place — updated with file-summarisation token usage. */ tokens: TokenAccumulator; - fileLLM: RuntimeLLMProvider; + fileLLM: LLMProvider; chatStore: ChatStore; } /** * Mirrors `SummariseFileNode` for a single file. * - * Invokes the file LLM to produce a concise summary of the file in the - * context of the user prompt, persists the attachment message, and returns + * Uses `generateText()` from the Vercel AI SDK instead of the LangChain + * `BaseChatModel.invoke()` path. Persists the attachment message and returns * an updated prompt that embeds the summary. */ export async function summariseOneFile( @@ -33,33 +36,40 @@ export async function summariseOneFile( const {file, currentPrompt, chatId, userMessage, tokens, fileLLM, chatStore} = params; - const llm = resolveLegacyLLMProvider(fileLLM); - const fileContent = buildFileContent(file, fileLLM); - const messages = [ - { - role: 'system' as const, - content: buildFileSummaryPrompt(currentPrompt), - }, - { - role: 'user' as const, - content: [{type: 'text', text: currentPrompt}, fileContent], - }, - ]; + debug('Summarising file: %s', file.originalname); + debug('Prompt length: %d', currentPrompt.length); + + const fileData = file.buffer?.toString('base64') ?? ''; - const aiResponse = (await llm.invoke(messages)) as AIMessage; - const usage = aiResponse.usage_metadata; - if (usage) { - accumulateUsage( + const {text, usage} = await generateText({ + model: fileLLM, + messages: [ { - promptTokens: usage.input_tokens, - completionTokens: usage.output_tokens, + role: 'user', + content: [ + {type: 'text', text: buildFileSummaryPrompt(currentPrompt)}, + { + type: 'file', + data: fileData, + mediaType: file.mimetype, + }, + ], }, - 'mastra-file', - tokens, - ); - } + ], + }); - const summary = stripThinkingTokens(aiResponse); + debug('Summary generated, tokens: %o', usage); + + accumulateUsage( + { + promptTokens: usage.inputTokens, + completionTokens: usage.outputTokens, + }, + 'mastra-file', + tokens, + ); + + const summary = stripThinkingFromText(text); await chatStore.addAttachmentMessage(chatId, userMessage, file, summary); return mergeAttachments(currentPrompt, file.originalname, summary); } @@ -80,20 +90,3 @@ function buildFileSummaryPrompt(userPrompt: string): string { ${userPrompt} `; } - -function buildFileContent( - file: Express.Multer.File, - fileLLM: RuntimeLLMProvider, -): object { - if (fileLLM.getFile) { - return fileLLM.getFile(file); - } - return { - type: 'file', - // eslint-disable-next-line @typescript-eslint/naming-convention - source_type: 'base64', - data: file.buffer?.toString('base64') ?? '', - // eslint-disable-next-line @typescript-eslint/naming-convention - mime_type: file.mimetype, - }; -} diff --git a/src/mastra/chat/types/chat.types.ts b/src/mastra/chat/types/chat.types.ts index 9a65026..45f2eef 100644 --- a/src/mastra/chat/types/chat.types.ts +++ b/src/mastra/chat/types/chat.types.ts @@ -1,4 +1,4 @@ -import {LLMStreamEvent} from '../../../graphs/event.types'; +import {LLMStreamEvent} from '../../../types/events'; import {TokenMetadata} from '../../../types'; /** diff --git a/src/mastra/chat/utils/adapt-stream.util.ts b/src/mastra/chat/utils/adapt-stream.util.ts new file mode 100644 index 0000000..bda34f2 --- /dev/null +++ b/src/mastra/chat/utils/adapt-stream.util.ts @@ -0,0 +1,103 @@ +import {streamText} from 'ai'; +import {MastraAgentStreamOutput, MastraStreamEvent} from '../../types'; + +/** + * Adapts the AI SDK v6 `StreamTextResult.fullStream` into the + * `MastraAgentStreamOutput` shape consumed by `handleStream()`. + * + * This allows `MastraChatAgent` to call `streamText()` directly while + * reusing the existing `handleStream` event-processing logic unchanged. + * + * ### AI SDK v6 → Mastra event mapping + * | AI SDK v6 event type | Mapped to Mastra `type` | Notes | + * |----------------------|--------------------------|--------------------------------------| + * | `text-delta` | `text-delta` | `part.text` → `payload.text` | + * | `tool-call` | `tool-call` | `part.input` → `payload.args` | + * | `tool-result` | `tool-result` | `part.output` → `payload.result` | + * | `finish-step` | `step-finish` | usage from `part.usage` | + * | `finish` | `finish` | usage from `part.totalUsage` | + */ +export function adaptStreamResult( + result: ReturnType, +): MastraAgentStreamOutput { + return { + fullStream: adaptFullStream(result), + // result.usage is a PromiseLike — wrap in Promise for interface compatibility + usage: Promise.resolve(result.usage).then(u => ({ + inputTokens: u?.inputTokens, + outputTokens: u?.outputTokens, + })), + }; +} + +async function* adaptFullStream( + result: ReturnType, +): AsyncGenerator { + for await (const part of result.fullStream) { + switch (part.type) { + case 'text-delta': + yield { + type: 'text-delta', + payload: {text: part.text, id: (part as {id?: string}).id ?? ''}, + }; + break; + + case 'tool-call': + yield { + type: 'tool-call', + payload: { + toolCallId: part.toolCallId, + toolName: part.toolName, + // `handleStream` reads `payload.args`; AI SDK v6 calls this `input` + args: part.input, + }, + }; + break; + + case 'tool-result': + yield { + type: 'tool-result', + payload: { + toolCallId: part.toolCallId, + toolName: part.toolName, + // `handleStream` reads `payload.result`; AI SDK v6 calls this `output` + result: part.output, + args: part.input, + }, + }; + break; + + case 'finish-step': + yield { + type: 'step-finish', + payload: { + output: { + usage: { + inputTokens: part.usage?.inputTokens, + outputTokens: part.usage?.outputTokens, + }, + }, + }, + }; + break; + + case 'finish': + yield { + type: 'finish', + payload: { + output: { + usage: { + inputTokens: part.totalUsage?.inputTokens, + outputTokens: part.totalUsage?.outputTokens, + }, + }, + }, + }; + break; + + default: + // Unknown event types are silently dropped + break; + } + } +} diff --git a/src/mastra/chat/utils/normalize-messages.util.ts b/src/mastra/chat/utils/normalize-messages.util.ts new file mode 100644 index 0000000..cf31053 --- /dev/null +++ b/src/mastra/chat/utils/normalize-messages.util.ts @@ -0,0 +1,85 @@ +import {MastraAgentMessage} from '../../types'; + +/** + * Converts Mastra-format messages into AI SDK v6 ModelMessage format. + * + * Key differences handled: + * - `tool-call` parts: `args` (Mastra) → `input` (AI SDK v6) + * - `tool-result` parts: `result` (Mastra) → `output: {type, value}` (AI SDK v6) + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function normalizeMessages(messages: MastraAgentMessage[]): any[] { + return messages.map(msg => { + const {role, content} = msg; + + // System / user messages: content must be a string + if (role === 'system' || role === 'user') { + return { + role, + content: typeof content === 'string' ? content : String(content ?? ''), + }; + } + + // Assistant messages: convert tool-call parts (args → input) + if (role === 'assistant') { + if (typeof content === 'string') { + return {role, content}; + } + if (Array.isArray(content)) { + const parts = content.map( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (part: Record) => { + if (part.type === 'tool-call') { + return { + type: 'tool-call', + toolCallId: part.toolCallId, + toolName: part.toolName, + // AI SDK v6 uses `input`; Mastra stores `args` + input: part.input ?? part.args ?? {}, + }; + } + if (part.type === 'text') { + return {type: 'text', text: part.text ?? ''}; + } + return part; + }, + ); + return {role, content: parts}; + } + return { + role, + content: typeof content === 'string' ? content : String(content ?? ''), + }; + } + + // Tool messages: convert tool-result parts (result → output) + if (role === 'tool') { + if (Array.isArray(content)) { + const parts = content.map( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (part: Record) => { + if (part.type === 'tool-result') { + // AI SDK v6 uses `output`; Mastra stores `result` + const rawResult = part.output ?? part.result; + const output = + typeof rawResult === 'string' + ? {type: 'text', value: rawResult} + : {type: 'json', value: rawResult}; + return { + type: 'tool-result', + toolCallId: part.toolCallId, + toolName: part.toolName, + output, + }; + } + return part; + }, + ); + return {role, content: parts}; + } + return {role, content}; + } + + return msg; + }); +} diff --git a/src/mastra/db-query/index.ts b/src/mastra/db-query/index.ts new file mode 100644 index 0000000..da272c5 --- /dev/null +++ b/src/mastra/db-query/index.ts @@ -0,0 +1,15 @@ +export {MastraDbQueryWorkflow} from './mastra-db-query.workflow'; +export type { + DbQueryWorkflowInput, + DbQueryWriterFn, + MastraDbQueryContext, +} from './types/db-query.types'; + +// Mastra-path utilities +export * from './utils'; + +// Mastra-path services +export * from './services'; + +// Workflow steps, conditions, and routing helpers +export * from './workflow'; diff --git a/src/mastra/db-query/mastra-db-query.workflow.ts b/src/mastra/db-query/mastra-db-query.workflow.ts new file mode 100644 index 0000000..cff5cf4 --- /dev/null +++ b/src/mastra/db-query/mastra-db-query.workflow.ts @@ -0,0 +1,470 @@ +import {BindingScope, inject, injectable, service} from '@loopback/core'; +import {IAuthUserWithPermissions} from '@sourceloop/core'; +import {AuthenticationBindings} from 'loopback4-authentication'; +import { + DataSetHelper, + DbSchemaHelperService, + PermissionHelper, +} from '../../components/db-query/services'; +import {SchemaStore} from '../../components/db-query/services/schema.store'; +import {TableSearchService} from '../../components/db-query/services/search/table-search.service'; +import {DbQueryState} from '../../components/db-query/state'; +import { + DbQueryConfig, + Errors, + GenerationError, + IDataSetStore, + IDbConnector, +} from '../../components/db-query/types'; +import {DbQueryAIExtensionBindings} from '../../components/db-query/keys'; +import {MAX_ATTEMPTS} from '../../components/db-query/constant'; +import {AiIntegrationBindings} from '../../keys'; +import {LLMProvider} from '../../types'; +import {TokenCounter} from '../../services/token-counter.service'; +import {DatasetSearchService} from './services/dataset-search.service'; +import {MastraTemplateHelperService} from './services/mastra-template-helper.service'; +import {TemplateSearchService} from './services/template-search.service'; +import { + DbQueryWorkflowInput, + MastraDbQueryContext, +} from './types/db-query.types'; +import { + checkPostCacheAndTablesConditions, + checkPostValidationConditions, + mergeValidationResults, + failedStep, + isImprovementStep, + classifyChangeStep, + checkCacheStep, + checkPermissionsStep, + checkTemplatesStep, + getTablesStep, + getColumnsStep, + generateChecklistStep, + verifyChecklistStep, + sqlGenerationStep, + syntacticValidatorStep, + semanticValidatorStep, + generateDescriptionStep, + fixQueryStep, + saveDatasetStep, +} from './workflow'; + +const debug = require('debug')('mastra:db-query:workflow'); + +/** + * Mastra-path imperative workflow for the DbQuery feature. + * + * Injects all services directly and delegates to step functions in + * `workflow/steps/`. This eliminates class-based node wrappers and keeps + * execution units as plain async functions. + * + * Preserves 100% of the original orchestration behaviour: + * - Same parallel fan-out / fan-in with `Promise.all()` + * - Same validation-retry loop with `MAX_ATTEMPTS` guard + * - Same conditional routing (template hit, cache hit, table error, query error) + * - Same state-merging semantics (last-write-wins per field) + * + * @injectable `BindingScope.REQUEST` — one instance per HTTP request. + */ +@injectable({scope: BindingScope.REQUEST}) +export class MastraDbQueryWorkflow { + constructor( + // ── LLM providers ────────────────────────────────────────────────────── + @inject(AiIntegrationBindings.AiSdkCheapLLM) + private readonly cheapLlm: LLMProvider, + @inject(AiIntegrationBindings.AiSdkSmartLLM) + private readonly smartLlm: LLMProvider, + @inject(AiIntegrationBindings.AiSdkSmartNonThinkingLLM, {optional: true}) + private readonly smartNonThinkingLlm: LLMProvider | undefined, + + // ── Data stores & config ──────────────────────────────────────────────── + @inject(DbQueryAIExtensionBindings.DatasetStore) + private readonly datasetStore: IDataSetStore, + @inject(DbQueryAIExtensionBindings.Config) + private readonly config: DbQueryConfig, + @inject(DbQueryAIExtensionBindings.Connector) + private readonly connector: IDbConnector, + @inject(AuthenticationBindings.CURRENT_USER) + private readonly user: IAuthUserWithPermissions, + @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) + private readonly checks: string[] | undefined, + + // ── Services ──────────────────────────────────────────────────────────── + @service(DatasetSearchService) + private readonly datasetSearch: DatasetSearchService, + @service(DataSetHelper) + private readonly dataSetHelper: DataSetHelper, + @service(PermissionHelper) + private readonly permissionHelper: PermissionHelper, + @service(TemplateSearchService) + private readonly templateSearch: TemplateSearchService, + @service(MastraTemplateHelperService) + private readonly templateHelper: MastraTemplateHelperService, + @service(SchemaStore) + private readonly schemaStore: SchemaStore, + @service(TableSearchService) + private readonly tableSearch: TableSearchService, + @service(DbSchemaHelperService) + private readonly schemaHelper: DbSchemaHelperService, + @service(TokenCounter) + private readonly tokenCounter: TokenCounter, + ) {} + + /** + * Execute the full DbQuery workflow. + * + * @param input User prompt and optional datasetId / directCall flag. + * @param ctx Execution context: SSE writer and/or AbortSignal. + */ + async run( + input: DbQueryWorkflowInput, + ctx?: MastraDbQueryContext, + ): Promise { + const context: MastraDbQueryContext = { + ...ctx, + onUsage: (i, o, m) => { + this.tokenCounter.accumulate(i, o, m); + if (ctx?.onUsage) ctx.onUsage(i, o, m); + }, + }; + + debug( + 'Workflow START prompt=%s datasetId=%s', + input.prompt, + input.datasetId, + ); + + // ── Initial state ─────────────────────────────────────────────────────── + let state: DbQueryState = { + prompt: input.prompt, + datasetId: input.datasetId, + directCall: input.directCall ?? false, + schema: {tables: {}, relations: []}, + } as unknown as DbQueryState; + + // ── Step 1: IsImprovement ─────────────────────────────────────────────── + debug('Executing step: IsImprovement'); + state = this._merge( + state, + await isImprovementStep(state, context, {store: this.datasetStore}), + ); + debug('Completed step: IsImprovement'); + + // ── Step 2: Parallel fan-out ──────────────────────────────────────────── + debug( + 'Executing step: CheckCache|GetTables|CheckTemplates|ClassifyChange (parallel)', + ); + const [cachePartial, tablesPartial, templatesPartial, classifyPartial] = + await Promise.all([ + checkCacheStep(state, context, { + datasetSearch: this.datasetSearch, + llm: this.cheapLlm, + dataSetHelper: this.dataSetHelper, + }), + getTablesStep(state, context, { + llmCheap: this.cheapLlm, + llmSmart: this.smartLlm, + config: this.config, + schemaHelper: this.schemaHelper, + schemaStore: this.schemaStore, + tableSearchService: this.tableSearch, + checks: this.checks, + permissionHelper: this.permissionHelper, + }), + checkTemplatesStep(state, context, { + templateSearch: this.templateSearch, + llm: this.cheapLlm, + permissionHelper: this.permissionHelper, + templateHelper: this.templateHelper, + schemaStore: this.schemaStore, + }), + classifyChangeStep(state, context, {llm: this.cheapLlm}), + ]); + state = this._merge( + state, + cachePartial, + tablesPartial, + templatesPartial, + classifyPartial, + ); + debug( + 'Completed step: CheckCache|GetTables|CheckTemplates|ClassifyChange (parallel)', + ); + + // ── Step 3: PostCacheAndTables routing ────────────────────────────────── + const cacheCondition = checkPostCacheAndTablesConditions(state); + debug('Branch decision: %o', cacheCondition); + + if (cacheCondition === 'fromTemplate') { + debug('Executing step: SaveDataset (fromTemplate)'); + state = this._merge(state, await this._runSaveDataset(state, context)); + debug('Workflow END success=true (fromTemplate)'); + return state; + } + + if (cacheCondition === 'fromCache') { + debug('Workflow END success=true (fromCache)'); + return state; + } + + if (cacheCondition === 'failed') { + debug('Executing step: Failed (post-cache)'); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (failed post-cache)'); + return state; + } + + // ── Step 4: GetColumns ────────────────────────────────────────────────── + debug('Executing step: GetColumns'); + state = this._merge( + state, + await getColumnsStep(state, context, { + llm: this.cheapLlm, + schemaHelper: this.schemaHelper, + config: this.config, + checks: this.checks, + }), + ); + debug('Completed step: GetColumns'); + if (state.status === GenerationError.Failed) { + debug('Executing step: Failed (GetColumns)'); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (GetColumns failed)'); + return state; + } + + // ── Step 5: CheckPermissions ──────────────────────────────────────────── + debug('Executing step: CheckPermissions'); + state = this._merge( + state, + await checkPermissionsStep(state, context, { + llm: this.cheapLlm, + permissions: this.permissionHelper, + }), + ); + debug('Completed step: CheckPermissions'); + if (state.status === Errors.PermissionError) { + debug('Workflow END success=false (PermissionError)'); + return state; + } + + // ── Step 6: GenerateChecklist ─────────────────────────────────────────── + debug('Executing step: GenerateChecklist'); + state = this._merge( + state, + await generateChecklistStep(state, context, { + llm: this.cheapLlm, + config: this.config, + schemaHelper: this.schemaHelper, + checks: this.checks, + }), + ); + debug('Completed step: GenerateChecklist'); + + // ── Steps 7–10: SQL generation + validation retry loop ────────────────── + type LoopEntry = 'generate' | 'validate'; + let loopEntry: LoopEntry = 'generate'; + + const maxIterations = MAX_ATTEMPTS * 2 + 2; + let iterations = 0; + + while (iterations++ < maxIterations) { + debug('Retry attempt: %d loopEntry=%s', iterations, loopEntry); + + if (loopEntry === 'generate') { + debug('Executing step: SqlGeneration|VerifyChecklist (parallel)'); + const [sqlPartial, checklistPartial] = await Promise.all([ + sqlGenerationStep(state, context, { + sqlLLM: this.smartLlm, + cheapLLM: this.cheapLlm, + config: this.config, + schemaHelper: this.schemaHelper, + checks: this.checks, + }), + verifyChecklistStep(state, context, { + smartLlm: this.smartLlm, + smartNonThinkingLlm: this.smartNonThinkingLlm, + config: this.config, + schemaHelper: this.schemaHelper, + checks: this.checks, + }), + ]); + state = this._merge(state, sqlPartial, checklistPartial); + debug('Completed step: SqlGeneration|VerifyChecklist (parallel)'); + + if (state.status === GenerationError.Failed) { + debug('Executing step: Failed (SqlGeneration)'); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (SqlGeneration failed)'); + return state; + } + } + + debug( + 'Executing step: SyntacticValidator|SemanticValidator|GenerateDescription (parallel)', + ); + const [syntacticPartial, semanticPartial, descPartial] = + await Promise.all([ + syntacticValidatorStep(state, context, { + llm: this.cheapLlm, + connector: this.connector, + }), + semanticValidatorStep(state, context, { + smartLlm: this.smartLlm, + cheapLlm: this.cheapLlm, + config: this.config, + tableSearchService: this.tableSearch, + schemaHelper: this.schemaHelper, + permissionHelper: this.permissionHelper, + }), + generateDescriptionStep(state, context, { + llm: this.cheapLlm, + config: this.config, + schemaHelper: this.schemaHelper, + checks: this.checks, + }), + ]); + state = this._merge( + state, + syntacticPartial, + semanticPartial, + descPartial, + ); + debug( + 'Completed step: SyntacticValidator|SemanticValidator|GenerateDescription (parallel)', + ); + + state = this._merge(state, mergeValidationResults(state)); + + if ((state.feedbacks?.length ?? 0) >= MAX_ATTEMPTS) { + debug( + 'Workflow error: max attempts reached feedbacks=%d', + state.feedbacks?.length, + ); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (max attempts)'); + return state; + } + + const validationCondition = checkPostValidationConditions(state); + debug('Branch decision: %o', validationCondition); + + if (validationCondition === 'accepted') { + debug('Executing step: SaveDataset (accepted)'); + state = this._merge(state, await this._runSaveDataset(state, context)); + debug('Workflow END success=true (accepted)'); + return state; + } + + if (validationCondition === 'reselectTables') { + debug('Executing step: GetTables (reselectTables)'); + state = this._merge( + state, + await getTablesStep(state, context, { + llmCheap: this.cheapLlm, + llmSmart: this.smartLlm, + config: this.config, + schemaHelper: this.schemaHelper, + schemaStore: this.schemaStore, + tableSearchService: this.tableSearch, + checks: this.checks, + permissionHelper: this.permissionHelper, + }), + ); + debug('Completed step: GetTables'); + if (state.status === GenerationError.Failed) { + debug('Executing step: Failed (GetTables reselect)'); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (GetTables reselect failed)'); + return state; + } + debug('Executing step: GetColumns (reselectTables)'); + state = this._merge( + state, + await getColumnsStep(state, context, { + llm: this.cheapLlm, + schemaHelper: this.schemaHelper, + config: this.config, + checks: this.checks, + }), + ); + debug('Completed step: GetColumns'); + if (state.status === GenerationError.Failed) { + debug('Executing step: Failed (GetColumns reselect)'); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (GetColumns reselect failed)'); + return state; + } + debug('Executing step: GenerateChecklist (reselectTables)'); + state = this._merge( + state, + await generateChecklistStep(state, context, { + llm: this.cheapLlm, + config: this.config, + schemaHelper: this.schemaHelper, + checks: this.checks, + }), + ); + debug('Completed step: GenerateChecklist'); + loopEntry = 'generate'; + continue; + } + + if (validationCondition === 'fixSql') { + debug('Executing step: FixQuery'); + state = this._merge( + state, + await fixQueryStep(state, context, { + llm: this.cheapLlm, + config: this.config, + schemaHelper: this.schemaHelper, + }), + ); + debug('Completed step: FixQuery'); + if (state.status === GenerationError.Failed) { + debug('Executing step: Failed (FixQuery)'); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (FixQuery failed)'); + return state; + } + loopEntry = 'validate'; + continue; + } + + debug( + 'Workflow error: unknown validationCondition=%o', + validationCondition, + ); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (unknown condition)'); + return state; + } + + debug('Workflow error: exceeded safety cap iterations=%d', iterations); + state = this._merge(state, await failedStep(state, context)); + debug('Workflow END success=false (safety cap)'); + return state; + } + + private _runSaveDataset( + state: DbQueryState, + context: MastraDbQueryContext, + ): Promise> { + return saveDatasetStep(state, context, { + llm: this.cheapLlm, + store: this.datasetStore, + config: this.config, + user: this.user, + dbSchemaHelper: this.schemaHelper, + checks: this.checks, + }); + } + + private _merge( + base: DbQueryState, + ...partials: Partial[] + ): DbQueryState { + return Object.assign({}, base, ...partials) as DbQueryState; + } +} diff --git a/src/mastra/db-query/services/dataset-search.service.ts b/src/mastra/db-query/services/dataset-search.service.ts new file mode 100644 index 0000000..badb466 --- /dev/null +++ b/src/mastra/db-query/services/dataset-search.service.ts @@ -0,0 +1,59 @@ +import {inject, injectable, BindingScope} from '@loopback/core'; +import {IAuthUserWithPermissions} from '@sourceloop/core'; +import {AuthenticationBindings} from 'loopback4-authentication'; +import {AiIntegrationBindings} from '../../../keys'; +import { + DbQueryStoredTypes, + QueryCacheMetadata, +} from '../../../components/db-query/types'; +import {IVectorStore, IVectorStoreDocument} from '../../../types'; + +const debug = require('debug')('ai-integration:mastra:db-query:dataset-search'); + +/** + * Mastra-path replacement for the LangChain `DatasetRetriever` provider. + * + * Injects `AiIntegrationBindings.AiSdkVectorStore` (`IVectorStore`) instead of + * the LangChain `VectorStore`, eliminating all `@langchain/core` dependencies + * from the Mastra execution path. The returned document shape mirrors + * `DocumentInterface` (`pageContent` + `metadata`) so existing step callers + * need no changes. + */ +@injectable({scope: BindingScope.REQUEST}) +export class DatasetSearchService { + constructor( + @inject(AiIntegrationBindings.AiSdkVectorStore) + private readonly vectorStore: IVectorStore, + @inject(AuthenticationBindings.CURRENT_USER) + private readonly user: IAuthUserWithPermissions, + ) {} + + /** + * Performs a similarity search against the stored dataset vector index. + * + * @param query - The natural-language query to search with. + * @param k - Number of results to return (default: 5). + * @returns Matching dataset documents ordered by descending relevance. + */ + async search( + query: string, + k = 5, + ): Promise[]> { + const tenantId = this.user.tenantId; + debug('search start', {query: query.slice(0, 80), k, tenantId}); + if (!tenantId) { + debug('no tenantId — returning empty results'); + return []; + } + const results = await this.vectorStore.similaritySearch( + query, + k, + { + type: DbQueryStoredTypes.DataSet, + tenantId, + }, + ); + debug('search complete: %d results', results.length); + return results; + } +} diff --git a/src/mastra/db-query/services/index.ts b/src/mastra/db-query/services/index.ts new file mode 100644 index 0000000..c9ac30a --- /dev/null +++ b/src/mastra/db-query/services/index.ts @@ -0,0 +1,3 @@ +export {DatasetSearchService} from './dataset-search.service'; +export {TemplateSearchService} from './template-search.service'; +export {MastraTemplateHelperService} from './mastra-template-helper.service'; diff --git a/src/mastra/db-query/services/mastra-template-helper.service.ts b/src/mastra/db-query/services/mastra-template-helper.service.ts new file mode 100644 index 0000000..265cd43 --- /dev/null +++ b/src/mastra/db-query/services/mastra-template-helper.service.ts @@ -0,0 +1,317 @@ +import {generateText} from 'ai'; +import {inject, injectable, BindingScope} from '@loopback/core'; +import {AiIntegrationBindings} from '../../../keys'; +import {LLMProvider} from '../../../types'; +import {RunnableConfig} from '../../../types/tool'; +import {buildPrompt} from '../utils/prompt.util'; +import {stripThinkingFromText} from '../utils/thinking.util'; +import { + DatabaseSchema, + QueryTemplate, + QueryTemplateMetadata, + TemplatePlaceholder, +} from '../../../components/db-query/types'; + +const MAX_TEMPLATE_RECURSION_DEPTH = 3; + +type ResolvedTemplate = { + sql: string; + description: string; +}; + +const EXTRACTION_TEMPLATE = ` + +You are an expert at extracting parameter values from natural language prompts. +Given a user prompt, a SQL template, and a list of placeholders with their descriptions and types, extract the value for each placeholder from the prompt. +For sql_expression placeholders, generate a valid SQL fragment that fits the position of the placeholder in the template. + + +{prompt} + + +{template} + + +{placeholders} + + +Return each extracted value as an XML tag where the tag name is the placeholder name. +If a placeholder value cannot be determined from the prompt, use the default value if provided, or leave the tag empty. + +Rules per type: +- string: Return the raw value only, without any surrounding quotes. Example: Acme Corp +- number: Return the numeric value only. Example: 10 +- boolean: Return true or false. Example: true +- sql_expression: Return a complete, valid SQL fragment with proper SQL syntax including quotes where needed. Example: created_at > '2024-01-01' + +Do not return any other text or explanation, just the XML tags. +`; + +/** + * Mastra-path replacement for `TemplateHelper`. + * + * Implements the same template-resolution logic as `TemplateHelper` but uses + * `generateText()` from the Vercel AI SDK instead of `RunnableSequence` from + * LangChain, removing the LangChain dependency from the Mastra orchestration + * path. All non-LLM methods are exact ports of the original implementation. + */ +@injectable({scope: BindingScope.REQUEST}) +export class MastraTemplateHelperService { + constructor( + @inject(AiIntegrationBindings.AiSdkCheapLLM) + private readonly llm: LLMProvider, + ) {} + + /** + * Extracts placeholder values from a natural-language prompt using the LLM. + * Uses `generateText()` instead of `RunnableSequence`. + */ + async extractPlaceholderValues( + placeholders: TemplatePlaceholder[], + prompt: string, + sqlTemplate: string, + _config: RunnableConfig, + schema?: DatabaseSchema, + ): Promise> { + const placeholderDescriptions = placeholders + .map(p => { + let desc = `- ${p.name} (type: ${p.type}): ${p.description}`; + if (p.default) desc += ` [default: ${p.default}]`; + const columnContext = this._getColumnContext(p, schema); + if (columnContext) desc += `\n ${columnContext}`; + return desc; + }) + .join('\n'); + + const content = buildPrompt(EXTRACTION_TEMPLATE, { + prompt, + template: sqlTemplate, + placeholders: placeholderDescriptions, + }); + + const {text} = await generateText({ + model: this.llm, + messages: [{role: 'user', content}], + }); + + const response = stripThinkingFromText(text); + return this._parseXmlValues(response, placeholders); + } + + /** + * Fully resolves a query template: expands template_ref placeholders + * recursively, extracts values for remaining placeholders via LLM, and + * substitutes all values into the SQL string. + */ + async resolveTemplate( + template: QueryTemplate, + prompt: string, + config: RunnableConfig, + schema?: DatabaseSchema, + templateFetcher?: (id: string) => Promise, + depth = 0, + ): Promise { + if (depth > MAX_TEMPLATE_RECURSION_DEPTH) { + throw new Error( + `Max template recursion depth exceeded (${MAX_TEMPLATE_RECURSION_DEPTH})`, + ); + } + + let sql = await this._resolveTemplateRefs( + template, + prompt, + config, + schema, + templateFetcher, + depth, + ); + + const extractablePlaceholders = template.placeholders.filter( + p => p.type !== 'template_ref', + ); + + let values: Record = {}; + if (extractablePlaceholders.length > 0) { + values = await this.extractPlaceholderValues( + extractablePlaceholders, + prompt, + sql, + config, + schema, + ); + } + + sql = this._substitutePlaceholders(sql, extractablePlaceholders, values); + + return {sql, description: template.description}; + } + + /** Parses stored `QueryTemplateMetadata` into a `QueryTemplate` object. */ + parseTemplateMetadata(metadata: QueryTemplateMetadata): QueryTemplate { + return { + id: metadata.templateId, + tenantId: '', + template: metadata.template, + description: metadata.description, + placeholders: JSON.parse(metadata.placeholders), + tables: JSON.parse(metadata.tables), + schemaHash: metadata.schemaHash, + votes: metadata.votes, + prompt: '', + }; + } + + // ─── Private helpers (exact ports of TemplateHelper) ───────────────────── + + private async _resolveTemplateRefs( + template: QueryTemplate, + prompt: string, + config: RunnableConfig, + schema: DatabaseSchema | undefined, + templateFetcher: + | ((id: string) => Promise) + | undefined, + depth: number, + ): Promise { + let sql = template.template; + const templateRefPlaceholders = template.placeholders.filter( + p => p.type === 'template_ref', + ); + for (const placeholder of templateRefPlaceholders) { + const marker = `{{${placeholder.name}}}`; + if (!sql.includes(marker)) continue; + if (!templateFetcher || !placeholder.templateId) { + throw new Error( + `Cannot resolve template_ref placeholder "${placeholder.name}" - no template fetcher or templateId`, + ); + } + const refTemplate = await templateFetcher(placeholder.templateId); + if (!refTemplate) { + throw new Error( + `Referenced template "${placeholder.templateId}" not found`, + ); + } + const resolved = await this.resolveTemplate( + refTemplate, + prompt, + config, + schema, + templateFetcher, + depth + 1, + ); + sql = sql.replace(marker, `(${resolved.sql})`); + } + return sql; + } + + private _substitutePlaceholders( + sql: string, + placeholders: TemplatePlaceholder[], + values: Record, + ): string { + for (const placeholder of placeholders) { + const value = values[placeholder.name] ?? placeholder.default ?? null; + const marker = `{{${placeholder.name}}}`; + if (!sql.includes(marker)) continue; + if (placeholder.optional && !value) { + sql = sql.replace( + new RegExp(String.raw`\s*${this._escapeRegex(marker)}\s*`), + ' ', + ); + continue; + } + sql = sql.replace(marker, this._formatValue(placeholder.type, value)); + } + return sql; + } + + private _formatValue(type: string, value: string | null): string { + switch (type) { + case 'string': + return `'${(value ?? '').replace(/'/g, "''")}'`; + case 'number': + return `${Number(value) || 0}`; + case 'boolean': + return this._isTruthy(value) ? 'TRUE' : 'FALSE'; + case 'sql_expression': + return value ?? '1=1'; + default: + return value ?? ''; + } + } + + private _isTruthy(value: string | null): boolean { + const lower = value?.toLowerCase(); + return lower === 'true' || lower === 'yes' || value === '1'; + } + + private _escapeRegex(str: string): string { + return str.replace(/[.*+?^${}()|[\]\\]/g, String.raw`\$&`); + } + + private _parseXmlValues( + xml: string, + placeholders: TemplatePlaceholder[], + ): Record { + const result: Record = {}; + for (const p of placeholders) { + const match = new RegExp( + String.raw`<${p.name}>([\s\S]*?)`, + ).exec(xml); + const value = match?.[1]?.trim(); + result[p.name] = value?.length ? value : null; + } + return result; + } + + private _getColumnContext( + placeholder: TemplatePlaceholder, + schema?: DatabaseSchema, + ): string | null { + if (!schema || !placeholder.table || !placeholder.column) return null; + const tableSchema = schema.tables[placeholder.table]; + if (!tableSchema) return null; + const columnSchema = tableSchema.columns[placeholder.column]; + if (!columnSchema) return null; + const parts: string[] = [ + `Column "${placeholder.column}" in "${placeholder.table}" (${columnSchema.type})`, + ]; + if (columnSchema.description) parts.push(columnSchema.description); + if (columnSchema.metadata) { + const metaStr = Object.entries(columnSchema.metadata) + .map(([k, v]) => `${k}: ${JSON.stringify(v)}`) + .join(', '); + if (metaStr) parts.push(metaStr); + } + parts.push( + ...this._getRelevantContextEntries( + tableSchema.context, + placeholder.column, + ), + ); + return parts.join('. '); + } + + private _getRelevantContextEntries( + context: unknown[] | undefined, + column: string, + ): string[] { + if (!context?.length) return []; + const results: string[] = []; + for (const ctx of context) { + if ( + typeof ctx === 'string' && + ctx.toLowerCase().includes(column.toLowerCase()) + ) { + results.push(ctx); + } else if ( + typeof ctx === 'object' && + ctx !== null && + (ctx as Record)[column] + ) { + results.push((ctx as Record)[column]); + } + } + return results; + } +} diff --git a/src/mastra/db-query/services/template-search.service.ts b/src/mastra/db-query/services/template-search.service.ts new file mode 100644 index 0000000..00290da --- /dev/null +++ b/src/mastra/db-query/services/template-search.service.ts @@ -0,0 +1,58 @@ +import {inject, injectable, BindingScope} from '@loopback/core'; +import {IAuthUserWithPermissions} from '@sourceloop/core'; +import {AuthenticationBindings} from 'loopback4-authentication'; +import {AiIntegrationBindings} from '../../../keys'; +import { + DbQueryStoredTypes, + QueryTemplateMetadata, +} from '../../../components/db-query/types'; +import {IVectorStore, IVectorStoreDocument} from '../../../types'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:template-search', +); + +/** + * Mastra-path replacement for the LangChain `TemplateRetriever` provider. + * + * Injects `AiIntegrationBindings.AiSdkVectorStore` (`IVectorStore`) instead of + * the LangChain `VectorStore`, eliminating all `@langchain/core` dependencies + * from the Mastra execution path. The returned document shape mirrors + * `DocumentInterface` (`pageContent` + `metadata`) so existing step callers + * need no changes. + */ +@injectable({scope: BindingScope.REQUEST}) +export class TemplateSearchService { + constructor( + @inject(AiIntegrationBindings.AiSdkVectorStore) + private readonly vectorStore: IVectorStore, + @inject(AuthenticationBindings.CURRENT_USER) + private readonly user: IAuthUserWithPermissions, + ) {} + + /** + * Performs a similarity search against the stored query-template vector index. + * + * @param query - The natural-language query to match against template descriptions. + * @param k - Number of results to return (default: 5). + * @returns Matching template documents ordered by descending relevance. + */ + async search( + query: string, + k = 5, + ): Promise[]> { + const tenantId = this.user.tenantId; + debug('search start', {query: query.slice(0, 80), k, tenantId}); + if (!tenantId) { + debug('no tenantId — returning empty results'); + return []; + } + const results = + await this.vectorStore.similaritySearch(query, k, { + type: DbQueryStoredTypes.Template, + tenantId, + }); + debug('search complete: %d results', results.length); + return results; + } +} diff --git a/src/mastra/db-query/types/db-query.types.ts b/src/mastra/db-query/types/db-query.types.ts new file mode 100644 index 0000000..845f7f2 --- /dev/null +++ b/src/mastra/db-query/types/db-query.types.ts @@ -0,0 +1,52 @@ +import {LLMStreamEvent} from '../../../types/events'; + +/** + * SSE event writer passed from the Mastra chat agent into the DbQuery workflow. + * Matches the signature of `config.writer` on LangGraphRunnableConfig. + */ +export type DbQueryWriterFn = (event: LLMStreamEvent) => void; + +/** + * Execution context threaded through every step of the Mastra DbQuery workflow. + * Passed as `RunnableConfig` to each node so they can emit SSE events and + * respect request cancellation. + */ +export interface MastraDbQueryContext { + /** + * Callback to emit events back to the SSE transport. + * Accepts `unknown` to allow arbitrary event shapes from step functions + * (matches `RunnableConfig.writer` semantics). + */ + writer?: (chunk: unknown) => void; + /** AbortSignal forwarded from the request lifecycle. Optional. */ + signal?: AbortSignal; + /** + * Optional callback invoked by each step after a `generateText()` or + * `generateObject()` call to report AI SDK token usage. + * + * Wire this to `TokenCounter.accumulate()` in the workflow runner so that + * the Mastra execution path accumulates token counts without LangChain + * callbacks. + * + * @param inputTokens - Number of prompt tokens consumed. + * @param outputTokens - Number of completion tokens produced. + * @param model - Model identifier string (e.g. `llm.modelId`). + */ + onUsage?: (inputTokens: number, outputTokens: number, model: string) => void; +} + +/** + * Input accepted by `MastraDbQueryWorkflow.run()`. + */ +export interface DbQueryWorkflowInput { + /** Natural-language prompt from the user. */ + prompt: string; + /** Existing dataset UUID when running an improvement flow; omitted for new datasets. */ + datasetId?: string; + /** + * When `true`, suppresses ToolStatus events because the caller is rendering + * results directly (not via the SSE chat transport). + * Defaults to `false`. + */ + directCall?: boolean; +} diff --git a/src/mastra/db-query/utils/index.ts b/src/mastra/db-query/utils/index.ts new file mode 100644 index 0000000..d11572f --- /dev/null +++ b/src/mastra/db-query/utils/index.ts @@ -0,0 +1,2 @@ +export {buildPrompt} from './prompt.util'; +export {stripThinkingFromText} from './thinking.util'; diff --git a/src/mastra/db-query/utils/prompt.util.ts b/src/mastra/db-query/utils/prompt.util.ts new file mode 100644 index 0000000..8c2515c --- /dev/null +++ b/src/mastra/db-query/utils/prompt.util.ts @@ -0,0 +1,22 @@ +/** + * Formats a LangChain-style prompt template using simple variable substitution. + * + * Handles the standard `PromptTemplate.fromTemplate` syntax: + * - `{variableName}` → replaced with the corresponding value + * - `{{` / `}}` → literal braces (used in JSON examples in templates) + * + * @param template - The template string with `{var}` placeholders. + * @param vars - A map of variable names to their string values. + * @returns The formatted prompt string ready to send to the LLM. + */ +export function buildPrompt( + template: string, + vars: Record, +): string { + return template + .replace(/\{\{/g, '__LBRACE__') + .replace(/\}\}/g, '__RBRACE__') + .replace(/\{(\w+)\}/g, (_, key) => vars[key] ?? '') + .replace(/__LBRACE__/g, '{') + .replace(/__RBRACE__/g, '}'); +} diff --git a/src/mastra/db-query/utils/thinking.util.ts b/src/mastra/db-query/utils/thinking.util.ts new file mode 100644 index 0000000..58cb64b --- /dev/null +++ b/src/mastra/db-query/utils/thinking.util.ts @@ -0,0 +1,24 @@ +/** + * Strips `` / `` reasoning blocks from an AI response string. + * + * This is the string-based counterpart of `stripThinkingTokens()` in + * `src/utils.ts` (which operates on a LangChain `AIMessage`). Use this + * function in all Mastra-path nodes where the AI SDK returns a plain `string`. + * + * Handles three cases: + * 1. Complete blocks — `...` or `...` + * 2. Dangling close tags — `...some preamble` (reasoning model + * token budget exhausted mid-block) + * 3. Whitespace trimming after stripping + * + * @param text - Raw text string from `generateText().text` or accumulated + * `streamText` chunks. + * @returns The cleaned response string with all thinking tokens removed. + */ +export function stripThinkingFromText(text: string): string { + // Remove complete ... and ... blocks + let result = text.replace(/[\s\S]*?<\/think(ing)?>/gi, ''); + // Remove dangling close tags and everything before them + result = result.replace(/^[\s\S]*?<\/think(ing)?>/gi, ''); + return result.trim(); +} diff --git a/src/mastra/db-query/workflow/conditions/db-query.conditions.ts b/src/mastra/db-query/workflow/conditions/db-query.conditions.ts new file mode 100644 index 0000000..864057c --- /dev/null +++ b/src/mastra/db-query/workflow/conditions/db-query.conditions.ts @@ -0,0 +1,62 @@ +import {DbQueryState} from '../../../../components/db-query/state'; +import { + EvaluationResult, + GenerationError, +} from '../../../../components/db-query/types'; + +/** + * Routing outcomes after the parallel fan-in of CheckCache, GetTables, + * CheckTemplates, and ClassifyChange. Mirrors the `PostCacheAndTables` + * conditional-edge function in `DbQueryGraph._addEdges()`. + */ +export type PostCacheCondition = + | 'fromTemplate' + | 'fromCache' + | 'failed' + | 'continue'; + +/** + * Evaluates the merged state after the initial parallel fan-out and returns + * the appropriate routing decision. + * + * - `fromTemplate` → a pre-defined SQL template matched; skip generation. + * - `fromCache` → a semantically identical query was already in cache. + * - `failed` → a node (e.g. GetTables) already set status to Failed. + * - `continue` → proceed to column selection and SQL generation. + * + * Mirrors the `PostCacheAndTables` conditional edge in `DbQueryGraph`. + */ +export function checkPostCacheAndTablesConditions( + state: DbQueryState, +): PostCacheCondition { + if (state.fromTemplate) return 'fromTemplate'; + if (state.fromCache) return 'fromCache'; + if (state.status === GenerationError.Failed) return 'failed'; + return 'continue'; +} + +/** + * Routing outcomes for the validation retry loop. + * Mirrors the `PostValidation` conditional-edge function in `DbQueryGraph._addEdges()`. + */ +export type PostValidationCondition = + | 'accepted' + | 'fixSql' + | 'reselectTables' + | 'failed'; + +/** + * Evaluates merged validation state and returns the next routing decision. + * The `feedbackCount` guard (`>= MAX_ATTEMPTS`) is handled by the caller + * before this function is invoked. + * + * Mirrors the `PostValidation` conditional edge in `DbQueryGraph`. + */ +export function checkPostValidationConditions( + state: DbQueryState, +): PostValidationCondition { + if (state.status === EvaluationResult.Pass) return 'accepted'; + if (state.status === EvaluationResult.TableError) return 'reselectTables'; + if (state.status === EvaluationResult.QueryError) return 'fixSql'; + return 'failed'; +} diff --git a/src/mastra/db-query/workflow/conditions/index.ts b/src/mastra/db-query/workflow/conditions/index.ts new file mode 100644 index 0000000..4626552 --- /dev/null +++ b/src/mastra/db-query/workflow/conditions/index.ts @@ -0,0 +1,8 @@ +export { + checkPostCacheAndTablesConditions, + checkPostValidationConditions, +} from './db-query.conditions'; +export type { + PostCacheCondition, + PostValidationCondition, +} from './db-query.conditions'; diff --git a/src/mastra/db-query/workflow/index.ts b/src/mastra/db-query/workflow/index.ts new file mode 100644 index 0000000..3debcdd --- /dev/null +++ b/src/mastra/db-query/workflow/index.ts @@ -0,0 +1,2 @@ +export * from './steps'; +export * from './conditions'; diff --git a/src/mastra/db-query/workflow/steps/check-cache.step.ts b/src/mastra/db-query/workflow/steps/check-cache.step.ts new file mode 100644 index 0000000..e2354c6 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/check-cache.step.ts @@ -0,0 +1,198 @@ +import {generateText} from 'ai'; +import {DataSetHelper} from '../../../../components/db-query/services'; +import {DatasetActionType} from '../../../../components/db-query/constant'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {CacheResults} from '../../../../components/db-query/types'; +import {LLMStreamEventType, ToolStatus} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {DatasetSearchService} from '../../services/dataset-search.service'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')('ai-integration:mastra:db-query:check-cache'); + +const CACHE_PROMPT = ` + +You are an expert Semantic analyser, you will be given a prompt from the user and a list of past prompts that were handled successfully, along with description of the sql generated from those prompts. +You need to return the most relevant prompt from the list and in which of the following ways is it relevant - +- return '${CacheResults.AsIs}' if the prompt's result would contain the information the user is looking for without any changes in the result, and can be used as it is. +- return '${CacheResults.Similar}' if the prompt's result would be similar to the question in the new prompt but not exactly, and can be modified to get the data user needs. +- return '${CacheResults.NotRelevant}' if the prompt is not relevant to the new prompt at all. +Remember that if the cached prompt has extra information, then still the old prompt could be considered exactly same as long as it does not contradict the new prompt. + + +{prompt} + + +{queries} + + +format - +relevant index-of-query-starting-from-1 +examples - +${CacheResults.AsIs} 2 + +${CacheResults.Similar} 1 + +${CacheResults.NotRelevant} + + + +Do not return any other text or explanation, just the output in the above format. +If no queries are relevant, return '${CacheResults.NotRelevant}' and nothing else. +`; + +export type CheckCacheStepDeps = { + datasetSearch: DatasetSearchService; + llm: LLMProvider; + dataSetHelper: DataSetHelper; +}; + +/** + * Searches the dataset vector index for semantically similar past queries and + * uses the LLM to classify the relevance. + */ +export async function checkCacheStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: CheckCacheStepDeps, +): Promise> { + debug('step start', {prompt: state.prompt, hasSampleSql: !!state.sampleSql}); + + if (state.sampleSql) { + debug('sampleSql already set — skipping cache check'); + return {}; + } + + const relevantDocs = await deps.datasetSearch.search(state.prompt); + if (relevantDocs.length === 0) { + debug('no documents in cache for prompt'); + return {}; + } + + const queriesText = relevantDocs + .map( + (doc, index) => + `\n\n${doc.pageContent}\n\n${doc.metadata.description}`, + ) + .join('\n'); + + const content = buildPrompt(CACHE_PROMPT, { + prompt: state.prompt, + queries: queriesText, + }); + + debug('invoking LLM for cache relevance classification'); + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const response = stripThinkingFromText(text); + const [relevance, index] = response.split(' '); + const indexNum = parseInt(index, 10) - 1; + + if (relevance === CacheResults.NotRelevant) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'No relevant queries found in cache for this prompt', + }); + return {}; + } + + if (indexNum >= relevantDocs.length || indexNum < 0 || isNaN(indexNum)) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Index ${index} is out of bounds for the list of relevant queries.`, + }); + return {}; + } + + if (relevance === CacheResults.AsIs) { + const missingPermissions = await deps.dataSetHelper.checkPermissions( + relevantDocs[indexNum].metadata.datasetId, + ); + if (missingPermissions.length > 0) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Found relevant query in cache, but missing permissions: ${missingPermissions.join(', ')} so generating new query`, + }); + return {}; + } + + const [dataset] = await deps.dataSetHelper.find({ + where: {id: relevantDocs[indexNum].metadata.datasetId}, + include: [{relation: 'actions'}], + }); + + if ( + !dataset || + (dataset.actions?.length && + dataset.actions?.some(a => a.action === DatasetActionType.Disliked)) + ) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Found relevant query in cache, but the dataset was not found or was disliked by the user, so generating new query', + }); + return {}; + } + + const datasetId = relevantDocs[indexNum].metadata.datasetId; + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Found relevant query in cache, using it as is', + }); + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Found relevant query in cache'}, + }); + + if (!state.directCall) { + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: { + status: ToolStatus.Completed, + data: {datasetId}, + }, + }); + } + + const result = { + fromCache: true, + datasetId, + replyToUser: `I found this dataset in the cache - ${relevantDocs[indexNum].pageContent}`, + }; + debug('step result (AsIs cache hit)', result); + return result; + } + + if (relevance === CacheResults.Similar) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Found similar query in cache, using it as example', + }); + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Found similar query in cache, using it as example'}, + }); + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Found relevant query in cache'}, + }); + + const result = { + sampleSql: relevantDocs[indexNum].metadata.query, + sampleSqlPrompt: relevantDocs[indexNum].pageContent, + }; + debug('step result (Similar cache hit)', result); + return result; + } + + return {}; +} diff --git a/src/mastra/db-query/workflow/steps/check-permissions.step.ts b/src/mastra/db-query/workflow/steps/check-permissions.step.ts new file mode 100644 index 0000000..084dbe3 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/check-permissions.step.ts @@ -0,0 +1,90 @@ +import {generateText} from 'ai'; +import {PermissionHelper} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {Errors} from '../../../../components/db-query/types'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:check-permissions', +); + +const PERMISSIONS_PROMPT = `You are an AI assistant that received the following request from the user - + {prompt} + + But as this request requires access to the following tables - + {tables} + + and user the does not have permissions for the following tables - + {missingPermissions} + + You must return an error message that explains the user that they do not have permissions to access the required tables and cannot proceed with the request, and then asking him to give a new request. + Do not give direct tables names or any technical details, use plain language to explain the error. + Do not return any other text, comments, or explanations. Only return a simple error message with request for new prompt. + `; + +export type CheckPermissionsStepDeps = { + llm: LLMProvider; + permissions: PermissionHelper; +}; + +/** + * Verifies the current user's RBAC permissions against the tables in the + * resolved schema. If any permissions are missing, uses the LLM to compose + * a plain-language error message and sets `state.status = Errors.PermissionError`. + */ +export async function checkPermissionsStep( + state: DbQueryState, + _context: MastraDbQueryContext, + deps: CheckPermissionsStepDeps, +): Promise> { + debug('step start', {schema: Object.keys(state.schema?.tables ?? {})}); + + const tableNames = getTableNames(state); + const missingPermissions = + deps.permissions.findMissingPermissions(tableNames); + + if (missingPermissions.length === 0) { + debug('all permissions granted'); + return {}; + } + + debug('missing permissions for tables: %o', missingPermissions); + + const content = buildPrompt(PERMISSIONS_PROMPT, { + prompt: state.prompt, + tables: tableNames.join(', '), + missingPermissions: missingPermissions.join(', '), + }); + + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + _context.onUsage?.( + usage.inputTokens ?? 0, + usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const response = stripThinkingFromText(text); + + const result = { + status: Errors.PermissionError, + replyToUser: response, + }; + debug('step result', result); + return result; +} + +function getTableNames(state: DbQueryState): string[] { + return Object.keys(state.schema?.tables ?? {}).map(table => + table.toLowerCase().slice(table.indexOf('.') + 1), + ); +} diff --git a/src/mastra/db-query/workflow/steps/check-templates.step.ts b/src/mastra/db-query/workflow/steps/check-templates.step.ts new file mode 100644 index 0000000..177118a --- /dev/null +++ b/src/mastra/db-query/workflow/steps/check-templates.step.ts @@ -0,0 +1,193 @@ +import {generateText} from 'ai'; +import {PermissionHelper} from '../../../../components/db-query/services'; +import {SchemaStore} from '../../../../components/db-query/services/schema.store'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraTemplateHelperService} from '../../services/mastra-template-helper.service'; +import {TemplateSearchService} from '../../services/template-search.service'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:check-templates', +); + +const MATCH_PROMPT = ` + +You are an expert at matching user prompts to query templates. +Given a user prompt and a list of query templates with their canonical prompts and placeholders, determine if any template can EXACTLY fulfill the user's request. + +A template is a match ONLY if ALL of the following are true: +- The template produces exactly the data the user is asking for — not more, not less +- The user's intent is identical to the template's purpose, just with different parameter values +- All non-optional placeholders can be filled from the user's prompt or have defaults +- The template does not include extra filters, columns, or logic that the user did not ask for +- The template does not omit any filters, columns, or logic that the user is asking for + +Do NOT match if: +- The template is only similar or partially relevant +- The template would need structural changes beyond placeholder substitution to answer the question +- The user is asking for something the template cannot express through its placeholders alone + + +{prompt} + + +{templates} + + +If a template is an exact match, return: match +If no template exactly matches, return: no_match + +Do not return any other text or explanation. +`; + +export type CheckTemplatesStepDeps = { + templateSearch: TemplateSearchService; + llm: LLMProvider; + permissionHelper: PermissionHelper; + templateHelper: MastraTemplateHelperService; + schemaStore: SchemaStore; +}; + +/** + * Performs a vector similarity search for matching SQL templates, then uses + * the LLM to confirm an exact semantic match. If a match is found, resolves + * all placeholders via `MastraTemplateHelperService`. + */ +export async function checkTemplatesStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: CheckTemplatesStepDeps, +): Promise> { + debug('step start', {prompt: state.prompt}); + + const relevantDocs = await deps.templateSearch.search(state.prompt); + + if (relevantDocs.length === 0) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'No templates found for this prompt', + }); + return {}; + } + + const templatesText = relevantDocs + .map((doc, index) => { + const metadata = doc.metadata; + const placeholders = JSON.parse(metadata.placeholders); + const placeholderText = placeholders + .map( + (p: {name: string; type: string; description: string}) => + ` - {{${p.name}}} (${p.type}): ${p.description}`, + ) + .join('\n'); + return ` +${doc.pageContent} + +${placeholderText} + +`; + }) + .join('\n'); + + const content = buildPrompt(MATCH_PROMPT, { + prompt: state.prompt, + templates: templatesText, + }); + + debug('invoking LLM for template matching'); + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const trimmed = stripThinkingFromText(text).trim(); + + if (trimmed === 'no_match') { + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'No matching template found for this prompt', + }); + return {}; + } + + const matchResult = trimmed.match(/^match\s+(\d+)$/); + if (!matchResult) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Unexpected template match response: ${trimmed}`, + }); + return {}; + } + + const matchIndex = Number.parseInt(matchResult[1], 10) - 1; + if (matchIndex < 0 || matchIndex >= relevantDocs.length) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Template match index ${matchResult[1]} out of bounds`, + }); + return {}; + } + + const matchedDoc = relevantDocs[matchIndex]; + const template = deps.templateHelper.parseTemplateMetadata( + matchedDoc.metadata, + ); + + const missingPermissions = deps.permissionHelper.findMissingPermissions( + template.tables, + ); + if (missingPermissions.length > 0) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Template matched but missing permissions: ${missingPermissions.join(', ')}`, + }); + return {}; + } + + try { + const schema = deps.schemaStore.filteredSchema(template.tables); + // resolveTemplate expects a RunnableConfig-compatible object; MastraDbQueryContext + // satisfies that structurally (writer + signal). + const resolved = await deps.templateHelper.resolveTemplate( + template, + state.prompt, + context as Parameters[2], + schema, + ); + + debug('template matched: %s', template.description); + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Template matched: ${template.description}`, + }); + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Matched query template'}, + }); + + const result = { + sql: resolved.sql, + description: resolved.description, + fromTemplate: true, + templateId: template.id, + }; + debug('step result', result); + return result; + } catch (error) { + debug('error', error); + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Template resolution failed: ${(error as Error).message}`, + }); + return {}; + } +} diff --git a/src/mastra/db-query/workflow/steps/classify-change.step.ts b/src/mastra/db-query/workflow/steps/classify-change.step.ts new file mode 100644 index 0000000..7c421d1 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/classify-change.step.ts @@ -0,0 +1,98 @@ +import {generateText} from 'ai'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {ChangeType} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; +import {MastraDbQueryContext} from '../../types/db-query.types'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:classify-change', +); + +const CLASSIFY_PROMPT = ` + +You are given the original description of a SQL query and a new description that includes user feedback. +Your task is to classify the level of change required to transform the original query into the new one. + +Classify as one of: +- **minor**: Small tweaks such as changing a filter value, adjusting a limit, adding/removing a single condition, or renaming an alias. +- **major**: Structural changes like adding/removing joins, changing grouping logic, adding subqueries, or significantly altering the WHERE clause. +- **rewrite**: The intent of the query has fundamentally changed, requiring a completely new query from scratch. + + + +{originalDescription} + + + +{newDescription} + + + +Return ONLY one of: minor, major, rewrite +Do not include any other text, explanation, or formatting. +`; + +export type ClassifyChangeStepDeps = { + llm: LLMProvider; +}; + +/** + * Classifies the magnitude of change between a cached SQL query and the + * user's new request. The classification guides downstream step selection + * (e.g. cheap vs. smart LLM in `sqlGenerationStep`). + */ +export async function classifyChangeStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: ClassifyChangeStepDeps, +): Promise> { + debug('step start', {hasSampleSql: !!state.sampleSql}); + + if (!state.sampleSql) { + debug('no sampleSql — skipping change classification'); + return {}; + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Classifying the level of change required for the query.', + }); + + const content = buildPrompt(CLASSIFY_PROMPT, { + originalDescription: state.sampleSqlPrompt ?? '', + newDescription: state.prompt, + }); + + debug('classifying change'); + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const response = stripThinkingFromText(text).trim().toLowerCase(); + const changeType = parseChangeType(response); + + debug('change classified as: %s', changeType); + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Change classified as: ${changeType}`, + }); + + const result = {changeType}; + debug('step result', result); + return result; +} + +function parseChangeType(response: string): ChangeType { + if (response.includes(ChangeType.Minor)) return ChangeType.Minor; + if (response.includes(ChangeType.Rewrite)) return ChangeType.Rewrite; + return ChangeType.Major; +} diff --git a/src/mastra/db-query/workflow/steps/failed.step.ts b/src/mastra/db-query/workflow/steps/failed.step.ts new file mode 100644 index 0000000..190471e --- /dev/null +++ b/src/mastra/db-query/workflow/steps/failed.step.ts @@ -0,0 +1,32 @@ +import {DbQueryState} from '../../../../components/db-query/state'; +import {LLMStreamEventType, ToolStatus} from '../../../../types/events'; +import {MastraDbQueryContext} from '../../types/db-query.types'; + +const debug = require('debug')('ai-integration:mastra:db-query:failed'); + +/** + * Emits a `ToolStatus.Failed` SSE event and ensures `state.replyToUser` is + * set to a human-readable error summary. No LLM call is made. + */ +export async function failedStep( + state: DbQueryState, + context: MastraDbQueryContext, +): Promise> { + debug('step start', {feedbacks: state.feedbacks}); + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: ToolStatus.Failed}, + }); + + const result: Partial = { + replyToUser: + state.replyToUser ?? + `I am sorry, I was not able to generate a valid SQL query for your request. ` + + `Please try again with a more detailed or a more specific prompt.\n` + + `These were the errors I encountered:\n${state.feedbacks?.join('\n') ?? 'No errors reported.'}`, + }; + + debug('step result', result); + return result; +} diff --git a/src/mastra/db-query/workflow/steps/fix-query.step.ts b/src/mastra/db-query/workflow/steps/fix-query.step.ts new file mode 100644 index 0000000..e63b648 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/fix-query.step.ts @@ -0,0 +1,199 @@ +import {generateText} from 'ai'; +import {DbSchemaHelperService} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + DatabaseSchema, + DbQueryConfig, + EvaluationResult, + GenerationError, +} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider, SupportedDBs} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')('ai-integration:mastra:db-query:fix-query'); + +const FIX_PROMPT = ` + +You are an expert AI assistant that fixes SQL query errors. +You are given a SQL query that has validation errors related to specific tables. +Your task is to fix ONLY the parts of the query related to the listed error tables. +DO NOT change any part of the query that does not involve the error tables. +Preserve the overall structure, logic, and all other table references exactly as they are. + +Rules: +- Only modify clauses, joins, columns, or conditions that involve the error tables. +- Do not add, remove, or reorder columns or tables that are not related to the error. +- Do not change aliases, formatting, or logic for unrelated parts of the query. +- **DO NOT make any DML statements** (INSERT, UPDATE, DELETE, DROP etc.) to the database. +- Use the provided schema for the error-related tables to write correct SQL. +- The dialect is {dialect}. + + + +{question} + + + +{currentQuery} + + + +{errorSchema} + + + +{errorFeedback} + + +{checks} + +{historicalErrors} + + +Output should only be a valid SQL query with no other special character or formatting. +Contains the required valid SQL with the error fixed. +It should have no other character or symbol or character that is not part of SQLs. +`; + +export type FixQueryStepDeps = { + llm: LLMProvider; + config: DbQueryConfig; + schemaHelper: DbSchemaHelperService; +}; + +/** + * Repairs the SQL query based on validation error feedback, targeting only the + * tables identified as problematic. Uses a trimmed schema (error tables only) + * to guide the fix. + */ +export async function fixQueryStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: FixQueryStepDeps, +): Promise> { + debug('step start', {sql: state.sql, feedbacks: state.feedbacks?.length}); + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Fixing SQL query based on validation errors'}, + }); + + const errorTables = [ + ...(state.syntacticErrorTables ?? []), + ...(state.semanticErrorTables ?? []), + ]; + + const trimmedSchema = + errorTables.length > 0 + ? trimSchema(state.schema, errorTables) + : state.schema; + + const lastFeedback = state.feedbacks?.length + ? state.feedbacks[state.feedbacks.length - 1] + : 'Unknown validation error'; + + const historicalFeedbacks = state.feedbacks?.slice(0, -1) ?? []; + + const content = buildPrompt(FIX_PROMPT, { + dialect: deps.config.db?.dialect ?? SupportedDBs.PostgreSQL, + question: state.prompt, + currentQuery: state.sql ?? '', + errorSchema: deps.schemaHelper.asString(trimmedSchema), + errorFeedback: lastFeedback, + checks: buildChecks(state, trimmedSchema, deps), + historicalErrors: historicalFeedbacks.length + ? [ + '', + 'You also faced these issues in previous attempts -', + historicalFeedbacks.join('\n'), + '', + ].join('\n') + : '', + }); + + debug('invoking LLM to fix query'); + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const response = stripThinkingFromText(text); + const sql = + response + .replace(/^```(?:sql)?\s*/i, '') + .replace(/```\s*$/, '') + .trim() || undefined; + + if (!sql) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `SQL fix failed: ${response}`, + }); + return { + status: GenerationError.Failed, + replyToUser: + 'Failed to fix SQL query. Please try rephrasing your question or provide more details.', + }; + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Fixed SQL query: ${sql}`, + }); + + const result = {status: EvaluationResult.Pass, sql}; + debug('step result', {sql}); + return result; +} + +function trimSchema( + fullSchema: DatabaseSchema, + errorTables: string[], +): DatabaseSchema { + const errorTableSet = new Set(errorTables); + const trimmedTables: DatabaseSchema['tables'] = {}; + + for (const tableName of Object.keys(fullSchema.tables)) { + if (errorTableSet.has(tableName)) { + trimmedTables[tableName] = fullSchema.tables[tableName]; + } + } + + const trimmedRelations = fullSchema.relations.filter( + rel => + errorTableSet.has(rel.table) || errorTableSet.has(rel.referencedTable), + ); + + return {tables: trimmedTables, relations: trimmedRelations}; +} + +function buildChecks( + state: DbQueryState, + trimmedSchema: DatabaseSchema, + deps: FixQueryStepDeps, +): string { + if (state.validationChecklist) { + return [ + '', + 'You must keep these additional details in mind while fixing the query -', + ...state.validationChecklist.split('\n').map(check => `- ${check}`), + '', + ].join('\n'); + } + const context = deps.schemaHelper.getTablesContext(trimmedSchema); + if (context.length === 0) return ''; + return [ + '', + 'You must keep these additional details in mind while fixing the query -', + ...context.map(check => `- ${check}`), + '', + ].join('\n'); +} diff --git a/src/mastra/db-query/workflow/steps/generate-checklist.step.ts b/src/mastra/db-query/workflow/steps/generate-checklist.step.ts new file mode 100644 index 0000000..e539f71 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/generate-checklist.step.ts @@ -0,0 +1,173 @@ +import {generateText} from 'ai'; +import {DbSchemaHelperService} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {DbQueryConfig} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:generate-checklist', +); + +const CHECKLIST_PROMPT = ` + +You are given a user question, the tables selected for SQL generation, the relevant database schema, and a numbered list of rules/checks. +Return ONLY the indexes of the rules that are relevant to the user's question, the selected tables, and the given schema. + +A rule is relevant if: +- It directly affects how a correct SQL query should be written for this question. +- It is a dependency of another relevant rule (e.g. if rule 3 requires a currency conversion, and rule 5 defines how currency conversion works, both must be included). +- It applies to any of the selected tables or their relationships. + +After selecting relevant rules, review your selection and ensure: +- Any rule that is referenced by, or is a prerequisite for, another selected rule is also included. +- Do not include rules that are completely unrelated to the question, schema, or selected tables. + + + +{prompt} + + + +{tables} + + + +{schema} + + + +{indexedChecks} + + + +Return only a comma-separated list of the relevant rule indexes. +Do not include any other text, explanation, or formatting. +Example: 1,3,5 +If no rules are relevant, return: none +`; + +export type GenerateChecklistStepDeps = { + llm: LLMProvider; + config: DbQueryConfig; + schemaHelper: DbSchemaHelperService; + checks?: string[]; +}; + +/** + * Filters the global validation checklist down to rules that are relevant to + * the current query context. Runs the LLM `parallelism` times concurrently + * with `Promise.all()` and merges the result sets by union. + */ +export async function generateChecklistStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: GenerateChecklistStepDeps, +): Promise> { + debug('step start', {tables: Object.keys(state.schema?.tables ?? {})}); + + if (deps.config.nodes?.generateChecklistNode?.enabled === false) { + debug('generateChecklistNode disabled by config'); + return {}; + } + + if (state.validationChecklist) { + debug('validationChecklist already set — skipping'); + return {}; + } + + const tableCount = Object.keys(state.schema?.tables ?? {}).length; + if (tableCount <= 2) { + debug('too few tables (%d) — skipping checklist generation', tableCount); + return {}; + } + + const allChecks = [ + ...(deps.checks ?? []), + ...deps.schemaHelper.getTablesContext(state.schema), + ]; + + if (allChecks.length === 0) { + debug('no checks available — skipping'); + return {}; + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Filtering validation checklist for semantic validation.', + }); + + const mergedIndexes = await runParallelChecklist( + state, + allChecks, + deps, + context, + ); + + if (mergedIndexes.size === 0) { + return {}; + } + + const validationChecklist = Array.from(mergedIndexes) + .sort((a, b) => a - b) + .map(i => allChecks[i - 1]) + .join('\n'); + + debug('generated checklist with %d rules', mergedIndexes.size); + return {validationChecklist}; +} + +async function runParallelChecklist( + state: DbQueryState, + allChecks: string[], + deps: GenerateChecklistStepDeps, + context: MastraDbQueryContext, +): Promise> { + const indexedChecks = allChecks + .map((check, i) => `${i + 1}. ${check}`) + .join('\n'); + const parallelism = + deps.config.nodes?.generateChecklistNode?.parallelism ?? 1; + + const content = buildPrompt(CHECKLIST_PROMPT, { + prompt: state.prompt, + tables: Object.keys(state.schema?.tables ?? {}).join(', '), + schema: deps.schemaHelper.asString(state.schema), + indexedChecks, + }); + + const results = await Promise.all( + Array.from({length: parallelism}, () => + generateText({model: deps.llm, messages: [{role: 'user', content}]}), + ), + ); + + const mergedIndexes = new Set(); + for (const {text, usage} of results) { + context.onUsage?.( + usage.inputTokens ?? 0, + usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + parseIndexes(stripThinkingFromText(text), allChecks.length).forEach(n => + mergedIndexes.add(n), + ); + } + return mergedIndexes; +} + +function parseIndexes(response: string, maxIndex: number): number[] { + const trimmed = response.trim(); + if (!trimmed || trimmed === 'none') return []; + return trimmed + .split(',') + .map(s => Number.parseInt(s.trim(), 10)) + .filter(n => !Number.isNaN(n) && n >= 1 && n <= maxIndex); +} diff --git a/src/mastra/db-query/workflow/steps/generate-description.step.ts b/src/mastra/db-query/workflow/steps/generate-description.step.ts new file mode 100644 index 0000000..a01a8c0 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/generate-description.step.ts @@ -0,0 +1,122 @@ +import {streamText} from 'ai'; +import {DbSchemaHelperService} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {DbQueryConfig} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:generate-description', +); + +const DESCRIPTION_PROMPT = ` + +You are an AI assistant that describes what a SQL query does in plain english. +Analyze the actual query below and write a concise, bulleted summary of the data it retrieves and any filters/conditions it applies. +Write in plain english. No SQL, no technical jargon, no table/column names. + + + +{prompt} + + + +{sql} + + + +{schema} + + +{checks} + + +Return a short bulleted list where each bullet is one condition, filter, or piece of data the query retrieves. +- Use plain, non-technical language a business user would understand. +- Do NOT mention tables, columns, joins, CTEs, enums, or any DB concepts. +- Keep each bullet to one line. +- Do not add any preamble, heading, or closing text — just the bullets. +`; + +export type GenerateDescriptionStepDeps = { + llm: LLMProvider; + config: DbQueryConfig; + schemaHelper: DbSchemaHelperService; + checks?: string[]; +}; + +/** + * Streams a plain-English description of the generated SQL query and forwards + * each text chunk as a `ToolStatus` SSE event. Uses `streamText()` from the + * Vercel AI SDK. Runs concurrently with the syntactic and semantic validators + * in the workflow's `Promise.all()` fan-out. + */ +export async function generateDescriptionStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: GenerateDescriptionStepDeps, +): Promise> { + debug('step start', {sql: state.sql}); + + const generateDesc = + deps.config.nodes?.sqlGenerationNode?.generateDescription !== false; + + if (!generateDesc || !state.sql) { + debug('description generation skipped'); + return {}; + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Generating query description.', + }); + + const content = buildPrompt(DESCRIPTION_PROMPT, { + prompt: state.prompt, + sql: state.sql, + schema: deps.schemaHelper.asString(state.schema), + checks: [ + '', + ...(deps.checks ?? []), + ...deps.schemaHelper.getTablesContext(state.schema), + '', + ].join('\n'), + }); + + debug('streaming description'); + const result = streamText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + + let accumulated = ''; + for await (const chunk of result.textStream) { + if (chunk) { + accumulated += chunk; + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {thinkingToken: chunk}, + }); + } + } + + const usage = await result.usage; + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const description = stripThinkingFromText(accumulated); + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Query description: ${description}`, + }); + + debug('step result description length=%d', description.length); + return {description}; +} diff --git a/src/mastra/db-query/workflow/steps/get-columns.step.ts b/src/mastra/db-query/workflow/steps/get-columns.step.ts new file mode 100644 index 0000000..736a0a8 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/get-columns.step.ts @@ -0,0 +1,299 @@ +import {generateText} from 'ai'; +import {DbSchemaHelperService} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + ColumnSchema, + DatabaseSchema, + DbQueryConfig, + GenerationError, + TableSchema, +} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')('ai-integration:mastra:db-query:get-columns'); + +const GET_COLUMNS_PROMPT = ` + +You are an AI assistant that identifies relevant columns from database tables based on a user's query. +Given a set of tables with their columns, you need to identify which columns are relevant to answer the user's query. + +For each table, return only the column names that are relevant to the query. Include: +1. Columns directly mentioned or implied in the query +2. Primary key columns (always needed for joins and identification) +3. Foreign key columns (needed for relationships) +4. Columns that might be needed for filtering, sorting, or calculations +5. It is better to include a few extra relevant columns than to miss important ones. + +Do not include: +- Columns that are clearly irrelevant to the query +- Descriptions, types, or any other metadata about the columns + +Return the result as a JSON object where each table name is a key and the value is an array of relevant column names. +If you are not sure about which columns to select, return your doubt asking the user for more details in the following format: +failed attempt: + + + +{tablesWithColumns} + + + +{query} + + +{checks} + +{feedbacks} + + +Return a valid JSON object with table names as keys and arrays of column names as values. +Example format (do not copy these exact values): +{{ + "table_name1": ["column1", "column2", "column3"], + "table_name2": ["column1", "column2"] +}} + +In case of failure, return the failure message in the format: +failed attempt: +`; + +const FEEDBACK_PROMPT = ` + +We also need to consider the errors from last attempt at query generation. + +In the last attempt, these were the columns selected: +{lastColumns} + +But it was rejected with the following errors: +{feedback} + +Use these errors to refine your column selection. Consider if you need additional columns for joins, filtering, or calculations. + +`; + +export type GetColumnsStepDeps = { + llm: LLMProvider; + schemaHelper: DbSchemaHelperService; + config: DbQueryConfig; + checks?: string[]; +}; + +/** + * Selects the minimal set of columns needed to answer the user's query. + * Implements the same three-attempt retry loop as the LangGraph version, + * validating that all returned column names exist in the schema. + */ +export async function getColumnsStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: GetColumnsStepDeps, +): Promise> { + debug('step start', {tables: Object.keys(state.schema?.tables ?? {})}); + + if (!deps.config.columnSelection) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Skipping column selection as per configuration', + }); + return {}; + } + + if (!state.schema?.tables || Object.keys(state.schema.tables).length === 0) { + throw new Error( + 'No tables found in the schema. Please ensure the get-tables step was completed successfully.', + ); + } + + const tablesWithColumns = getTablesWithColumns(state.schema); + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Selecting relevant columns from ${Object.keys(state.schema.tables).length} tables`, + }); + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Extracting relevant columns from the schema'}, + }); + + const feedbacksText = buildFeedbacks(state); + const content = buildPrompt(GET_COLUMNS_PROMPT, { + tablesWithColumns: tablesWithColumns.join('\n\n'), + query: state.prompt, + feedbacks: feedbacksText, + checks: [ + '', + ...(deps.checks ?? []), + ...deps.schemaHelper.getTablesContext(state.schema), + '', + ].join('\n'), + }); + + let attempts = 0; + let selectedColumns: Record = {}; + + while (attempts < 3) { + attempts++; + debug('column selection attempt %d', attempts); + + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.( + usage.inputTokens ?? 0, + usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured attempt=%d', attempts, { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const output = stripThinkingFromText(text); + + if (output.startsWith('failed attempt:')) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Column selection failed: ${output}`, + }); + return { + status: GenerationError.Failed, + replyToUser: output.replace('failed attempt: ', ''), + }; + } + + try { + const jsonMatch = output.match(/\{[\s\S]*\}/); + if (!jsonMatch) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Failed to find JSON in LLM response, trying again (attempt ${attempts})`, + }); + continue; + } + + selectedColumns = JSON.parse(jsonMatch[0]); + + if (validateColumns(selectedColumns, state.schema)) { + break; + } else { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `LLM returned invalid columns (attempt ${attempts})`, + }); + } + } catch { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Failed to parse JSON response (attempt ${attempts})`, + }); + } + } + + if (Object.keys(selectedColumns).length === 0) { + return { + status: GenerationError.Failed, + replyToUser: + 'Not able to select relevant columns. Please rephrase the question or provide more details.', + }; + } + + const filteredSchema = createFilteredSchema(state.schema, selectedColumns); + debug('step result columns=%o', selectedColumns); + return {schema: filteredSchema}; +} + +function buildFeedbacks(state: DbQueryState): string { + if (!state.feedbacks) return ''; + const lastColumns = getSelectedColumnsFromSchema(state.schema); + return buildPrompt(FEEDBACK_PROMPT, { + lastColumns: JSON.stringify(lastColumns, null, 2), + feedback: state.feedbacks.join('\n'), + }); +} + +function getTablesWithColumns(schema: DatabaseSchema): string[] { + return Object.entries(schema.tables).map(([tableName, table]) => { + const columnDescriptions = Object.entries(table.columns).map( + ([columnName, column]) => { + const details = [ + `${columnName} (${column.type})`, + column.required ? 'NOT NULL' : 'NULL', + column.id ? 'PRIMARY KEY' : '', + column.description ? `- ${column.description}` : '', + ] + .filter(Boolean) + .join(' '); + return ` - ${details}`; + }, + ); + return `${tableName}: ${table.description}\nColumns:\n${columnDescriptions.join('\n')}`; + }); +} + +function validateColumns( + selectedColumns: Record, + schema: DatabaseSchema, +): boolean { + for (const tableName of Object.keys(selectedColumns)) { + if (!schema.tables[tableName]) return false; + const tableColumns = Object.keys(schema.tables[tableName].columns); + for (const columnName of selectedColumns[tableName]) { + if (!tableColumns.includes(columnName)) return false; + } + } + return true; +} + +function createFilteredSchema( + originalSchema: DatabaseSchema, + selectedColumns: Record, +): DatabaseSchema { + const filteredTables: Record = {}; + + for (const [tableName, columnNames] of Object.entries(selectedColumns)) { + if (originalSchema.tables[tableName]) { + const originalTable = originalSchema.tables[tableName]; + const filteredColumns: Record = {}; + + for (const columnName of columnNames) { + if (originalTable.columns[columnName]) { + filteredColumns[columnName] = originalTable.columns[columnName]; + } + } + + for (const pkColumn of originalTable.primaryKey) { + if (!filteredColumns[pkColumn] && originalTable.columns[pkColumn]) { + filteredColumns[pkColumn] = originalTable.columns[pkColumn]; + } + } + + filteredTables[tableName] = { + ...originalTable, + columns: filteredColumns, + }; + } + } + + const filteredRelations = originalSchema.relations.filter( + relation => + filteredTables[relation.table] && + filteredTables[relation.referencedTable], + ); + + return {tables: filteredTables, relations: filteredRelations}; +} + +function getSelectedColumnsFromSchema( + schema: DatabaseSchema, +): Record { + const result: Record = {}; + for (const [tableName, table] of Object.entries(schema.tables)) { + result[tableName] = Object.keys(table.columns); + } + return result; +} diff --git a/src/mastra/db-query/workflow/steps/get-tables.step.ts b/src/mastra/db-query/workflow/steps/get-tables.step.ts new file mode 100644 index 0000000..3eac6db --- /dev/null +++ b/src/mastra/db-query/workflow/steps/get-tables.step.ts @@ -0,0 +1,246 @@ +import {generateText} from 'ai'; +import { + DbSchemaHelperService, + PermissionHelper, +} from '../../../../components/db-query/services'; +import {SchemaStore} from '../../../../components/db-query/services/schema.store'; +import {TableSearchService} from '../../../../components/db-query/services/search/table-search.service'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + DatabaseSchema, + DbQueryConfig, + GenerationError, +} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')('ai-integration:mastra:db-query:get-tables'); + +const GET_TABLES_PROMPT = ` + +You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later. +- Consider not just the user query but also the context and the table descriptions while selecting the tables. +- Carefully consider each and every table before including or excluding it. +- If doubtful about a table's relevance, include it anyway to give the SQL generation step more options to choose from. +- Assume that the table would have appropriate columns for relating them to any other table even if the description does not mention it. +- If you are not sure about the tables to select from the given schema, just return your doubt asking the user for more details or to rephrase the question in the following format - +failed attempt: reason for failure + + + +{tables} + + + +{query} + + +{checks} + +{feedbacks} + + +The output should be just a comma separated list of table names with no other text, comments or formatting. +Ensure that table names are exact and match the names in the input including schema if given. + +public.employees, public.departments + +In case of failure, return the failure message in the format - +failed attempt: + +failed attempt: reason for failure + +`; + +const FEEDBACK_PROMPT = ` + +We also need to consider the errors from last attempt at query generation. + +In the last attempt, these were the last tables selected: +{lastTables} + +But it was rejected with the following errors: +{feedback} + +Use these if they are relevant to the table selection, otherwise ignore them, they would be considered again during the SQL generation step. + +`; + +export type GetTablesStepDeps = { + llmCheap: LLMProvider; + llmSmart: LLMProvider; + config: DbQueryConfig; + schemaHelper: DbSchemaHelperService; + schemaStore: SchemaStore; + tableSearchService: TableSearchService; + checks?: string[]; + permissionHelper?: PermissionHelper; +}; + +/** + * Selects relevant tables from the schema using a vector similarity pre-filter + * followed by an LLM classification call. Handles the two-attempt retry loop + * to validate that returned table names exist in the schema. + */ +export async function getTablesStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: GetTablesStepDeps, +): Promise> { + debug('step start', {prompt: state.prompt}); + + const tableList = await deps.tableSearchService.getTables(state.prompt, 10); + const accessibleTables = filterByPermissions( + tableList, + deps.permissionHelper, + ); + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Selecting from tables: ${accessibleTables}`, + }); + + const dbSchema = deps.schemaStore.filteredSchema(accessibleTables); + const allTables = getTablesFromSchema(dbSchema); + + if (allTables.length === 0) { + throw new Error( + 'No tables found in the provided database schema. Please ensure the schema is valid.', + ); + } + + const useSmartLLM = deps.config.nodes?.getTablesNode?.useSmartLLM ?? false; + const llm = useSmartLLM ? deps.llmSmart : deps.llmCheap; + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Extracting relevant tables from the schema'}, + }); + + const feedbacksText = await buildFeedbacks(state, deps.schemaHelper); + const content = buildPrompt(GET_TABLES_PROMPT, { + tables: allTables.join('\n\n'), + query: state.prompt, + feedbacks: feedbacksText, + checks: [ + '', + ...(deps.checks ?? []).map(check => `- ${check}`), + ...deps.schemaHelper + .getTablesContext(dbSchema) + .map(check => `- ${check}`), + '', + ].join('\n'), + }); + + let attempts = 0; + let requiredTables: string[] = []; + + while (attempts < 2) { + attempts++; + debug('table selection attempt %d', attempts); + + const {text, usage} = await generateText({ + model: llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.( + usage.inputTokens ?? 0, + usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured attempt=%d', attempts, { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const output = stripThinkingFromText(text); + + if (output.startsWith('failed attempt:')) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Table selection failed: ${output}`, + }); + return { + status: GenerationError.Failed, + replyToUser: output.replace('failed attempt: ', ''), + }; + } + + const lastLine = output.split('\n').pop() ?? ''; + requiredTables = lastLine.split(',').map(t => t.trim()); + + if (validateTables(requiredTables, dbSchema)) { + break; + } + + if (attempts === 2) { + return { + status: GenerationError.Failed, + replyToUser: + 'Not able to select relevant tables from the schema. Please rephrase the question or provide more details.', + }; + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `LLM returned invalid tables: ${lastLine}, trying again`, + }); + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Picked tables - ${requiredTables.join(', ')}`, + }); + + if (requiredTables.length === 0) { + throw new Error( + 'LLM did not return a valid comma separated string response.', + ); + } + + const result = {schema: deps.schemaStore.filteredSchema(requiredTables)}; + debug('step result tables=%o', requiredTables); + return result; +} + +async function buildFeedbacks( + state: DbQueryState, + schemaHelper: DbSchemaHelperService, +): Promise { + if (!state.feedbacks) return ''; + return buildPrompt(FEEDBACK_PROMPT, { + lastTables: tableListFromSchema(state.schema).join(', '), + feedback: state.feedbacks.join('\n'), + }); +} + +function tableListFromSchema(schema: DatabaseSchema): string[] { + if (!schema?.tables) return []; + return Object.keys(schema.tables); +} + +function getTablesFromSchema(schema: DatabaseSchema): string[] { + if (!schema?.tables) return []; + return Object.keys(schema.tables).map(tableName => { + const table = schema.tables[tableName]; + return `${tableName}: ${table.description}`; + }); +} + +function filterByPermissions( + tables: string[], + permissionHelper?: PermissionHelper, +): string[] { + if (!permissionHelper) return tables; + return tables.filter(t => { + const name = t.toLowerCase().slice(t.indexOf('.') + 1); + return permissionHelper.findMissingPermissions([name]).length === 0; + }); +} + +function validateTables(tables: string[], schema: DatabaseSchema): boolean { + return tables.every(t => schema.tables[t] !== undefined); +} diff --git a/src/mastra/db-query/workflow/steps/index.ts b/src/mastra/db-query/workflow/steps/index.ts new file mode 100644 index 0000000..65fba54 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/index.ts @@ -0,0 +1,32 @@ +export {failedStep} from './failed.step'; +export {isImprovementStep} from './is-improvement.step'; +export type {IsImprovementStepDeps} from './is-improvement.step'; +export {classifyChangeStep} from './classify-change.step'; +export type {ClassifyChangeStepDeps} from './classify-change.step'; +export {checkCacheStep} from './check-cache.step'; +export type {CheckCacheStepDeps} from './check-cache.step'; +export {checkPermissionsStep} from './check-permissions.step'; +export type {CheckPermissionsStepDeps} from './check-permissions.step'; +export {checkTemplatesStep} from './check-templates.step'; +export type {CheckTemplatesStepDeps} from './check-templates.step'; +export {getTablesStep} from './get-tables.step'; +export type {GetTablesStepDeps} from './get-tables.step'; +export {getColumnsStep} from './get-columns.step'; +export type {GetColumnsStepDeps} from './get-columns.step'; +export {generateChecklistStep} from './generate-checklist.step'; +export type {GenerateChecklistStepDeps} from './generate-checklist.step'; +export {verifyChecklistStep} from './verify-checklist.step'; +export type {VerifyChecklistStepDeps} from './verify-checklist.step'; +export {sqlGenerationStep} from './sql-generation.step'; +export type {SqlGenerationStepDeps} from './sql-generation.step'; +export {syntacticValidatorStep} from './syntactic-validator.step'; +export type {SyntacticValidatorStepDeps} from './syntactic-validator.step'; +export {semanticValidatorStep} from './semantic-validator.step'; +export type {SemanticValidatorStepDeps} from './semantic-validator.step'; +export {generateDescriptionStep} from './generate-description.step'; +export type {GenerateDescriptionStepDeps} from './generate-description.step'; +export {fixQueryStep} from './fix-query.step'; +export type {FixQueryStepDeps} from './fix-query.step'; +export {saveDatasetStep} from './save-dataset.step'; +export type {SaveDatasetStepDeps} from './save-dataset.step'; +export {mergeValidationResults} from './post-validation.step'; diff --git a/src/mastra/db-query/workflow/steps/is-improvement.step.ts b/src/mastra/db-query/workflow/steps/is-improvement.step.ts new file mode 100644 index 0000000..ee9b2cc --- /dev/null +++ b/src/mastra/db-query/workflow/steps/is-improvement.step.ts @@ -0,0 +1,42 @@ +import {DbQueryState} from '../../../../components/db-query/state'; +import {IDataSetStore} from '../../../../components/db-query/types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; + +const debug = require('debug')('ai-integration:mastra:db-query:is-improvement'); + +export type IsImprovementStepDeps = { + store: IDataSetStore; +}; + +/** + * Detects whether the incoming request is an improvement on an existing + * dataset query. If `state.datasetId` is set, loads the original query from + * the dataset store and enriches the state so that subsequent steps treat + * this run as a modification (not a fresh generation). + * + * No LLM call is made — this is a pure data-retrieval step. + */ +export async function isImprovementStep( + state: DbQueryState, + _context: MastraDbQueryContext, + deps: IsImprovementStepDeps, +): Promise> { + debug('step start', {datasetId: state.datasetId}); + + if (!state.datasetId) { + debug('no datasetId — treating as fresh generation'); + return {}; + } + + debug('loading existing dataset %s for improvement', state.datasetId); + const dataset = await deps.store.findById(state.datasetId); + + const result = { + sampleSql: dataset.query, + sampleSqlPrompt: dataset.prompt, + prompt: `${dataset.prompt}\n also consider following feedback given by user -\n ${state.prompt}\n`, + }; + + debug('step result', result); + return result; +} diff --git a/src/mastra/db-query/workflow/steps/post-validation.step.ts b/src/mastra/db-query/workflow/steps/post-validation.step.ts new file mode 100644 index 0000000..03f4960 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/post-validation.step.ts @@ -0,0 +1,100 @@ +import {DbQueryState} from '../../../../components/db-query/state'; +import {EvaluationResult} from '../../../../components/db-query/types'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:post-validation', +); + +/** + * Merges syntactic and semantic validation results into a single unified status. + * + * Rules (mirror the LangGraph PostValidation node exactly): + * - Both validators passed → `Pass`; clear all per-round fields. + * - Syntactic failure → use syntactic status/feedback; accumulate feedbacks. + * - Semantic failure only → use semantic status/feedback; accumulate feedbacks. + */ +export function mergeValidationResults( + state: DbQueryState, +): Partial { + const hasSyntacticFailure = isValidationFailure(state.syntacticStatus); + const hasSemanticFailure = isValidationFailure(state.semanticStatus); + + debug('mergeValidationResults', { + syntacticStatus: state.syntacticStatus, + semanticStatus: state.semanticStatus, + hasSyntacticFailure, + hasSemanticFailure, + }); + + if (!hasSyntacticFailure && !hasSemanticFailure) { + debug('result: Pass — both validators cleared'); + return buildPassedResult(state); + } + + debug( + 'result: Failed — syntactic=%s semantic=%s', + hasSyntacticFailure, + hasSemanticFailure, + ); + return buildFailedResult(state, hasSyntacticFailure); +} + +function isValidationFailure(status: DbQueryState['syntacticStatus']): boolean { + return !!status && status !== EvaluationResult.Pass; +} + +function buildPassedResult(state: DbQueryState): Partial { + return { + status: EvaluationResult.Pass, + feedbacks: (state.feedbacks ?? []).filter( + f => !f.startsWith('Query Validation Failed'), + ), + syntacticStatus: undefined, + syntacticFeedback: undefined, + syntacticErrorTables: undefined, + semanticStatus: undefined, + semanticFeedback: undefined, + semanticErrorTables: undefined, + }; +} + +function buildFailedResult( + state: DbQueryState, + hasSyntacticFailure: boolean, +): Partial { + const clearedState = buildClearedState(state); + const baseFeedbacks = state.feedbacks ?? []; + const semanticFb = toArray(state.semanticFeedback); + const syntacticFb = hasSyntacticFailure + ? toArray(state.syntacticFeedback) + : []; + + return { + status: hasSyntacticFailure ? state.syntacticStatus : state.semanticStatus, + feedbacks: [...baseFeedbacks, ...syntacticFb, ...semanticFb], + ...clearedState, + }; +} + +function buildClearedState(state: DbQueryState): Partial { + const mergedErrorTables = [ + ...new Set([ + ...(state.syntacticErrorTables ?? []), + ...(state.semanticErrorTables ?? []), + ]), + ]; + const errorTables = + mergedErrorTables.length > 0 ? mergedErrorTables : undefined; + return { + syntacticStatus: undefined, + syntacticFeedback: undefined, + syntacticErrorTables: errorTables, + semanticStatus: undefined, + semanticFeedback: undefined, + semanticErrorTables: errorTables, + }; +} + +function toArray(value: string | undefined): string[] { + return value ? [value] : []; +} diff --git a/src/mastra/db-query/workflow/steps/save-dataset.step.ts b/src/mastra/db-query/workflow/steps/save-dataset.step.ts new file mode 100644 index 0000000..f098229 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/save-dataset.step.ts @@ -0,0 +1,153 @@ +import {generateText} from 'ai'; +import {HttpErrors} from '@loopback/rest'; +import {IAuthUserWithPermissions} from '@sourceloop/core'; +import {createHash} from 'crypto'; +import {AnyObject} from '@loopback/repository'; +import {DbSchemaHelperService} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + DatabaseSchema, + DbQueryConfig, + IDataSetStore, +} from '../../../../components/db-query/types'; +import {DEFAULT_MAX_READ_ROWS_FOR_AI} from '../../../../components/db-query/constant'; +import {LLMStreamEventType, ToolStatus} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')('ai-integration:mastra:db-query:save-dataset'); + +const DESCRIPTION_FALLBACK_PROMPT = `You are an AI assitant that generates a short description of a query based on a given schema, providing a summary of the query's intent and user's demand in a way that is short but does not miss any importance detail. + + Here is the query that you need to describe - {query} + + And here is the schema that was used to generate the query - + {schema} + + + {checks} + The output should be a valid description of the query that is easy to understand by the user in plain text, without any formatting`; + +export type SaveDatasetStepDeps = { + llm: LLMProvider; + store: IDataSetStore; + config: DbQueryConfig; + user: IAuthUserWithPermissions; + dbSchemaHelper: DbSchemaHelperService; + checks?: string[]; +}; + +/** + * Persists the validated SQL query as a dataset record. If `state.description` + * is already populated (by `generateDescriptionStep`), skips the fallback LLM + * call. Emits a `ToolStatus.Completed` event so the frontend can render the + * data grid. + */ +export async function saveDatasetStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: SaveDatasetStepDeps, +): Promise> { + debug('step start', {sql: state.sql, hasDescription: !!state.description}); + + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Dataset generated', + }); + + const tenantId = deps.user.tenantId; + if (!tenantId) { + throw new HttpErrors.BadRequest('User does not have a tenantId'); + } + if (!state.sql) { + throw new HttpErrors.InternalServerError(); + } + + let description = state.description; + + if (!description) { + debug('generating fallback description via LLM'); + const content = buildPrompt(DESCRIPTION_FALLBACK_PROMPT, { + query: state.sql, + schema: deps.dbSchemaHelper.asString(state.schema), + checks: [ + 'You must keep these additional details in consideration while describing the query -', + ...(deps.checks ?? []), + ].join('\n'), + }); + + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.( + usage.inputTokens ?? 0, + usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + description = stripThinkingFromText(text); + } + + const dataset = await deps.store.create({ + query: state.sql, + tenantId, + description, + prompt: state.prompt, + tables: getTableList(state.schema), + schemaHash: hashSchema(state.schema), + votes: 0, + }); + + if (!state.directCall) { + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: { + status: ToolStatus.Completed, + data: {datasetId: dataset.id}, + }, + }); + } + + let result: undefined | AnyObject[] = undefined; + if (deps.config.readAccessForAI && dataset.id) { + result = await deps.store.getData( + dataset.id, + deps.config.maxRowsForAI ?? DEFAULT_MAX_READ_ROWS_FOR_AI, + ); + } + + const stepResult = { + datasetId: dataset.id, + replyToUser: description, + done: true, + resultArray: result, + }; + debug('step result', {datasetId: dataset.id}); + return stepResult; +} + +function hashSchema(schema: DatabaseSchema): string { + const hash = createHash('sha256'); + const tableList = getTableList(schema).sort((a, b) => a.localeCompare(b)); + tableList.forEach(table => { + hash.update(table); + const columns = schema.tables[table]?.columns || {}; + Object.keys(columns) + .sort((a, b) => a.localeCompare(b)) + .forEach(column => { + hash.update(`${column}:${columns[column].type}`); + }); + }); + return hash.digest('hex'); +} + +function getTableList(schema: DatabaseSchema): string[] { + if (!schema?.tables) return []; + return Object.keys(schema.tables); +} diff --git a/src/mastra/db-query/workflow/steps/semantic-validator.step.ts b/src/mastra/db-query/workflow/steps/semantic-validator.step.ts new file mode 100644 index 0000000..690def3 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/semantic-validator.step.ts @@ -0,0 +1,191 @@ +import {generateText} from 'ai'; +import { + DbSchemaHelperService, + PermissionHelper, +} from '../../../../components/db-query/services'; +import {TableSearchService} from '../../../../components/db-query/services/search/table-search.service'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + DbQueryConfig, + EvaluationResult, +} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:semantic-validator', +); + +const SEMANTIC_PROMPT = ` + +You are an AI assistant that validates whether a SQL query satisfies a given checklist. +The query has already been validated for syntax and correctness. +Go through each checklist item and verify it against the SQL query. +DO NOT make up issues that do not exist in the query. + + + +{userPrompt} + + + +{query} + + + +{schema} + + + +{tableNames} + + + +{checklist} + + +{feedbacks} + + +If the query satisfies ALL checklist items, return ONLY a valid tag with no other text: + + + + +If any checklist item is NOT satisfied, return your response in two sections: +1. An invalid tag containing each failed item with a detailed explanation of what is wrong and how it should be fixed. +2. A tables tag listing ALL table names from the available tables that are related to the errors. Be generous - include tables directly involved in the error, tables that need to be joined to fix the issue, and any tables that could be relevant. It is better to include extra tables than to miss any. + + + +- Salary values are not converted to USD. The query should join the exchange_rates table using currency_id and multiply salary by the rate. +- Lost and hold deals are not excluded. Add a WHERE condition to filter out deals with status 0 and 2. + +exchange_rates, deals, employees + + +`; + +const FEEDBACK_PROMPT = ` + +We also need to consider the users feedback on the last attempt at query generation. + +But was rejected by validator with the following errors - +{feedback} + +Keep these feedbacks in mind while validating the new query. +`; + +export type SemanticValidatorStepDeps = { + smartLlm: LLMProvider; + cheapLlm: LLMProvider; + config: DbQueryConfig; + tableSearchService: TableSearchService; + schemaHelper: DbSchemaHelperService; + permissionHelper?: PermissionHelper; +}; + +/** + * Validates the generated SQL against the validation checklist using the LLM. + * Selects cheap vs. smart LLM based on config. + */ +export async function semanticValidatorStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: SemanticValidatorStepDeps, +): Promise> { + debug('step start', {sql: state.sql}); + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: { + status: `Verifying if the query fully satisfies the user's requirement`, + }, + }); + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Validating the query semantically.', + }); + + const useSmartLLM = + deps.config.nodes?.semanticValidatorNode?.useSmartLLM ?? false; + const llm = useSmartLLM ? deps.smartLlm : deps.cheapLlm; + + const tableList = + (await deps.tableSearchService.getTables(state.prompt)) ?? []; + const accessibleTables = filterByPermissions( + tableList, + deps.permissionHelper, + ); + + const feedbacksText = state.feedbacks?.length + ? buildPrompt(FEEDBACK_PROMPT, {feedback: state.feedbacks.join('\n')}) + : ''; + + const content = buildPrompt(SEMANTIC_PROMPT, { + userPrompt: state.prompt, + query: state.sql ?? '', + schema: deps.schemaHelper.asString(state.schema), + tableNames: accessibleTables.join(', '), + checklist: state.validationChecklist ?? 'No checklist provided.', + feedbacks: feedbacksText, + }); + + debug('invoking LLM for semantic validation'); + const {text, usage} = await generateText({ + model: llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const response = stripThinkingFromText(text); + const invalidMatch = /(.*?)<\/invalid>/s.exec(response); + const tablesMatch = /(.*?)<\/tables>/s.exec(response); + const isValid = + response.includes('') || response.includes(''); + + if (isValid && !invalidMatch) { + debug('semantic validation passed'); + return {semanticStatus: EvaluationResult.Pass}; + } + + const reason = invalidMatch ? invalidMatch[1].trim() : response.trim(); + const errorTables = tablesMatch + ? tablesMatch[1] + .split(',') + .map(t => t.trim()) + .filter(t => t.length > 0) + : []; + + debug('semantic validation failed: %s', reason); + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Query Validation Failed by LLM: ${reason}`, + }); + + const result = { + semanticStatus: EvaluationResult.QueryError, + semanticFeedback: `Query Validation Failed by LLM: ${reason}`, + semanticErrorTables: errorTables, + }; + debug('step result', result); + return result; +} + +function filterByPermissions( + tables: string[], + permissionHelper?: PermissionHelper, +): string[] { + if (!permissionHelper) return tables; + return tables.filter(t => { + const name = t.toLowerCase().slice(t.indexOf('.') + 1); + return permissionHelper.findMissingPermissions([name]).length === 0; + }); +} diff --git a/src/mastra/db-query/workflow/steps/sql-generation.step.ts b/src/mastra/db-query/workflow/steps/sql-generation.step.ts new file mode 100644 index 0000000..8fc6589 --- /dev/null +++ b/src/mastra/db-query/workflow/steps/sql-generation.step.ts @@ -0,0 +1,221 @@ +import {generateText} from 'ai'; +import {DbSchemaHelperService} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + ChangeType, + DbQueryConfig, + EvaluationResult, + GenerationError, +} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider, SupportedDBs} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')('ai-integration:mastra:db-query:sql-generation'); + +const SQL_GENERATION_PROMPT = ` + +You are an expert AI assistant that generates SQL queries based on user questions and a given database schema. +You try to following the instructions carefully to generate the SQL query that answers the question. +Do not hallucinate details or make up information. +Your task is to convert a question into a SQL query, given a {dialect} database schema. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **DO NOT make any DML statements** (INSERT, UPDATE, DELETE, DROP etc.) to the database. +- Never query for all the columns from a specific table, only ask for the relevant columns for the given the question. +- You can only generate a single query, so if you need multiple results you can use JOINs, subqueries, CTEs or UNIONS. +- Do not make any assumptions about the user's intent beyond what is explicitly provided in the prompt. +- Ensure proper grouping with brackets for where clauses with multiple conditions using AND and OR. +- Follow each and every single rule in the "must-follow-rules" section carefully while writing the query. DO NOT SKIP ANY RULE. + + +{question} + + + +{dbschema} + + +{checks} + +{exampleQueries} + +{feedbacks} + + +{outputFormat} +`; + +const OUTPUT_FORMAT = ` +Output should only be a valid SQL query with no other special character or formatting. +Contains the required valid SQL satisfying all the constraints. +It should have no other character or symbol or character that is not part of SQLs.`; + +const FEEDBACK_PROMPT = ` + +We also need to consider the users feedback on the last attempt at query generation. +Make sure you fix the provided error without introducing any new or past errors. +In the last attempt, you generated this SQL query - + +{query} + + + +{feedback} + + +{historicalErrors} +`; + +export type SqlGenerationStepDeps = { + sqlLLM: LLMProvider; + cheapLLM: LLMProvider; + config: DbQueryConfig; + schemaHelper: DbSchemaHelperService; + checks?: string[]; +}; + +/** + * Selects cheap vs. smart LLM based on query complexity (minor change, single + * table, or validation-fix retry → cheap LLM; otherwise → smart LLM). + * Generates a SQL query from the filtered schema and state context. + */ +export async function sqlGenerationStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: SqlGenerationStepDeps, +): Promise> { + debug('step start', { + prompt: state.prompt, + feedbacks: state.feedbacks?.length, + }); + + const isSingleTable = + state.schema.tables && Object.keys(state.schema.tables).length === 1; + const isValidationFixRetry = + state.feedbacks?.length && + state.feedbacks[state.feedbacks.length - 1].startsWith( + 'Query Validation Failed', + ); + + const llm = + state.changeType === ChangeType.Minor || + isSingleTable || + isValidationFixRetry + ? deps.cheapLLM + : deps.sqlLLM; + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Generating SQL query from the prompt - ${state.prompt}`, + }); + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Generating SQL query from the prompt'}, + }); + + const content = buildPrompt(SQL_GENERATION_PROMPT, { + dialect: deps.config.db?.dialect ?? SupportedDBs.PostgreSQL, + question: state.prompt, + dbschema: deps.schemaHelper.asString(state.schema), + checks: buildChecks(state, deps), + feedbacks: buildFeedbacks(state), + exampleQueries: state.feedbacks?.length ? '' : buildSampleQueries(state), + outputFormat: OUTPUT_FORMAT, + }); + + debug('generating SQL'); + const {text, usage} = await generateText({ + model: llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const response = stripThinkingFromText(text); + const sql = + response + .replace(/^```(?:sql)?\s*/i, '') + .replace(/```\s*$/, '') + .trim() || undefined; + + if (!sql) { + context.writer?.({ + type: LLMStreamEventType.Log, + data: `SQL generation failed: ${response}`, + }); + return { + status: GenerationError.Failed, + replyToUser: + 'Failed to generate SQL query. Please try rephrasing your question or provide more details.', + }; + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Generated SQL query: ${sql}`, + }); + + const result = {status: EvaluationResult.Pass, sql}; + debug('step result', {sql}); + return result; +} + +function buildFeedbacks(state: DbQueryState): string { + if (!state.feedbacks?.length) return ''; + const lastFeedback = state.feedbacks[state.feedbacks.length - 1]; + const otherFeedbacks = state.feedbacks.slice(0, -1); + return buildPrompt(FEEDBACK_PROMPT, { + query: state.sql ?? '', + feedback: `This was the error in the latest query you generated - \n${lastFeedback}`, + historicalErrors: otherFeedbacks.length + ? [ + '', + 'You already faced following issues in the past -', + otherFeedbacks.join('\n'), + '', + ].join('\n') + : '', + }); +} + +function buildSampleQueries(state: DbQueryState): string { + let startTag = ''; + let endTag = ''; + let baseLine = + 'Here is an example query for reference that is similar to the question asked and has been validated by the user'; + if (!state.fromCache) { + startTag = ''; + endTag = ''; + baseLine = + 'Here is the last valid SQL query that was generated for the user that is supposed to be used as the base line for the next query generation.'; + } + return state.sampleSql + ? `${startTag}\n${baseLine} -\n${state.sampleSql}\nThis was generated for the following question - \n${state.sampleSqlPrompt} \n\n${endTag}` + : ''; +} + +function buildChecks(state: DbQueryState, deps: SqlGenerationStepDeps): string { + if (state.validationChecklist) { + return [ + '', + 'You must keep these additional details in mind while writing the query -', + ...state.validationChecklist.split('\n').map(check => `- ${check}`), + '', + ].join('\n'); + } + return [ + '', + 'You must keep these additional details in mind while writing the query -', + ...(deps.checks ?? []).map(check => `- ${check}`), + ...deps.schemaHelper + .getTablesContext(state.schema) + .map(check => `- ${check}`), + '', + ].join('\n'); +} diff --git a/src/mastra/db-query/workflow/steps/syntactic-validator.step.ts b/src/mastra/db-query/workflow/steps/syntactic-validator.step.ts new file mode 100644 index 0000000..e4028ff --- /dev/null +++ b/src/mastra/db-query/workflow/steps/syntactic-validator.step.ts @@ -0,0 +1,122 @@ +import {generateText} from 'ai'; +import {DbQueryState} from '../../../../components/db-query/state'; +import { + EvaluationResult, + IDbConnector, +} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:syntactic-validator', +); + +const CATEGORIZE_PROMPT = `You are an AI assistant that categorizes the SQL query error and identifies related tables. + +Here is the SQL query error that you need to categorize - +{error} + +Here is the query that resulted in the error - +{query} + +Here are all the available tables in the database - +{tableNames} + +Categorize the error into one of these two categories: +- table_not_found: Any error that indicates a table or column is missing +- query_error: All other errors + +Also identify ALL tables that are related to the error. Be generous - include tables that are directly involved in the error, tables referenced in the failing part of the query, and tables that might need to be joined or referenced to fix the error. It is better to include extra tables than to miss any. + +Return your response in exactly this format with no other text: +table_not_found or query_error +comma, separated, table, names +`; + +export type SyntacticValidatorStepDeps = { + llm: LLMProvider; + connector: IDbConnector; +}; + +/** + * Executes the SQL against the configured connector to catch database-level + * syntax and schema errors. On failure, uses the LLM to categorize the error + * and identify affected tables for the retry loop. + */ +export async function syntacticValidatorStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: SyntacticValidatorStepDeps, +): Promise> { + debug('step start', {sql: state.sql}); + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Validating generated SQL query'}, + }); + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Validating the query syntactically.', + }); + + try { + if (!state.sql) throw new Error('No SQL query generated to validate'); + await deps.connector.validate(state.sql); + debug('syntactic validation passed'); + return {syntacticStatus: EvaluationResult.Pass}; + } catch (error) { + debug('syntactic validation failed: %s', (error as Error).message); + + const tableNames = Object.keys(state.schema?.tables ?? {}); + const content = buildPrompt(CATEGORIZE_PROMPT, { + error: (error as Error).message, + query: state.sql ?? '', + tableNames: tableNames.join(', '), + }); + + const {text, usage} = await generateText({ + model: deps.llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.( + usage.inputTokens ?? 0, + usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const result = stripThinkingFromText(text); + const categoryMatch = /(.*?)<\/category>/s.exec(result); + const tablesMatch = /(.*?)<\/tables>/s.exec(result); + + const category = categoryMatch + ? (categoryMatch[1].trim() as EvaluationResult) + : (result.trim() as EvaluationResult); + + const errorTables = tablesMatch + ? tablesMatch[1] + .split(',') + .map(t => t.trim()) + .filter(t => t.length > 0) + : []; + + context.writer?.({ + type: LLMStreamEventType.Log, + data: `Query Validation Failed by DB: ${category} with error ${(error as Error).message}`, + }); + + const stepResult = { + syntacticStatus: category, + syntacticFeedback: `Query Validation Failed by DB: ${category} with error ${(error as Error).message}`, + syntacticErrorTables: errorTables, + }; + debug('step result', stepResult); + return stepResult; + } +} diff --git a/src/mastra/db-query/workflow/steps/verify-checklist.step.ts b/src/mastra/db-query/workflow/steps/verify-checklist.step.ts new file mode 100644 index 0000000..2f5bcce --- /dev/null +++ b/src/mastra/db-query/workflow/steps/verify-checklist.step.ts @@ -0,0 +1,191 @@ +import {generateText} from 'ai'; +import {DbSchemaHelperService} from '../../../../components/db-query/services'; +import {DbQueryState} from '../../../../components/db-query/state'; +import {DbQueryConfig} from '../../../../components/db-query/types'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {MastraDbQueryContext} from '../../types/db-query.types'; +import {buildPrompt} from '../../utils/prompt.util'; +import {stripThinkingFromText} from '../../utils/thinking.util'; + +const debug = require('debug')( + 'ai-integration:mastra:db-query:verify-checklist', +); + +const BASE_PROMPT = ` + +You are given a user question, the tables selected for SQL generation, the relevant database schema, and a numbered list of rules/checks. +Return ONLY the indexes of the rules that are relevant to the user's question, the selected tables, and the given schema. + +A rule is relevant if: +- It directly affects how a correct SQL query should be written for this question. +- It is a dependency of another relevant rule (e.g. if rule 3 requires a currency conversion, and rule 5 defines how currency conversion works, both must be included). +- It applies to any of the selected tables or their relationships. + +Ensure: +- Any rule that is referenced by, or is a prerequisite for, another selected rule is also included. +- Do not include rules that are completely unrelated to the question, schema, or selected tables. + + + +{prompt} + + + +{tables} + + + +{schema} + + + +{indexedChecks} + + +`; + +const EVALUATION_OUTPUT = ` +First, evaluate each rule inside an evaluation tag. For each rule, repeat the full rule text exactly as given, followed by " — Include" or " — Exclude" with a brief reason. +Then, return only the comma-separated list of included rule indexes inside a result tag. + +Example: + +1. When matching names, use ilike with wildcards — Include, query involves name matching +2. Format dates using to_char — Exclude, no date fields in this query +3. Always exclude lost deals — Include, query involves deals + +1,3 + +If no rules are relevant: none +`; + +const SIMPLE_OUTPUT = ` +Return ONLY the comma-separated list of relevant rule indexes inside a result tag. +Do NOT include any reasoning, analysis, or explanation — only the result tag. +Example: +1,3,5 +If no rules are relevant: +none +`; + +export type VerifyChecklistStepDeps = { + smartLlm: LLMProvider; + smartNonThinkingLlm?: LLMProvider; + config: DbQueryConfig; + schemaHelper: DbSchemaHelperService; + checks?: string[]; +}; + +/** + * A second-pass checklist filter that runs only for schemas with more than two + * tables. Supports an optional chain-of-thought "evaluation" mode. Merges + * verified indexes with any checklist already in state. + */ +export async function verifyChecklistStep( + state: DbQueryState, + context: MastraDbQueryContext, + deps: VerifyChecklistStepDeps, +): Promise> { + debug('step start', {tables: Object.keys(state.schema?.tables ?? {})}); + + if (deps.config.nodes?.verifyChecklistNode?.enabled === false) { + return {}; + } + + if (state.feedbacks?.length) { + return {}; + } + + const tableCount = Object.keys(state.schema?.tables ?? {}).length; + if (tableCount <= 2) { + return {}; + } + + const allChecks = [ + ...(deps.checks ?? []), + ...deps.schemaHelper.getTablesContext(state.schema), + ]; + + if (allChecks.length === 0) { + return {}; + } + + context.writer?.({ + type: LLMStreamEventType.Log, + data: 'Verifying validation checklist with chain-of-thought.', + }); + + const llm = deps.smartNonThinkingLlm ?? deps.smartLlm; + const indexedChecks = allChecks + .map((check, i) => `${i + 1}. ${check}`) + .join('\n'); + const useEvaluation = + deps.config.nodes?.verifyChecklistNode?.evaluation ?? false; + + const content = buildPrompt( + BASE_PROMPT + (useEvaluation ? EVALUATION_OUTPUT : SIMPLE_OUTPUT), + { + prompt: state.prompt, + tables: Object.keys(state.schema?.tables ?? {}).join(', '), + schema: deps.schemaHelper.asString(state.schema), + indexedChecks, + }, + ); + + debug( + 'invoking LLM for checklist verification (evaluation=%s)', + useEvaluation, + ); + const {text, usage} = await generateText({ + model: llm, + messages: [{role: 'user', content}], + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + + const response = stripThinkingFromText(text).trim(); + const verifiedIndexes = parseVerifiedIndexes(response, allChecks.length); + + if (verifiedIndexes.length === 0) { + return {}; + } + + const validationChecklist = mergeWithExisting( + state.validationChecklist, + verifiedIndexes, + allChecks, + ); + + debug('step result checklist rules=%d', verifiedIndexes.length); + return {validationChecklist}; +} + +function parseVerifiedIndexes(response: string, maxIndex: number): number[] { + const resultMatch = /(.*?)<\/result>/s.exec(response); + const indexStr = resultMatch ? resultMatch[1].trim() : response; + + if (!indexStr || indexStr === 'none') return []; + + return indexStr + .split(',') + .map(s => Number.parseInt(s.trim(), 10)) + .filter(n => !Number.isNaN(n) && n >= 1 && n <= maxIndex); +} + +function mergeWithExisting( + existing: string | undefined, + verifiedIndexes: number[], + allChecks: string[], +): string { + const existingChecks = new Set( + (existing ?? '').split('\n').filter(c => c.length > 0), + ); + for (const check of verifiedIndexes.map(i => allChecks[i - 1])) { + existingChecks.add(check); + } + return Array.from(existingChecks).join('\n'); +} diff --git a/src/mastra/index.ts b/src/mastra/index.ts index 10128b2..45208aa 100644 --- a/src/mastra/index.ts +++ b/src/mastra/index.ts @@ -1,3 +1,5 @@ export * from './chat/mastra-chat.agent'; +export * from './db-query'; +export * from './visualization'; export * from './request-tool-store'; export * from './types'; diff --git a/src/mastra/request-tool-store.ts b/src/mastra/request-tool-store.ts index 7b7d31f..dced6b6 100644 --- a/src/mastra/request-tool-store.ts +++ b/src/mastra/request-tool-store.ts @@ -1,5 +1,5 @@ -import {LLMStreamEvent} from '../graphs/event.types'; -import {IRuntimeTool} from '../graphs/types'; +import {LLMStreamEvent} from '../types/events'; +import {IRuntimeTool} from '../types/tool'; /** * Per-request IRuntimeTool registry. diff --git a/src/mastra/types.ts b/src/mastra/types.ts index cc436d0..9e13a3d 100644 --- a/src/mastra/types.ts +++ b/src/mastra/types.ts @@ -1,5 +1,3 @@ -import {AnyObject} from '@loopback/repository'; - /** * A single message in the format accepted by Mastra (compatible with AI SDK `CoreMessage`). * diff --git a/src/mastra/visualization/index.ts b/src/mastra/visualization/index.ts new file mode 100644 index 0000000..1dd4496 --- /dev/null +++ b/src/mastra/visualization/index.ts @@ -0,0 +1,9 @@ +export {MastraVisualizationWorkflow} from './mastra-visualization.workflow'; +export type { + MastraVisualizationContext, + MastraVisualizationState, + VisualizationWorkflowInput, + IMastraVisualizer, +} from './types/visualization.types'; +export * from './services'; +export * from './workflow'; diff --git a/src/mastra/visualization/mastra-visualization.workflow.ts b/src/mastra/visualization/mastra-visualization.workflow.ts new file mode 100644 index 0000000..2d49795 --- /dev/null +++ b/src/mastra/visualization/mastra-visualization.workflow.ts @@ -0,0 +1,204 @@ +import {BindingScope, inject, injectable, service} from '@loopback/core'; +import {IDataSetStore} from '../../components/db-query/types'; +import {DbQueryAIExtensionBindings} from '../../components/db-query/keys'; +import {AiIntegrationBindings} from '../../keys'; +import {LLMProvider} from '../../types'; +import {TokenCounter} from '../../services/token-counter.service'; +import {MastraDbQueryWorkflow} from '../db-query/mastra-db-query.workflow'; +import { + MastraBarVisualizerService, + MastraLineVisualizerService, + MastraPieVisualizerService, +} from './services'; +import { + IMastraVisualizer, + MastraVisualizationContext, + MastraVisualizationState, + VisualizationWorkflowInput, +} from './types/visualization.types'; +import { + callQueryGenerationStep, + checkPostQueryGenerationConditions, + checkPostSelectConditions, + getDatasetDataStep, + renderVisualizationStep, + selectVisualizationStep, +} from './workflow'; + +const debug = require('debug')('mastra:visualization:workflow'); + +/** + * Mastra-path imperative workflow for the Visualization feature. + * + * Injects all services directly (no class-based node wrappers) and delegates + * to pure step functions in `workflow/steps/`. This is the exact same pattern + * used by `MastraDbQueryWorkflow` in Phase 3. + * + * ## Flow (mirrors `VisualizationGraph.build()`) + * ``` + * START + * └─ selectVisualization ─┬─ error → END + * └─ continue + * └─ callQueryGeneration ─┬─ error → END + * └─ continue + * └─ getDatasetData + * └─ renderVisualization → END + * ``` + * + * ## Key design decisions + * - `selectVisualizationStep` resolves the chart type from a list of + * `IMastraVisualizer` instances injected directly (no `findByTag` at runtime). + * - `callQueryGenerationStep` delegates to `MastraDbQueryWorkflow` when no + * `datasetId` is provided, forwarding `context.writer` and `context.signal`. + * - `renderVisualizationStep` calls `visualizer.getConfig()` which uses AI SDK + * `generateObject()` — no LangGraph `withStructuredOutput()` anywhere. + * + * @injectable `BindingScope.REQUEST` — one instance per HTTP request. + */ +@injectable({scope: BindingScope.REQUEST}) +export class MastraVisualizationWorkflow { + /** All registered Mastra-path visualizers, collected for step injection. */ + private readonly mastraVisualizers: IMastraVisualizer[]; + + constructor( + // ── LLM providers ────────────────────────────────────────────────────── + @inject(AiIntegrationBindings.AiSdkCheapLLM) + private readonly cheapLlm: LLMProvider, + + // ── Dataset store ─────────────────────────────────────────────────────── + @inject(DbQueryAIExtensionBindings.DatasetStore) + private readonly datasetStore: IDataSetStore, + + // ── Mastra visualizer services ────────────────────────────────────────── + @service(MastraBarVisualizerService) + private readonly barVisualizer: MastraBarVisualizerService, + @service(MastraLineVisualizerService) + private readonly lineVisualizer: MastraLineVisualizerService, + @service(MastraPieVisualizerService) + private readonly pieVisualizer: MastraPieVisualizerService, + + // ── DbQuery workflow (for generating datasets on-the-fly) ─────────────── + @service(MastraDbQueryWorkflow) + private readonly dbQueryWorkflow: MastraDbQueryWorkflow, + @service(TokenCounter) + private readonly tokenCounter: TokenCounter, + ) { + // Collect all visualizers into a flat array for the selection step. + // New visualizers should be added here AND registered in VisualizerComponent. + this.mastraVisualizers = [barVisualizer, lineVisualizer, pieVisualizer]; + } + + /** + * Execute the full Visualization workflow. + * + * @param input User prompt, optional datasetId, and optional chart type hint. + * @param ctx Execution context: SSE writer and/or AbortSignal. + * @returns Final `MastraVisualizationState` after all steps have run. + */ + async run( + input: VisualizationWorkflowInput, + ctx?: MastraVisualizationContext, + ): Promise { + const context: MastraVisualizationContext = { + ...ctx, + onUsage: (i, o, m) => { + this.tokenCounter.accumulate(i, o, m); + if (ctx?.onUsage) ctx.onUsage(i, o, m); + }, + }; + + debug( + 'Workflow START prompt=%s datasetId=%s type=%s', + input.prompt, + input.datasetId ?? '(none)', + input.type ?? '(auto)', + ); + + // ── Initial state ─────────────────────────────────────────────────────── + let state: MastraVisualizationState = { + prompt: input.prompt, + datasetId: input.datasetId, + type: input.type, + }; + + // ── Step 1: SelectVisualization ───────────────────────────────────────── + debug('Executing step: SelectVisualization'); + state = this.merge( + state, + await selectVisualizationStep(state, context, { + llm: this.cheapLlm, + visualizers: this.mastraVisualizers, + }), + ); + debug( + 'Completed step: SelectVisualization visualizer=%s', + state.visualizerName ?? state.error, + ); + + // ── Branch: SelectVisualization error? ────────────────────────────────── + const selectCondition = checkPostSelectConditions(state); + debug('Branch decision (post-select): %s', selectCondition); + + if (selectCondition === 'error') { + debug( + 'Workflow END success=false (no matching visualizer: %s)', + state.error, + ); + return state; + } + + // ── Step 2: CallQueryGeneration ───────────────────────────────────────── + debug('Executing step: CallQueryGeneration'); + state = this.merge( + state, + await callQueryGenerationStep(state, context, { + dbQueryWorkflow: this.dbQueryWorkflow, + }), + ); + debug( + 'Completed step: CallQueryGeneration datasetId=%s', + state.datasetId ?? state.error, + ); + + // ── Branch: CallQueryGeneration error? ────────────────────────────────── + const queryCondition = checkPostQueryGenerationConditions(state); + debug('Branch decision (post-query-gen): %s', queryCondition); + + if (queryCondition === 'error') { + debug( + 'Workflow END success=false (dataset generation failed: %s)', + state.error, + ); + return state; + } + + // ── Step 3: GetDatasetData ────────────────────────────────────────────── + debug('Executing step: GetDatasetData'); + state = this.merge( + state, + await getDatasetDataStep(state, context, { + store: this.datasetStore, + }), + ); + debug('Completed step: GetDatasetData sql=%s', state.sql?.substring(0, 60)); + + // ── Step 4: RenderVisualization ───────────────────────────────────────── + debug('Executing step: RenderVisualization'); + state = this.merge(state, await renderVisualizationStep(state, context)); + debug('Completed step: RenderVisualization done=%s', state.done); + + debug('Workflow END success=true visualizer=%s', state.visualizerName); + return state; + } + + /** + * Shallow-merge one or more partial states into `base`. + * Last-write-wins per field — same semantics as LangGraph's `Annotation.Root`. + */ + private merge( + base: MastraVisualizationState, + ...partials: Partial[] + ): MastraVisualizationState { + return Object.assign({}, base, ...partials); + } +} diff --git a/src/mastra/visualization/services/bar.visualizer.service.ts b/src/mastra/visualization/services/bar.visualizer.service.ts new file mode 100644 index 0000000..114df4e --- /dev/null +++ b/src/mastra/visualization/services/bar.visualizer.service.ts @@ -0,0 +1,130 @@ +import {injectable, BindingScope, inject} from '@loopback/core'; +import {AnyObject} from '@loopback/repository'; +import {generateObject} from 'ai'; +import {z} from 'zod'; +import {AiIntegrationBindings} from '../../../keys'; +import {LLMProvider} from '../../../types'; +import { + IMastraVisualizer, + MastraVisualizationState, +} from '../types/visualization.types'; + +const debug = require('debug')('ai-integration:mastra:visualization:bar'); + +/** + * Zod schema describing the bar chart configuration returned by the LLM. + * Mirrors the schema used by the LangGraph `BarVisualizer`. + */ +const BAR_CONFIG_SCHEMA = z.object({ + categoryColumn: z + .string() + .describe('Column to be used for categories (x-axis) in the bar chart'), + valueColumn: z + .string() + .describe('Column to be used for values (y-axis) in the bar chart'), + orientation: z + .string() + .default('vertical') + .describe( + 'Orientation of the bar chart: `vertical` or `horizontal` without backticks', + ), +}); + +/** + * Mastra-path bar-chart visualizer. + * + * Replaces the LangGraph `BarVisualizer` by using AI SDK `generateObject()` + * instead of `BaseChatModel.withStructuredOutput()`. Business logic and + * prompt are identical — only the LLM call site changes. + * + * Implements `IMastraVisualizer` so `selectVisualizationStep` can discover + * and rank it alongside other Mastra visualizers. + */ +@injectable({scope: BindingScope.SINGLETON}) +export class MastraBarVisualizerService implements IMastraVisualizer { + /** Unique chart type key — must match the value returned by the LLM. */ + readonly name = 'bar'; + + readonly description = + 'Renders the data in a bar chart format. Best for comparing values across different categories or showing trends over time.'; + + readonly context = + 'A bar chart requires data with at exactly two columns: one for the categories (x-axis) and one for the values (y-axis). Ensure that the category column contains discrete values representing different groups or categories, while the value column contains numerical data that can be compared across these categories. Bar charts can be oriented either vertically or horizontally depending on the data representation needs.'; + + constructor( + @inject(AiIntegrationBindings.AiSdkCheapLLM) + private readonly llm: LLMProvider, + ) {} + + /** + * Uses AI SDK `generateObject()` with `BAR_CONFIG_SCHEMA` to map the SQL + * query's columns to bar-chart axes. + * + * @param state Current visualization state with `sql`, `queryDescription`, + * and `prompt` already populated. + * @returns `{ categoryColumn, valueColumn, orientation }` chart config. + */ + async getConfig( + state: MastraVisualizationState, + onUsage?: ( + inputTokens: number, + outputTokens: number, + model: string, + ) => void, + ): Promise { + if (!state.sql || !state.queryDescription || !state.prompt) { + throw new Error( + 'MastraBarVisualizerService: Invalid State — sql, queryDescription and prompt are required', + ); + } + + debug( + 'Generating bar chart config for sql=%s', + state.sql?.substring(0, 80), + ); + + const systemPrompt = `You are an expert data visualization assistant. Your task is to create a bar chart config based on the provided SQL query, its description and user prompt. Follow these steps: +1. Analyze the SQL query results to understand the data structure. +2. Identify the category column (x-axis) and value column (y-axis) for the bar chart. +3. Create a configuration object for the bar chart using the identified columns. +4. Return the bar chart configuration object.`; + + const userPrompt = ` +${state.sql} + + +${state.queryDescription} + + +${state.prompt} +`; + + // Cast to avoid TS2589 (deep overload inference in AI SDK v6) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const result = (await (generateObject as any)({ + model: this.llm, + schema: BAR_CONFIG_SCHEMA, + system: systemPrompt, + prompt: userPrompt, + })) as { + object: { + categoryColumn: string; + valueColumn: string; + orientation: string; + }; + usage: {inputTokens: number; outputTokens: number}; + }; + + onUsage?.( + result.usage.inputTokens ?? 0, + result.usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured', { + promptTokens: result.usage.inputTokens ?? 0, + completionTokens: result.usage.outputTokens ?? 0, + }); + debug('Bar chart config generated: %o', result.object); + return result.object as AnyObject; + } +} diff --git a/src/mastra/visualization/services/index.ts b/src/mastra/visualization/services/index.ts new file mode 100644 index 0000000..aae3015 --- /dev/null +++ b/src/mastra/visualization/services/index.ts @@ -0,0 +1,3 @@ +export {MastraBarVisualizerService} from './bar.visualizer.service'; +export {MastraLineVisualizerService} from './line.visualizer.service'; +export {MastraPieVisualizerService} from './pie.visualizer.service'; diff --git a/src/mastra/visualization/services/line.visualizer.service.ts b/src/mastra/visualization/services/line.visualizer.service.ts new file mode 100644 index 0000000..e415912 --- /dev/null +++ b/src/mastra/visualization/services/line.visualizer.service.ts @@ -0,0 +1,151 @@ +import {injectable, BindingScope, inject} from '@loopback/core'; +import {AnyObject} from '@loopback/repository'; +import {generateObject} from 'ai'; +import {z} from 'zod'; +import {AiIntegrationBindings} from '../../../keys'; +import {LLMProvider} from '../../../types'; +import { + IMastraVisualizer, + MastraVisualizationState, +} from '../types/visualization.types'; + +const debug = require('debug')('ai-integration:mastra:visualization:line'); + +/** + * Zod schema describing the line chart configuration returned by the LLM. + * Mirrors the schema used by the LangGraph `LineVisualizer`. + */ +const LINE_CONFIG_SCHEMA = z.object({ + xAxisColumn: z + .string() + .describe( + 'Single column name to be used for x-axis in the line chart (typically time or sequential data)', + ), + yAxisColumn: z + .string() + .describe( + 'Single column name to be used for y-axis values in the line chart', + ), + seriesColumns: z + .string() + .describe( + 'Optional column to group data into multiple lines/series, leave it as empty string if not needed. It can cover multiple columns separated by comma if the query needs to show multiple lines based on multiple columns. The UI supports multiple series in line chart by forming a combined key.', + ), +}); + +/** + * Mastra-path line-chart visualizer. + * + * Replaces the LangGraph `LineVisualizer` by using AI SDK `generateObject()` + * instead of `BaseChatModel.withStructuredOutput()`. Business logic, prompt, + * and post-processing of `seriesColumns` are identical — only the LLM call + * site changes. + * + * Implements `IMastraVisualizer` so `selectVisualizationStep` can discover + * and rank it alongside other Mastra visualizers. + */ +@injectable({scope: BindingScope.SINGLETON}) +export class MastraLineVisualizerService implements IMastraVisualizer { + /** Unique chart type key — must match the value returned by the LLM. */ + readonly name = 'line'; + + readonly description = + 'Renders the data in a line chart format. Best for showing trends and changes over time or continuous data.'; + + readonly context = + 'A line chart requires data with exactly 3 columns: one for the x-axis (typically time or sequential data), one for the y-axis (values), and one series type column to distinguish multiple lines/series in the chart. The series type column is important for grouping data into separate lines.'; + + constructor( + @inject(AiIntegrationBindings.AiSdkSmartNonThinkingLLM) + private readonly llm: LLMProvider, + ) {} + + /** + * Uses AI SDK `generateObject()` with `LINE_CONFIG_SCHEMA` to map the SQL + * query's columns to line-chart axes. + * + * Post-processes `seriesColumns`: + * - Empty string / null / undefined → `null` (no series grouping) + * - Comma-separated string → `string[]` (multiple series) + * + * @param state Current visualization state with `sql`, `queryDescription`, + * and `prompt` already populated. + * @returns `{ xAxisColumn, yAxisColumn, seriesColumns }` chart config. + */ + async getConfig( + state: MastraVisualizationState, + onUsage?: ( + inputTokens: number, + outputTokens: number, + model: string, + ) => void, + ): Promise { + if (!state.sql || !state.queryDescription || !state.prompt) { + throw new Error( + 'MastraLineVisualizerService: Invalid State — sql, queryDescription and prompt are required', + ); + } + + debug( + 'Generating line chart config for sql=%s', + state.sql?.substring(0, 80), + ); + + const systemPrompt = `You are an expert data visualization assistant. Your task is to create a line chart config based on the provided SQL query, its description and user prompt. Follow these steps: +1. Analyze the SQL query results to understand the data structure. +2. Identify the x-axis column (typically time or sequential data) and y-axis column (values) for the line chart. +3. Determine if there are multiple series to be plotted (multiple lines) with combination of multiple columns, or single series based on single column. +4. Create a configuration object for the line chart using the identified columns. +5. Return the line chart configuration object.`; + + const userPrompt = ` +${state.sql} + + +${state.queryDescription} + + +${state.prompt} +`; + + // Cast to avoid TS2589 (deep overload inference in AI SDK v6) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const result = (await (generateObject as any)({ + model: this.llm, + schema: LINE_CONFIG_SCHEMA, + system: systemPrompt, + prompt: userPrompt, + })) as { + object: {xAxisColumn: string; yAxisColumn: string; seriesColumns: string}; + usage: {inputTokens: number; outputTokens: number}; + }; + + onUsage?.( + result.usage.inputTokens ?? 0, + result.usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured', { + promptTokens: result.usage.inputTokens ?? 0, + completionTokens: result.usage.outputTokens ?? 0, + }); + + // Normalise seriesColumns: empty string → null, CSV string → string[] + const settings: AnyObject = {...result.object}; + if ( + settings.seriesColumns === '' || + settings.seriesColumns === undefined || + settings.seriesColumns === null + ) { + settings.seriesColumns = null; + } else { + settings.seriesColumns = + (settings.seriesColumns as string) + .split(',') + .map((s: string) => s.trim()) ?? []; + } + + debug('Line chart config generated: %o', settings); + return settings; + } +} diff --git a/src/mastra/visualization/services/pie.visualizer.service.ts b/src/mastra/visualization/services/pie.visualizer.service.ts new file mode 100644 index 0000000..de5ed5e --- /dev/null +++ b/src/mastra/visualization/services/pie.visualizer.service.ts @@ -0,0 +1,120 @@ +import {injectable, BindingScope, inject} from '@loopback/core'; +import {AnyObject} from '@loopback/repository'; +import {generateObject} from 'ai'; +import {z} from 'zod'; +import {AiIntegrationBindings} from '../../../keys'; +import {LLMProvider} from '../../../types'; +import { + IMastraVisualizer, + MastraVisualizationState, +} from '../types/visualization.types'; + +const debug = require('debug')('ai-integration:mastra:visualization:pie'); + +/** + * Zod schema describing the pie chart configuration returned by the LLM. + * Mirrors the schema used by the LangGraph `PieVisualizer`. + */ +const PIE_CONFIG_SCHEMA = z.object({ + labelColumn: z + .string() + .describe('Column to be used for labels in the pie chart'), + valueColumn: z + .string() + .describe('Column to be used for values in the pie chart'), +}); + +/** + * Mastra-path pie-chart visualizer. + * + * Replaces the LangGraph `PieVisualizer` by using AI SDK `generateObject()` + * instead of `BaseChatModel.withStructuredOutput()`. Business logic and + * prompt are identical — only the LLM call site changes. + * + * Implements `IMastraVisualizer` so `selectVisualizationStep` can discover + * and rank it alongside other Mastra visualizers. + */ +@injectable({scope: BindingScope.SINGLETON}) +export class MastraPieVisualizerService implements IMastraVisualizer { + /** Unique chart type key — must match the value returned by the LLM. */ + readonly name = 'pie'; + + readonly description = + 'Renders the data in a pie chart format. Best for visualizing proportions and percentages among categories.'; + + readonly context = + 'A pie chart requires data with at least two columns: one for the labels (categories) and one for the values (numerical data). Ensure that the values are non-negative and represent parts of a whole, as pie charts are used to visualize proportions and percentages among different categories.'; + + constructor( + @inject(AiIntegrationBindings.AiSdkCheapLLM) + private readonly llm: LLMProvider, + ) {} + + /** + * Uses AI SDK `generateObject()` with `PIE_CONFIG_SCHEMA` to map the SQL + * query's columns to pie-chart segments. + * + * @param state Current visualization state with `sql`, `queryDescription`, + * and `prompt` already populated. + * @returns `{ labelColumn, valueColumn }` chart config. + */ + async getConfig( + state: MastraVisualizationState, + onUsage?: ( + inputTokens: number, + outputTokens: number, + model: string, + ) => void, + ): Promise { + if (!state.sql || !state.queryDescription || !state.prompt) { + throw new Error( + 'MastraPieVisualizerService: Invalid State — sql, queryDescription and prompt are required', + ); + } + + debug( + 'Generating pie chart config for sql=%s', + state.sql?.substring(0, 80), + ); + + const systemPrompt = `You are an expert data visualization assistant. Your task is to create a pie chart config based on the provided SQL query, its description and user prompt. Follow these steps: +1. Analyze the SQL query results to understand the data structure. +2. Identify the key categories and their corresponding values for the pie chart. +3. Create a configuration object for the pie chart using the identified categories and values. +4. Return the pie chart configuration object.`; + + const userPrompt = ` +${state.sql} + + +${state.queryDescription} + + +${state.prompt} +`; + + // Cast to avoid TS2589 (deep overload inference in AI SDK v6) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const result = (await (generateObject as any)({ + model: this.llm, + schema: PIE_CONFIG_SCHEMA, + system: systemPrompt, + prompt: userPrompt, + })) as { + object: {labelColumn: string; valueColumn: string}; + usage: {inputTokens: number; outputTokens: number}; + }; + + onUsage?.( + result.usage.inputTokens ?? 0, + result.usage.outputTokens ?? 0, + 'unknown', + ); + debug('token usage captured', { + promptTokens: result.usage.inputTokens ?? 0, + completionTokens: result.usage.outputTokens ?? 0, + }); + debug('Pie chart config generated: %o', result.object); + return result.object as AnyObject; + } +} diff --git a/src/mastra/visualization/types/visualization.types.ts b/src/mastra/visualization/types/visualization.types.ts new file mode 100644 index 0000000..0e34883 --- /dev/null +++ b/src/mastra/visualization/types/visualization.types.ts @@ -0,0 +1,136 @@ +import {AnyObject} from '@loopback/repository'; + +// ── Mastra-path state ──────────────────────────────────────────────────────── + +/** + * Plain-TypeScript state object for the Mastra Visualization workflow. + * + * Mirrors `VisualizationGraphState` (LangGraph) field-for-field but does NOT + * depend on `@langchain/langgraph` `Annotation` — making every step function + * completely LangGraph-free. + * + * `visualizer` references `IMastraVisualizer` (the Mastra-path visualizer + * interface) rather than the LangGraph-coupled `IVisualizer`. + */ +export interface MastraVisualizationState { + /** Natural-language prompt from the user. */ + prompt: string; + /** + * Existing dataset UUID (when provided up-front or set by the + * `callQueryGeneration` step after creating a new dataset). + */ + datasetId?: string; + /** SQL query string fetched from the dataset store. */ + sql?: string; + /** Human-readable description of the dataset's SQL query. */ + queryDescription?: string; + /** + * The resolved `IMastraVisualizer` instance that will render the chart. + * Set by `selectVisualizationStep`. + */ + visualizer?: IMastraVisualizer; + /** Friendly name of the chosen visualizer (e.g. `'bar'`). */ + visualizerName?: string; + /** `true` once the visualization has been successfully rendered. */ + done?: boolean; + /** Final chart configuration object emitted to the SSE transport. */ + visualizerConfig?: AnyObject; + /** + * Non-empty string means the workflow entered an error path. + * The `renderVisualization` step is skipped when `error` is set. + */ + error?: string; + /** + * Optional visualization type hint supplied by the caller. + * When present, `selectVisualizationStep` bypasses the LLM selection call. + */ + type?: string; +} + +// ── Mastra-path visualizer interface ──────────────────────────────────────── + +/** + * Interface for a Mastra-path visualizer. + * + * Mirrors `IVisualizer` (LangGraph component) but: + * - accepts `MastraVisualizationState` instead of `VisualizationGraphState` + * - uses AI SDK `generateObject()` internally (no `withStructuredOutput()`) + * + * Each chart-type service (`MastraBarVisualizerService`, etc.) implements this. + */ +export interface IMastraVisualizer { + /** Unique chart type identifier (e.g. `'bar'`, `'line'`, `'pie'`). */ + name: string; + /** Short description shown to the LLM for visualizer selection. */ + description: string; + /** + * Optional guidance injected into the data-generation prompt so the SQL + * query returns columns shaped for this chart type. + */ + context?: string; + /** + * Generate and return a chart configuration object for the given state. + * Implementations use AI SDK `generateObject()` with a Zod schema. + * @param onUsage Optional callback to report token usage for rate limiting. + */ + getConfig( + state: MastraVisualizationState, + onUsage?: ( + inputTokens: number, + outputTokens: number, + model: string, + ) => void, + ): Promise; +} + +// ── Execution context ──────────────────────────────────────────────────────── + +/** + * Execution context threaded through every step of the Mastra Visualization + * workflow. Mirrors `MastraDbQueryContext` — passed as the second argument to + * every step function so they can emit SSE events and respect cancellation. + */ +export interface MastraVisualizationContext { + /** + * Callback to emit streaming events back to the SSE transport. + * Accepts `unknown` to allow arbitrary event shapes from step functions + * (matches `RunnableConfig.writer` semantics used by the LangGraph path). + */ + writer?: (chunk: unknown) => void; + /** AbortSignal forwarded from the request lifecycle. Optional. */ + signal?: AbortSignal; + /** + * Optional callback invoked after `generateText()` / `generateObject()` to + * report AI SDK token usage for the Mastra path's token counting. + * + * Wire to `TokenCounter.accumulate()` in the workflow runner. + * + * @param inputTokens - Prompt tokens consumed. + * @param outputTokens - Completion tokens produced. + * @param model - Model identifier string. + */ + onUsage?: (inputTokens: number, outputTokens: number, model: string) => void; +} + +// ── Workflow input ─────────────────────────────────────────────────────────── + +/** + * Input accepted by `MastraVisualizationWorkflow.run()`. + * + * Mirrors the Zod schema declared on `GenerateVisualizationTool.inputSchema`. + */ +export interface VisualizationWorkflowInput { + /** Natural-language prompt describing the visualization the user wants. */ + prompt: string; + /** + * Optional existing dataset UUID. When provided, the workflow skips the + * `callQueryGeneration` step and uses this dataset directly. + */ + datasetId?: string; + /** + * Optional visualization type hint (e.g. `'bar'`, `'line'`, `'pie'`). + * When provided, the `selectVisualization` step skips the LLM selection + * call and uses the named visualizer directly. + */ + type?: string; +} diff --git a/src/mastra/visualization/workflow/conditions/index.ts b/src/mastra/visualization/workflow/conditions/index.ts new file mode 100644 index 0000000..2c5409b --- /dev/null +++ b/src/mastra/visualization/workflow/conditions/index.ts @@ -0,0 +1,8 @@ +export { + checkPostSelectConditions, + checkPostQueryGenerationConditions, +} from './visualization.conditions'; +export type { + PostSelectCondition, + PostQueryGenerationCondition, +} from './visualization.conditions'; diff --git a/src/mastra/visualization/workflow/conditions/visualization.conditions.ts b/src/mastra/visualization/workflow/conditions/visualization.conditions.ts new file mode 100644 index 0000000..c13f078 --- /dev/null +++ b/src/mastra/visualization/workflow/conditions/visualization.conditions.ts @@ -0,0 +1,55 @@ +import {MastraVisualizationState} from '../../types/visualization.types'; + +// ── Post-SelectVisualization routing ───────────────────────────────────────── + +/** + * Routing outcomes after `selectVisualizationStep`. + * + * Mirrors the `addConditionalEdges` on `SelectVisualisation` in + * `VisualizationGraph.build()`. + */ +export type PostSelectCondition = 'error' | 'continue'; + +/** + * Evaluates state after `selectVisualizationStep` and returns the routing + * decision. + * + * - `'error'` → `state.error` is set; short-circuit to end without rendering. + * - `'continue'` → proceed to `callQueryGenerationStep`. + * + * Mirrors the conditional edge after `SelectVisualisation` in the LangGraph + * `VisualizationGraph`. + */ +export function checkPostSelectConditions( + state: MastraVisualizationState, +): PostSelectCondition { + if (state.error) return 'error'; + return 'continue'; +} + +// ── Post-CallQueryGeneration routing ───────────────────────────────────────── + +/** + * Routing outcomes after `callQueryGenerationStep`. + * + * Mirrors the `addConditionalEdges` on `CallQueryGeneration` in + * `VisualizationGraph.build()`. + */ +export type PostQueryGenerationCondition = 'error' | 'continue'; + +/** + * Evaluates state after `callQueryGenerationStep` and returns the routing + * decision. + * + * - `'error'` → dataset generation failed; short-circuit to end. + * - `'continue'` → proceed to `getDatasetDataStep`. + * + * Mirrors the conditional edge after `CallQueryGeneration` in the LangGraph + * `VisualizationGraph`. + */ +export function checkPostQueryGenerationConditions( + state: MastraVisualizationState, +): PostQueryGenerationCondition { + if (state.error) return 'error'; + return 'continue'; +} diff --git a/src/mastra/visualization/workflow/index.ts b/src/mastra/visualization/workflow/index.ts new file mode 100644 index 0000000..43e671a --- /dev/null +++ b/src/mastra/visualization/workflow/index.ts @@ -0,0 +1,4 @@ +// Workflow step functions and their deps types +export * from './steps'; +// Routing condition functions and their return types +export * from './conditions'; diff --git a/src/mastra/visualization/workflow/steps/call-query-generation.step.ts b/src/mastra/visualization/workflow/steps/call-query-generation.step.ts new file mode 100644 index 0000000..a995aa0 --- /dev/null +++ b/src/mastra/visualization/workflow/steps/call-query-generation.step.ts @@ -0,0 +1,84 @@ +import {LLMStreamEventType} from '../../../../types/events'; +import {MastraDbQueryWorkflow} from '../../../db-query/mastra-db-query.workflow'; +import { + MastraVisualizationContext, + MastraVisualizationState, +} from '../../types/visualization.types'; + +const debug = require('debug')( + 'ai-integration:mastra:visualization:call-query-generation', +); + +/** Dependencies injected by `MastraVisualizationWorkflow`. */ +export type CallQueryGenerationStepDeps = { + /** The Mastra DB-Query workflow for generating a new dataset when needed. */ + dbQueryWorkflow: MastraDbQueryWorkflow; +}; + +/** + * Calls the Mastra DbQuery workflow to generate a dataset when one has not + * been provided by the caller. + * + * Short-circuits immediately if `state.datasetId` is already set — this + * matches the LangGraph `CallQueryGenerationNode` behaviour where the node + * was a no-op for pre-existing datasets. + * + * When dataset generation succeeds, the step returns `{ datasetId }`. + * On failure it returns `{ error }` which causes the workflow to short-circuit + * before reaching `renderVisualizationStep`. + * + * The prompt sent to the DbQuery workflow appends the selected visualizer's + * `context` hint (e.g. "ensure exactly two columns") so the generated SQL is + * already shaped for the chosen chart type. + * + * Mirrors `CallQueryGenerationNode.execute()` in the LangGraph path. + * LangGraph coupling removed: `DbQueryGraph.build().invoke()` → + * `MastraDbQueryWorkflow.run()`. + */ +export async function callQueryGenerationStep( + state: MastraVisualizationState, + context: MastraVisualizationContext, + deps: CallQueryGenerationStepDeps, +): Promise> { + debug('step start datasetId=%s', state.datasetId ?? '(none)'); + + // ── Short-circuit: dataset already known ───────────────────────────────── + if (state.datasetId) { + debug('datasetId already set, skipping query generation'); + return {}; + } + + // ── Build dataset-generation prompt with visualizer context hint ───────── + const vizContext = state.visualizer?.context + ? ` Ensure that the query structure satisfies the following context: ${state.visualizer.context}` + : ''; + + const dbQueryPrompt = `Generate a query to fetch data for visualization based on the following user prompt: ${state.prompt}.${vizContext}`; + + debug('Calling DbQuery workflow prompt=%s', dbQueryPrompt.substring(0, 120)); + + // Forward writer/signal so the nested workflow can emit status events too + const dbQueryResult = await deps.dbQueryWorkflow.run( + {prompt: dbQueryPrompt, directCall: true}, + {writer: context.writer, signal: context.signal}, + ); + + if (!dbQueryResult.datasetId) { + const reason = dbQueryResult.replyToUser ?? 'Unknown error'; + debug('DbQuery workflow failed: %s', reason); + context.writer?.({ + type: LLMStreamEventType.Error, + data: { + status: `Failed to create dataset for visualization: ${reason}`, + }, + }); + return { + error: + dbQueryResult.replyToUser ?? + 'Failed to create dataset for visualization', + }; + } + + debug('Dataset generated: datasetId=%s', dbQueryResult.datasetId); + return {datasetId: dbQueryResult.datasetId}; +} diff --git a/src/mastra/visualization/workflow/steps/get-dataset-data.step.ts b/src/mastra/visualization/workflow/steps/get-dataset-data.step.ts new file mode 100644 index 0000000..61a8fb5 --- /dev/null +++ b/src/mastra/visualization/workflow/steps/get-dataset-data.step.ts @@ -0,0 +1,51 @@ +import {LLMStreamEventType} from '../../../../types/events'; +import {IDataSetStore} from '../../../../components/db-query/types'; +import { + MastraVisualizationContext, + MastraVisualizationState, +} from '../../types/visualization.types'; + +const debug = require('debug')( + 'ai-integration:mastra:visualization:get-dataset-data', +); + +/** Dependencies injected by `MastraVisualizationWorkflow`. */ +export type GetDatasetDataStepDeps = { + /** Dataset store used to fetch the SQL query and description. */ + store: IDataSetStore; +}; + +/** + * Fetches the SQL query and human-readable description from the dataset store + * using `state.datasetId`. + * + * Populates `state.sql` and `state.queryDescription` so that the subsequent + * `renderVisualizationStep` can pass them to the visualizer's `getConfig()`. + * + * Also emits a "Preparing visualization" status event to the SSE transport. + * + * Mirrors `GetDatasetDataNode.execute()` in the LangGraph path. + * LangGraph coupling removed: `@inject(DbQueryAIExtensionBindings.DatasetStore)` → + * explicit `deps.store` parameter. + */ +export async function getDatasetDataStep( + state: MastraVisualizationState, + context: MastraVisualizationContext, + deps: GetDatasetDataStepDeps, +): Promise> { + debug('step start datasetId=%s', state.datasetId); + + const dataset = await deps.store.findById(state.datasetId!); + + debug('Dataset fetched sql=%s', dataset.query?.substring(0, 80)); + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Preparing visualization'}, + }); + + return { + sql: dataset.query, + queryDescription: dataset.description, + }; +} diff --git a/src/mastra/visualization/workflow/steps/index.ts b/src/mastra/visualization/workflow/steps/index.ts new file mode 100644 index 0000000..c672394 --- /dev/null +++ b/src/mastra/visualization/workflow/steps/index.ts @@ -0,0 +1,11 @@ +export {selectVisualizationStep} from './select-visualization.step'; +export type {SelectVisualizationStepDeps} from './select-visualization.step'; + +export {callQueryGenerationStep} from './call-query-generation.step'; +export type {CallQueryGenerationStepDeps} from './call-query-generation.step'; + +export {getDatasetDataStep} from './get-dataset-data.step'; +export type {GetDatasetDataStepDeps} from './get-dataset-data.step'; + +export {renderVisualizationStep} from './render-visualization.step'; +export type {RenderVisualizationStepDeps} from './render-visualization.step'; diff --git a/src/mastra/visualization/workflow/steps/render-visualization.step.ts b/src/mastra/visualization/workflow/steps/render-visualization.step.ts new file mode 100644 index 0000000..c9cc9d5 --- /dev/null +++ b/src/mastra/visualization/workflow/steps/render-visualization.step.ts @@ -0,0 +1,73 @@ +import {LLMStreamEventType, ToolStatus} from '../../../../types/events'; +import { + MastraVisualizationContext, + MastraVisualizationState, +} from '../../types/visualization.types'; + +const debug = require('debug')( + 'ai-integration:mastra:visualization:render-visualization', +); + +/** + * `renderVisualizationStep` has no extra dependencies beyond the state and + * context — the visualizer instance is already resolved in `state.visualizer` + * by `selectVisualizationStep`. + */ +export type RenderVisualizationStepDeps = Record; + +/** + * Renders the final chart configuration by calling the resolved visualizer's + * `getConfig()` method. + * + * Two SSE events are emitted: + * 1. `ToolStatus` — "Configuring " — signals the frontend that rendering + * has started. + * 2. `ToolStatus` — `ToolStatus.Completed` — delivers the final chart config + * (datasetId, visualization name, and config object) for the UI to render. + * + * Returns `{ done: true, visualizerConfig }` on success. + * + * Mirrors `RenderVisualizationNode.execute()` in the LangGraph path. + * LangGraph coupling removed: `LangGraphRunnableConfig` → `MastraVisualizationContext`. + */ +export async function renderVisualizationStep( + state: MastraVisualizationState, + context: MastraVisualizationContext, + _deps?: RenderVisualizationStepDeps, +): Promise> { + debug('step start visualizer=%s', state.visualizerName); + + const visualizer = state.visualizer; + + if (!visualizer || !state.sql || !state.queryDescription) { + throw new Error( + 'renderVisualizationStep: Invalid State — visualizer, sql, and queryDescription are all required', + ); + } + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: `Configuring ${visualizer.name}`}, + }); + + debug('Calling visualizer.getConfig() for %s', visualizer.name); + const settings = await visualizer.getConfig(state, context.onUsage); + debug('Visualizer config generated: %o', settings); + + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: { + status: ToolStatus.Completed, + data: { + datasetId: state.datasetId, + visualization: visualizer.name, + config: settings ?? {}, + }, + }, + }); + + return { + done: true, + visualizerConfig: settings ?? {}, + }; +} diff --git a/src/mastra/visualization/workflow/steps/select-visualization.step.ts b/src/mastra/visualization/workflow/steps/select-visualization.step.ts new file mode 100644 index 0000000..5f312d9 --- /dev/null +++ b/src/mastra/visualization/workflow/steps/select-visualization.step.ts @@ -0,0 +1,148 @@ +import {generateText} from 'ai'; +import {LLMStreamEventType} from '../../../../types/events'; +import {LLMProvider} from '../../../../types'; +import {stripThinkingFromText} from '../../../db-query/utils/thinking.util'; +import { + IMastraVisualizer, + MastraVisualizationContext, + MastraVisualizationState, +} from '../../types/visualization.types'; + +const debug = require('debug')( + 'ai-integration:mastra:visualization:select-visualization', +); + +const SELECT_VISUALIZATION_PROMPT = ` + +You are expert Data Analysis Agent whose job is to suggest visualisations that would be best suited to display the results for a particular user prompt and the data extracted based on that prompt. +You are provided with 2 inputs - +- user prompt +- A list of visualization names with their descriptions that are supported. + +You need to suggest a visualisation from a list of visualisation that would best fit the user's request. + + + +{prompt} + + + +{sql} + + +{description} + + + +{visualizations} + + + + +The output should be a single string that has the name from the visualizations list and nothing else. +If none of the visualizations fit the requirement, return "none" followed by the changes required in the data to be able to render the visualization. +Do not try to force fit the prompt to any visualization if it does not make sense. Prefer to returning none with appropriate reason instead. + + +type-of-visualization + + +none: reason why the visualization is not possible with the current prompt. + + +`; + +/** Dependencies injected by `MastraVisualizationWorkflow`. */ +export type SelectVisualizationStepDeps = { + /** Cheap LLM used for the visualizer-selection decision. */ + llm: LLMProvider; + /** + * All registered Mastra visualizer service instances. + * The step ranks them textually and returns one. + */ + visualizers: IMastraVisualizer[]; +}; + +/** + * Selects the most appropriate chart type for the user's data. + * + * Two code paths: + * 1. **Explicit type** (`state.type`): short-circuit — skip the LLM call and + * resolve the named visualizer directly (mirrors LangGraph `SelectVisualizationNode`). + * 2. **LLM selection**: invokes `generateText()` from the AI SDK with a + * formatted prompt that lists all registered visualizers. + * + * Returns `Partial` with either: + * - `{ visualizer, visualizerName }` on success, OR + * - `{ error }` when no matching visualizer is found. + * + * Mirrors `SelectVisualizationNode.execute()` in the LangGraph path. + * LangGraph coupling removed: `PromptTemplate` / `RunnableSequence` → + * plain `generateText()` + string interpolation. + */ +export async function selectVisualizationStep( + state: MastraVisualizationState, + context: MastraVisualizationContext, + deps: SelectVisualizationStepDeps, +): Promise> { + debug('step start type=%s', state.type ?? '(auto)'); + + const {visualizers} = deps; + + // ── Fast-path: caller specified an exact visualizer type ───────────────── + if (state.type) { + const selected = visualizers.find(v => v.name === state.type); + if (!selected) { + const available = visualizers.map(v => v.name).join(', '); + throw new Error( + `selectVisualizationStep: No visualizer found with name "${state.type}". Available: ${available}`, + ); + } + debug('fast-path: using explicit type=%s', state.type); + return {visualizer: selected, visualizerName: selected.name}; + } + + // ── LLM-selection path ─────────────────────────────────────────────────── + context.writer?.({ + type: LLMStreamEventType.ToolStatus, + data: {status: 'Selecting best visualization for the data'}, + }); + + const vizList = visualizers + .map(v => `- ${v.name}: ${v.description}`) + .join('\n'); + + const prompt = SELECT_VISUALIZATION_PROMPT.replace('{prompt}', state.prompt) + .replace('{sql}', state.sql ?? '') + .replace('{description}', state.queryDescription ?? '') + .replace('{visualizations}', vizList); + + debug('Calling LLM for visualization selection'); + const {text: rawOutput, usage} = await generateText({ + model: deps.llm, + prompt, + }); + context.onUsage?.(usage.inputTokens ?? 0, usage.outputTokens ?? 0, 'unknown'); + debug('token usage captured', { + promptTokens: usage.inputTokens ?? 0, + completionTokens: usage.outputTokens ?? 0, + }); + const output = stripThinkingFromText(rawOutput).trim(); + debug('LLM selection result: %s', output); + + // LLM returned "none: " — no suitable visualization + if (output.startsWith('none')) { + return {error: output.substring(4).trim()}; + } + + const selected = visualizers.find(v => v.name === output); + if (!selected) { + const available = visualizers.map(v => v.name).join(', '); + throw new Error( + `selectVisualizationStep: LLM returned unknown visualizer "${output}". Available: ${available}`, + ); + } + + debug('Selected visualizer: %s', selected.name); + return {visualizer: selected, visualizerName: selected.name}; +} diff --git a/src/models/message.model.ts b/src/models/message.model.ts index 59570ba..5e75bc2 100644 --- a/src/models/message.model.ts +++ b/src/models/message.model.ts @@ -1,6 +1,6 @@ import {hasMany, model, property} from '@loopback/repository'; import {Message as SourceloopMessage} from '@sourceloop/chat-service'; -import {MessageMetadata} from '../graphs'; +import {MessageMetadata} from '../services/chat-metadata.type'; @model({ name: 'messages', }) diff --git a/src/providers/tools.provider.ts b/src/providers/tools.provider.ts index 96f17b3..7dcb912 100644 --- a/src/providers/tools.provider.ts +++ b/src/providers/tools.provider.ts @@ -5,7 +5,7 @@ import { injectable, Provider, } from '@loopback/core'; -import {IGraphTool} from '../graphs/types'; +import {IGraphTool} from '../types/tool'; import {ToolStore} from '../types'; @injectable({scope: BindingScope.REQUEST}) diff --git a/src/providers/vector-stores/inmemory.vector.ts b/src/providers/vector-stores/inmemory.vector.ts index 3ab6938..6897dde 100644 --- a/src/providers/vector-stores/inmemory.vector.ts +++ b/src/providers/vector-stores/inmemory.vector.ts @@ -1,16 +1,48 @@ -import {inject, Provider, ValueOrPromise} from '@loopback/core'; -import {MemoryVectorStore} from 'langchain/vectorstores/memory'; -import {AiIntegrationBindings} from '../../keys'; -import {EmbeddingProvider} from '../../types'; -import {AnyObject} from '@loopback/repository'; -export class InMemoryVectorStore implements Provider { - constructor( - @inject(AiIntegrationBindings.EmbeddingModel) - private readonly embeddings: EmbeddingProvider, - ) {} - value(): ValueOrPromise { - const memory = new MemoryVectorStore(this.embeddings); - memory.delete = async (params: AnyObject) => {}; - return memory; +import {injectable, BindingScope, Provider} from '@loopback/core'; +import {IVectorStore, IVectorStoreDocument} from '../../types'; + +/** + * Simple in-memory vector store for testing. + * No actual vector embeddings — uses string matching. + */ +@injectable({scope: BindingScope.SINGLETON}) +export class InMemoryVectorStore implements Provider { + value(): IVectorStore { + return new InMemoryVectorStoreImpl(); + } +} + +class InMemoryVectorStoreImpl implements IVectorStore { + private docs: IVectorStoreDocument[] = []; + + async addDocuments(docs: IVectorStoreDocument[]): Promise { + this.docs.push(...docs); + } + + async similaritySearch>( + query: string, + k: number, + filter?: Record, + ): Promise[]> { + let results = this.docs as IVectorStoreDocument[]; + if (filter) { + results = results.filter(doc => + Object.entries(filter).every( + ([key, value]) => + (doc.metadata as Record)[key] === value, + ), + ); + } + return results.slice(0, k); + } + + async delete(params: {filter: Record}): Promise { + this.docs = this.docs.filter( + doc => + !Object.entries(params.filter).every( + ([key, value]) => + (doc.metadata as Record)[key] === value, + ), + ); } } diff --git a/src/graphs/chat/chat-metadata.type.ts b/src/services/chat-metadata.type.ts similarity index 100% rename from src/graphs/chat/chat-metadata.type.ts rename to src/services/chat-metadata.type.ts diff --git a/src/graphs/chat/chat.store.ts b/src/services/chat.store.ts similarity index 65% rename from src/graphs/chat/chat.store.ts rename to src/services/chat.store.ts index 7c7119b..4cf78ed 100644 --- a/src/graphs/chat/chat.store.ts +++ b/src/services/chat.store.ts @@ -1,4 +1,3 @@ -import {AIMessage, HumanMessage, ToolMessage} from '@langchain/core/messages'; import {BindingScope, Getter, inject, injectable} from '@loopback/core'; import { AnyObject, @@ -9,18 +8,31 @@ import { import {HttpErrors} from '@loopback/rest'; import {IAuthUserWithPermissions} from '@sourceloop/core'; import {AuthenticationBindings} from 'loopback4-authentication'; -import {CHAT_TITLE_MAX_LENGTH} from '../../constant'; -import {Chat, Message} from '../../models'; -import {ChatRepository} from '../../repositories'; -import {ChannelType, TokenMetadata} from '../../types'; -import {getTextContent, mergeAttachments} from '../../utils'; -import {SavedMessage} from '../types'; +import {CHAT_TITLE_MAX_LENGTH} from '../constant'; +import {Chat, Message} from '../models'; +import {ChatRepository} from '../repositories'; +import {ChannelType, TokenMetadata} from '../types'; +import {mergeAttachments} from '../utils'; +import {MastraAgentMessage, MastraAssistantContentPart} from '../mastra/types'; import { MessageMetadata, MessageMetadataType, ToolMessageMetadata, } from './chat-metadata.type'; +/** + * Plain-message type returned by `toMessage()`. + * Compatible with AI SDK `CoreMessage` and Mastra `MastraAgentMessage`. + */ +export type SavedMessage = MastraAgentMessage; + +/** + * Persistence service for Chat and Message records. + * + * All methods that previously accepted/returned `@langchain/core/messages` + * types (HumanMessage, AIMessage, ToolMessage) now use plain strings or + * `MastraAgentMessage` — no @langchain dependency. + */ @injectable({scope: BindingScope.REQUEST}) export class ChatStore { constructor( @@ -112,8 +124,12 @@ export class ChatStore { return newMessage; } - async addHumanMessage(chatId: string, message: HumanMessage) { - return this.addMessage(chatId, getTextContent(message.content), { + /** + * Persists a human/user message to the DB. + * Accepts a plain string instead of a `HumanMessage` instance. + */ + async addHumanMessage(chatId: string, prompt: string) { + return this.addMessage(chatId, prompt, { type: MessageMetadataType.User, }); } @@ -139,10 +155,17 @@ export class ChatStore { ); } - async addAIMessage(chatId: string, message: AIMessage) { - let text = getTextContent(message.content); + /** + * Persists an AI message to the DB. + * Accepts plain content + optional tool calls instead of an `AIMessage` instance. + */ + async addAIMessage( + chatId: string, + content: string, + toolCalls?: {id: string; name: string; args: AnyObject}[], + ) { + let text = content; if (!text.trim()) { - // empty message incase the LLM only returns tool calls text = ' '; } return this.addMessage( @@ -155,20 +178,26 @@ export class ChatStore { ); } + /** + * Persists a tool-result message to the DB. + * Accepts structured parameters instead of a `ToolMessage` instance. + */ async addToolMessage( chatId: string, - message: ToolMessage, + toolCallId: string, + toolName: string, + content: string, metadata: AnyObject, aiMessage: Message, args?: AnyObject, ) { return this.addMessage( chatId, - getTextContent(message.content), + content, { type: MessageMetadataType.Tool, - toolName: message.name!, - id: message.tool_call_id, + toolName, + id: toolCallId, args, ...metadata, }, @@ -177,6 +206,10 @@ export class ChatStore { ); } + /** + * Converts a DB `Message` row into a `MastraAgentMessage` (compatible with + * AI SDK `CoreMessage`). Returns `undefined` for unrecognised message types. + */ async toMessage(message: Message): Promise { if (message.metadata?.type === MessageMetadataType.User) { let messageContent = message.body; @@ -189,37 +222,49 @@ export class ChatStore { ); } } - return new HumanMessage({ - content: messageContent, - }); + return {role: 'user', content: messageContent}; } else if (message.metadata?.type === MessageMetadataType.AI) { - const newMessage = new AIMessage(message.body.trim() ?? undefined); - newMessage.tool_calls = + const text = message.body.trim(); + const toolCalls = message.messages ?.filter( - ( - v, - ): v is Message & { - metadata: ToolMessageMetadata; - } => v.metadata.type === MessageMetadataType.Tool, + (v): v is Message & {metadata: ToolMessageMetadata} => + v.metadata.type === MessageMetadataType.Tool, ) - .map(msg => { - return { - id: msg.metadata.id, - name: msg.metadata.toolName, - args: msg.metadata.args ?? {}, - }; - }) ?? []; - return newMessage; + .map(msg => ({ + id: msg.metadata.id, + name: msg.metadata.toolName, + args: msg.metadata.args ?? {}, + })) ?? []; + + if (toolCalls.length > 0) { + const parts: MastraAssistantContentPart[] = []; + if (text) parts.push({type: 'text', text}); + for (const tc of toolCalls) { + parts.push({ + type: 'tool-call', + toolCallId: tc.id, + toolName: tc.name, + args: tc.args as Record, + }); + } + return {role: 'assistant', content: parts}; + } + return {role: 'assistant', content: text}; } else if (message.metadata?.type === MessageMetadataType.Tool) { - return new ToolMessage({ - name: message.metadata.toolName, - content: message.body, - // eslint-disable-next-line @typescript-eslint/naming-convention - tool_call_id: message.metadata.id, - }); + return { + role: 'tool', + content: [ + { + type: 'tool-result', + toolCallId: message.metadata.id, + toolName: message.metadata.toolName, + result: message.body, + }, + ], + }; } else { - // do nothing for other types + return undefined; } } diff --git a/src/services/generation.service.ts b/src/services/generation.service.ts index ff563d0..3f7f7f5 100644 --- a/src/services/generation.service.ts +++ b/src/services/generation.service.ts @@ -1,6 +1,5 @@ import {BindingScope, inject, injectable, service} from '@loopback/core'; import {MastraChatAgent} from '../mastra'; -import {ChatGraph} from '../graphs/chat/chat.graph'; import {AiIntegrationBindings} from '../keys'; import {ITransport} from '../transports/types'; import {AIIntegrationConfig} from '../types'; @@ -9,8 +8,6 @@ import {ILimitStrategy} from './limit-strategies/types'; @injectable({scope: BindingScope.REQUEST}) export class GenerationService { constructor( - @service(ChatGraph) - private readonly chatGraph: ChatGraph, @service(MastraChatAgent) private readonly mastraChatAgent: MastraChatAgent, @inject(AiIntegrationBindings.Transport) @@ -29,29 +26,7 @@ export class GenerationService { abortController.abort(); }); - if (this.aiConfig?.runtime === 'mastra') { - await this._runMastraFlow(prompt, files, abortController.signal, id); - } else { - await this._runLangGraphFlow(prompt, files, abortController.signal, id); - } - } - - private async _runLangGraphFlow( - prompt: string, - files: Express.Multer.File[], - abort: AbortSignal, - id?: string, - ) { - const stream = await this.chatGraph.execute(prompt, files, abort, id); - try { - for await (const chunk of stream) { - await this.transport.send(chunk); - } - await this.transport.end(); - } catch (error) { - await this.transport.end(error); - throw error; - } + await this._runMastraFlow(prompt, files, abortController.signal, id); } private async _runMastraFlow( diff --git a/src/services/index.ts b/src/services/index.ts index 7c60c0f..9aaa49f 100644 --- a/src/services/index.ts +++ b/src/services/index.ts @@ -1,3 +1,5 @@ +export * from './chat-metadata.type'; +export * from './chat.store'; export * from './generation.service'; export * from './mastra-bridge.observer'; export * from './mastra-bridge.service'; diff --git a/src/services/mastra-bridge.service.ts b/src/services/mastra-bridge.service.ts index f6fcbfb..c4b0823 100644 --- a/src/services/mastra-bridge.service.ts +++ b/src/services/mastra-bridge.service.ts @@ -4,7 +4,7 @@ import { TOOL_NAME, TOOL_TAG, } from '../constant'; -import {IGraphNode, IGraphTool} from '../graphs/types'; +import {IGraphNode, IGraphTool} from '../types/tool'; import {AiIntegrationBindings} from '../keys'; import {AnyObject} from '@loopback/repository'; import {Context, inject, injectable, BindingScope} from '@loopback/core'; diff --git a/src/services/token-counter.service.ts b/src/services/token-counter.service.ts index aa5e42b..c8ae032 100644 --- a/src/services/token-counter.service.ts +++ b/src/services/token-counter.service.ts @@ -1,7 +1,7 @@ -import {AIMessage} from '@langchain/core/messages'; -import {LLMResult} from '@langchain/core/outputs'; import {BindingScope, injectable} from '@loopback/core'; +const debug = require('debug')('ai-integration:mastra:token-counter'); + @injectable({scope: BindingScope.REQUEST}) export class TokenCounter { private inputs = 0; @@ -13,41 +13,35 @@ export class TokenCounter { outputTokens: number; } > = new Map(); - private runMap: Map = new Map(); clear() { this.inputs = 0; this.outputs = 0; this.countMap.clear(); - this.runMap.clear(); } - handleLlmStart(runId: string, modelName: string): void { - this.runMap.set(runId, modelName); - } + // ── Mastra path (AI SDK usage field) ──────────────────────────────────── - handleLlmEnd(runId: string, message: LLMResult) { - const llmName = this.runMap.get(runId) ?? 'unknown'; - this.runMap.delete(runId); - const usageMetadata = ( - message.generations[0][0] as unknown as {message: AIMessage} - ).message.usage_metadata; - const prev = this.countMap.get(llmName) ?? { - inputTokens: 0, - outputTokens: 0, - }; - if (usageMetadata) { - this.inputs += usageMetadata.input_tokens ?? 0; - this.outputs += usageMetadata.output_tokens ?? 0; - prev.inputTokens += usageMetadata.input_tokens ?? 0; - prev.outputTokens += usageMetadata.output_tokens ?? 0; - this.countMap.set(llmName, prev); - } - return { - inputTokens: this.inputs, - outputTokens: this.outputs, - }; + /** + * Accumulate token counts directly from an AI SDK `usage` object. + * + * Called by Mastra step functions after every `generateText()` / + * `generateObject()` call — no LangChain callback required. + * + * @param inputTokens - `usage.promptTokens` from the AI SDK response. + * @param outputTokens - `usage.completionTokens` from the AI SDK response. + * @param model - Model identifier (e.g. `deps.llm.modelId`). + */ + accumulate(inputTokens: number, outputTokens: number, model: string): void { + const prev = this.countMap.get(model) ?? {inputTokens: 0, outputTokens: 0}; + this.inputs += inputTokens; + this.outputs += outputTokens; + prev.inputTokens += inputTokens; + prev.outputTokens += outputTokens; + this.countMap.set(model, prev); + debug('token usage captured', {inputTokens, outputTokens, model}); } + getCounts() { return { inputs: this.inputs, diff --git a/src/sub-modules/db/postgresql/vector-store/index.ts b/src/sub-modules/db/postgresql/vector-store/index.ts index a36a224..e736409 100644 --- a/src/sub-modules/db/postgresql/vector-store/index.ts +++ b/src/sub-modules/db/postgresql/vector-store/index.ts @@ -1 +1 @@ -export * from './pgvector.store'; +export * from './pgvector-sdk.store'; diff --git a/src/sub-modules/db/postgresql/vector-store/pgvector-sdk.store.ts b/src/sub-modules/db/postgresql/vector-store/pgvector-sdk.store.ts new file mode 100644 index 0000000..cfc0bc3 --- /dev/null +++ b/src/sub-modules/db/postgresql/vector-store/pgvector-sdk.store.ts @@ -0,0 +1,217 @@ +import {embed, embedMany} from 'ai'; +import { + BindingScope, + inject, + injectable, + Provider, + ValueOrPromise, +} from '@loopback/core'; +import {juggler} from '@loopback/repository'; +import * as pg from 'pg'; +import {AiIntegrationBindings} from '../../../../keys'; +import { + AiSdkEmbeddingModel, + IVectorStore, + IVectorStoreDocument, +} from '../../../../types'; + +const debug = require('debug')('mastra:db:pgvector-sdk'); + +/** + * Name of the pgvector table shared with the LangChain `PgVectorStore`. + * Both implementations target the same underlying schema so cached documents + * written by either path are readable by the other. + */ +const TABLE_NAME = 'semantic_cache'; + +// ─── Internal implementation ────────────────────────────────────────────────── + +/** + * Internal stateless implementation of `IVectorStore` backed by a raw pg `Pool`. + * + * All embedding computation uses the AI SDK `embed()` / `embedMany()` helpers — + * there is zero dependency on `@langchain/core` or any LangChain package. + * + * The table structure mirrors what `@langchain/community/vectorstores/pgvector` + * creates so both execution paths can share the same persistent store: + * + * ```sql + * CREATE TABLE {schema}.semantic_cache ( + * id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + * content text, + * metadata jsonb, + * vector vector + * ); + * ``` + */ +class PgVectorSdkStoreImpl implements IVectorStore { + constructor( + private readonly embeddingModel: AiSdkEmbeddingModel, + private readonly pool: pg.Pool, + private readonly schema: string, + ) {} + + /** + * Persist documents to the vector store. + * + * Embeddings are computed via `embedMany()` in a single batched call to reduce + * round-trips to the embedding API. Each document is then inserted with its + * serialised metadata and pgvector-formatted embedding literal. + * + * @param docs - Array of `{pageContent, metadata}` objects to persist. + */ + async addDocuments(docs: IVectorStoreDocument[]): Promise { + if (docs.length === 0) return; + + debug('addDocuments', {count: docs.length}); + + const {embeddings} = await embedMany({ + model: this.embeddingModel, + values: docs.map(d => d.pageContent), + }); + + const client = await this.pool.connect(); + try { + for (let i = 0; i < docs.length; i++) { + // pgvector expects the literal form "[f1,f2,...]" + const vectorLiteral = `[${embeddings[i].join(',')}]`; + await client.query( + `INSERT INTO ${this.schema}.${TABLE_NAME} (id, content, metadata, vector) + VALUES (gen_random_uuid(), $1, $2::jsonb, $3::vector)`, + [ + docs[i].pageContent, + JSON.stringify(docs[i].metadata), + vectorLiteral, + ], + ); + } + } finally { + client.release(); + } + } + + /** + * Return the `k` most semantically similar documents to `query`. + * + * If `filter` is supplied it is applied as a PostgreSQL JSONB containment check + * (`metadata @> filter::jsonb`) before ranking by cosine distance. + * + * @param query - Natural-language query string. + * @param k - Maximum number of results to return. + * @param filter - Optional key-value pairs that must all appear in document metadata. + */ + async similaritySearch>( + query: string, + k: number, + filter?: Record, + ): Promise[]> { + debug('similaritySearch', {k, filter}); + + const {embedding} = await embed({ + model: this.embeddingModel, + value: query, + }); + + const vectorLiteral = `[${embedding.join(',')}]`; + + const params: unknown[] = [vectorLiteral]; + let filterClause = ''; + if (filter && Object.keys(filter).length > 0) { + params.push(JSON.stringify(filter)); + filterClause = `WHERE metadata @> $2::jsonb`; + } + params.push(k); + + const limitParam = `$${params.length}`; + const sql = ` + SELECT content, metadata + FROM ${this.schema}.${TABLE_NAME} + ${filterClause} + ORDER BY vector <=> $1::vector + LIMIT ${limitParam} + `; + + const {rows} = await this.pool.query(sql, params); + + return rows.map(row => ({ + pageContent: row.content as string, + metadata: row.metadata as T, + })); + } + + /** + * Delete all documents whose metadata satisfies the given JSON containment filter. + * + * @param params.filter - Key-value pairs used to match documents for deletion. + */ + async delete(params: {filter: Record}): Promise { + debug('delete', {filter: params.filter}); + + await this.pool.query( + `DELETE FROM ${this.schema}.${TABLE_NAME} WHERE metadata @> $1::jsonb`, + [JSON.stringify(params.filter)], + ); + } +} + +// ─── LoopBack provider ──────────────────────────────────────────────────────── + +/** + * LoopBack provider that creates and returns a Mastra-path `IVectorStore` backed + * by PostgreSQL + pgvector. + * + * **Binding**: register in your component's `providers` map under + * `AiIntegrationBindings.AiSdkVectorStore`. + * + * **Prerequisites**: + * - The `vector` PostgreSQL extension must be installed. + * - The `semantic_cache` table must exist in the configured schema (it is + * created automatically by `PgVectorStore` on the LangGraph path, or by + * running the project migration). + * - `AiIntegrationBindings.AiSdkEmbeddingModel` must be bound to an AI SDK + * `EmbeddingModel` (e.g. `openai.embedding('text-embedding-3-small')`). + * - The `datasources.writerdb` LoopBack datasource must use the `loopback-connector-postgresql` + * connector so that `connector.pg` exposes a `pg.Pool`. + * + * Environment variables (same as `PgVectorStore`): + * - `DB_HOST`, `DB_PORT`, `DB_USER`, `DB_DATABASE` (validated at startup). + */ +@injectable({scope: BindingScope.SINGLETON}) +export class PgVectorSdkStore implements Provider { + constructor( + @inject(AiIntegrationBindings.AiSdkEmbeddingModel) + private readonly embeddingModel: AiSdkEmbeddingModel, + @inject(`datasources.writerdb`) + private readonly pgDataSource: juggler.DataSource, + ) {} + + /** + * Instantiate and return the vector store implementation. + * + * Reads the pg `Pool` and schema name from the injected LoopBack datasource + * so that the same connection pool is shared with the rest of the application — + * no extra connections are opened. + */ + value(): ValueOrPromise { + if ( + !process.env.DB_HOST || + !process.env.DB_PORT || + !process.env.DB_USER || + !process.env.DB_DATABASE + ) { + throw new Error( + 'Database connection details are not set. ' + + 'Please set DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, and DB_DATABASE environment variables.', + ); + } + + const pool = this.pgDataSource.connector?.pg as pg.Pool; + const dsConfig = this.pgDataSource.connector?.settings as + | {schema?: string} + | undefined; + const schema = dsConfig?.schema ?? 'public'; + + debug('PgVectorSdkStore initialised', {schema}); + return new PgVectorSdkStoreImpl(this.embeddingModel, pool, schema); + } +} diff --git a/src/sub-modules/db/postgresql/vector-store/pgvector.store.ts b/src/sub-modules/db/postgresql/vector-store/pgvector.store.ts deleted file mode 100644 index 0d58297..0000000 --- a/src/sub-modules/db/postgresql/vector-store/pgvector.store.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { - PGVectorStore as PGStore, - PGVectorStoreArgs, -} from '@langchain/community/vectorstores/pgvector'; -import {VectorStore} from '@langchain/core/vectorstores'; -import { - BindingScope, - inject, - injectable, - Provider, - ValueOrPromise, -} from '@loopback/core'; -import * as pg from 'pg'; -import {EmbeddingProvider} from '../../../../types'; -import {AiIntegrationBindings} from '../../../../keys'; -import {juggler} from '@loopback/repository'; -@injectable({scope: BindingScope.SINGLETON}) -export class PgVectorStore implements Provider { - constructor( - @inject(AiIntegrationBindings.EmbeddingModel) - private readonly embeddingModel: EmbeddingProvider, - @inject(`datasources.writerdb`) - private pgDataSource: juggler.DataSource, - ) {} - value(): ValueOrPromise { - if ( - !process.env.DB_HOST || - !process.env.DB_PORT || - !process.env.DB_USER || - !process.env.DB_DATABASE - ) { - throw new Error( - 'Database connection details are not set. Please set DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, and DB_DATABASE environment variables.', - ); - } - const reusablePool = this.pgDataSource.connector?.pg as pg.Pool; - const dsConfig = this.pgDataSource.connector?.settings; - - const config: PGVectorStoreArgs = { - pool: reusablePool, - schemaName: dsConfig.schema || 'public', - tableName: 'semantic_cache', - extensionSchemaName: dsConfig.schema || 'public', - columns: { - idColumnName: 'id', - vectorColumnName: 'vector', - contentColumnName: 'content', - metadataColumnName: 'metadata', - }, - }; - return PGStore.initialize(this.embeddingModel, config); - } -} diff --git a/src/sub-modules/obf/langfuse/index.ts b/src/sub-modules/obf/langfuse/index.ts index 6818574..5667219 100644 --- a/src/sub-modules/obf/langfuse/index.ts +++ b/src/sub-modules/obf/langfuse/index.ts @@ -1,2 +1,4 @@ export * from './langfuse.component'; export * from './langfuse.provider'; +export * from './langfuse-core.provider'; +export * from './langfuse-mastra.component'; diff --git a/src/sub-modules/obf/langfuse/langfuse-core.provider.ts b/src/sub-modules/obf/langfuse/langfuse-core.provider.ts new file mode 100644 index 0000000..257c635 --- /dev/null +++ b/src/sub-modules/obf/langfuse/langfuse-core.provider.ts @@ -0,0 +1,38 @@ +import {LangfuseAPIClient} from '@langfuse/core'; +import {Provider, ValueOrPromise} from '@loopback/core'; + +/** + * Mastra-path Langfuse provider. + * + * Returns a `LangfuseAPIClient` from `@langfuse/core` — zero LangChain dependency. + * This is the Mastra-path equivalent of the LangGraph `LangfuseObfProvider` which + * returns a `CallbackHandler` from `@langfuse/langchain`. + * + * The client exposes the full Langfuse REST API (traces, generations, spans, + * scores, datasets, etc.) and can be used inside Mastra step functions for + * structured observability via `client.ingestion.ingest({batch: [...]})`. + * + * Reads configuration from the standard Langfuse environment variables: + * - `LANGFUSE_PUBLIC_KEY` — project public key (required) + * - `LANGFUSE_SECRET_KEY` — project secret key (required) + * - `LANGFUSE_HOST` — API base URL (optional, defaults to `https://cloud.langfuse.com`) + * + * **Binding**: registered automatically by `LangfuseMastraComponent` under + * `AiIntegrationBindings.LangfuseMastraClient`. + */ +export class LangfuseCoreProvider implements Provider { + value(): ValueOrPromise { + if (!process.env.LANGFUSE_PUBLIC_KEY || !process.env.LANGFUSE_SECRET_KEY) { + throw new Error( + 'LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables must be set ' + + 'to use LangfuseMastraComponent.', + ); + } + + return new LangfuseAPIClient({ + environment: process.env.LANGFUSE_HOST ?? 'https://cloud.langfuse.com', + username: process.env.LANGFUSE_PUBLIC_KEY, + password: process.env.LANGFUSE_SECRET_KEY, + }); + } +} diff --git a/src/sub-modules/obf/langfuse/langfuse-mastra.component.ts b/src/sub-modules/obf/langfuse/langfuse-mastra.component.ts new file mode 100644 index 0000000..4d9da31 --- /dev/null +++ b/src/sub-modules/obf/langfuse/langfuse-mastra.component.ts @@ -0,0 +1,43 @@ +import {Component, ProviderMap} from '@loopback/core'; +import {AiIntegrationBindings} from '../../../keys'; +import {LangfuseCoreProvider} from './langfuse-core.provider'; + +/** + * Mastra-path Langfuse observability component. + * + * Registers a `LangfuseAPIClient` (from `@langfuse/core`) under + * `AiIntegrationBindings.LangfuseMastraClient` for use in the Mastra execution path. + * + * This is the Mastra-path equivalent of `LangfuseComponent` which registers a + * LangChain `CallbackHandler` (from `@langfuse/langchain`) under + * `AiIntegrationBindings.ObfHandler`. Both components can coexist in the same + * application so that LangGraph and Mastra paths each get their own Langfuse client. + * + * **Usage** (in your LoopBack application class): + * ```ts + * this.component(LangfuseMastraComponent); + * ``` + * + * **Prerequisites**: + * - Set `LANGFUSE_PUBLIC_KEY`, `LANGFUSE_SECRET_KEY`, and optionally `LANGFUSE_HOST`. + * + * **Injecting the client in a Mastra step**: + * ```ts + * import {inject} from '@loopback/core'; + * import {LangfuseAPIClient} from '@langfuse/core'; + * import {AiIntegrationBindings} from '../keys'; + * + * // Inside a service constructor: + * @inject(AiIntegrationBindings.LangfuseMastraClient, {optional: true}) + * private readonly langfuse?: LangfuseAPIClient, + * ``` + */ +export class LangfuseMastraComponent implements Component { + providers?: ProviderMap; + + constructor() { + this.providers = { + [AiIntegrationBindings.LangfuseMastraClient.key]: LangfuseCoreProvider, + }; + } +} diff --git a/src/sub-modules/obf/langfuse/langfuse.provider.ts b/src/sub-modules/obf/langfuse/langfuse.provider.ts index bc49641..60601ce 100644 --- a/src/sub-modules/obf/langfuse/langfuse.provider.ts +++ b/src/sub-modules/obf/langfuse/langfuse.provider.ts @@ -1,8 +1,24 @@ +import {LangfuseAPIClient} from '@langfuse/core'; import {Provider, ValueOrPromise} from '@loopback/core'; -import {CallbackHandler} from '@langfuse/langchain'; -export class LangfuseObfProvider implements Provider { - value(): ValueOrPromise { - return new CallbackHandler(); +/** + * Langfuse observability provider (LangChain-free). + * + * Previously returned a `CallbackHandler` from `@langfuse/langchain`. + * Now returns a `LangfuseAPIClient` from `@langfuse/core` — same trigger + * point, same binding, zero LangChain dependency. + * + * Reads configuration from the standard Langfuse environment variables: + * - `LANGFUSE_PUBLIC_KEY` — project public key + * - `LANGFUSE_SECRET_KEY` — project secret key + * - `LANGFUSE_HOST` — API base URL (optional) + */ +export class LangfuseObfProvider implements Provider { + value(): ValueOrPromise { + return new LangfuseAPIClient({ + environment: process.env.LANGFUSE_HOST ?? 'https://cloud.langfuse.com', + username: process.env.LANGFUSE_PUBLIC_KEY, + password: process.env.LANGFUSE_SECRET_KEY, + }); } } diff --git a/src/sub-modules/providers/anthropic/llms/anthropic.provider.ts b/src/sub-modules/providers/anthropic/llms/anthropic.provider.ts deleted file mode 100644 index c44ebde..0000000 --- a/src/sub-modules/providers/anthropic/llms/anthropic.provider.ts +++ /dev/null @@ -1,29 +0,0 @@ -import {AnthropicInput, ChatAnthropic} from '@langchain/anthropic'; -import {Provider, ValueOrPromise} from '@loopback/core'; -import {RuntimeLLMProvider} from '../../../../types'; -import {BaseChatModelParams} from '@langchain/core/language_models/chat_models'; - -export class Claude implements Provider { - value(): ValueOrPromise { - if (!process.env.CLAUDE_MODEL || !process.env.CLAUDE_API_KEY) { - throw new Error( - 'CLAUDE_MODEL and CLAUDE_API_KEY environment variables must be set', - ); - } - const config: AnthropicInput & BaseChatModelParams = { - model: process.env.CLAUDE_MODEL!, - apiKey: process.env.CLAUDE_API_KEY, - }; - if (process.env.CLAUDE_THINKING === 'true') { - config.thinking = { - // eslint-disable-next-line @typescript-eslint/naming-convention - budget_tokens: parseInt(process.env.CLAUDE_THINKING_BUDGET ?? '1024'), - type: process.env.CLAUDE_THINKING === 'true' ? 'enabled' : 'disabled', - }; - } - if (process.env.CLAUDE_TEMPERATURE) { - config.temperature = parseInt(process.env.CLAUDE_TEMPERATURE); - } - return new ChatAnthropic(config); - } -} diff --git a/src/sub-modules/providers/anthropic/llms/claude-sdk.provider.ts b/src/sub-modules/providers/anthropic/llms/claude-sdk.provider.ts new file mode 100644 index 0000000..b948387 --- /dev/null +++ b/src/sub-modules/providers/anthropic/llms/claude-sdk.provider.ts @@ -0,0 +1,32 @@ +import {createAnthropic} from '@ai-sdk/anthropic'; +import {Provider, ValueOrPromise} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; + +/** + * AI SDK (Vercel) provider for Anthropic Claude models. + * + * Returns a `LanguageModel` compatible with `generateText()` / `generateObject()` + * from the `ai` package. Bind to `AiIntegrationBindings.AiSdkSmartLLM` (or the + * other `AiSdk*` keys) for use in the Mastra db-query workflow nodes. + * + * Environment variables: + * - `CLAUDE_MODEL` — model id, e.g. `claude-3-5-sonnet-20241022` + * - `CLAUDE_API_KEY` — Anthropic API key + * - `CLAUDE_THINKING` — set to `'true'` to enable extended thinking + * - `CLAUDE_THINKING_BUDGET` — token budget for extended thinking (default 1024) + * - `CLAUDE_TEMPERATURE` — optional temperature override (0–1) + */ +export class ClaudeSdk implements Provider { + value(): ValueOrPromise { + if (!process.env.CLAUDE_MODEL || !process.env.CLAUDE_API_KEY) { + throw new Error( + 'CLAUDE_MODEL and CLAUDE_API_KEY environment variables must be set', + ); + } + const anthropic = createAnthropic({ + apiKey: process.env.CLAUDE_API_KEY, + }); + // thinking / temperature are passed per-call via providerOptions in generateText() + return anthropic(process.env.CLAUDE_MODEL); + } +} diff --git a/src/sub-modules/providers/anthropic/llms/index.ts b/src/sub-modules/providers/anthropic/llms/index.ts index cdf1df1..6ca1cc8 100644 --- a/src/sub-modules/providers/anthropic/llms/index.ts +++ b/src/sub-modules/providers/anthropic/llms/index.ts @@ -1 +1,2 @@ -export * from './anthropic.provider'; +export * from './claude-sdk.provider'; +export {ClaudeSdk as Claude} from './claude-sdk.provider'; diff --git a/src/sub-modules/providers/aws/embedding/bedrock-embedding-sdk.provider.ts b/src/sub-modules/providers/aws/embedding/bedrock-embedding-sdk.provider.ts new file mode 100644 index 0000000..f5eedf9 --- /dev/null +++ b/src/sub-modules/providers/aws/embedding/bedrock-embedding-sdk.provider.ts @@ -0,0 +1,25 @@ +import {createAmazonBedrock} from '@ai-sdk/amazon-bedrock'; +import {Provider, ValueOrPromise} from '@loopback/core'; +import {EmbeddingModel} from 'ai'; + +/** + * AI SDK embedding provider for AWS Bedrock (Titan). + * + * Environment variables: + * - `BEDROCK_EMBEDDING_MODEL` — e.g. `amazon.titan-embed-text-v2:0` + * - `BEDROCK_AWS_REGION` + * - `BEDROCK_AWS_ACCESS_KEY_ID` + * - `BEDROCK_AWS_SECRET_ACCESS_KEY` + */ +export class BedrockEmbeddingSdk implements Provider { + value(): ValueOrPromise { + const model = + process.env.BEDROCK_EMBEDDING_MODEL ?? 'amazon.titan-embed-text-v2:0'; + const bedrock = createAmazonBedrock({ + region: process.env.BEDROCK_AWS_REGION ?? 'us-east-1', + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + }); + return bedrock.embedding(model); + } +} diff --git a/src/sub-modules/providers/aws/embedding/bedrock-embedding.provider.ts b/src/sub-modules/providers/aws/embedding/bedrock-embedding.provider.ts deleted file mode 100644 index c746617..0000000 --- a/src/sub-modules/providers/aws/embedding/bedrock-embedding.provider.ts +++ /dev/null @@ -1,21 +0,0 @@ -import {BedrockEmbeddings} from '@langchain/aws'; -import {Provider, ValueOrPromise} from '@loopback/core'; -import {EmbeddingProvider} from '../../../../types'; - -export class BedrockEmbedding implements Provider { - value(): ValueOrPromise { - if (!process.env.BEDROCK_EMBEDDING_MODEL) { - throw new Error( - 'BEDROCK_EMBEDDING_MODEL environment variable is not set', - ); - } - return new BedrockEmbeddings({ - region: process.env.BEDROCK_AWS_REGION!, - credentials: { - accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, - secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, - }, - model: process.env.BEDROCK_EMBEDDING_MODEL!, - }); - } -} diff --git a/src/sub-modules/providers/aws/embedding/index.ts b/src/sub-modules/providers/aws/embedding/index.ts index 71550c8..dc56347 100644 --- a/src/sub-modules/providers/aws/embedding/index.ts +++ b/src/sub-modules/providers/aws/embedding/index.ts @@ -1 +1,2 @@ -export * from './bedrock-embedding.provider'; +export * from './bedrock-embedding-sdk.provider'; +export {BedrockEmbeddingSdk as BedrockEmbedding} from './bedrock-embedding-sdk.provider'; diff --git a/src/sub-modules/providers/aws/index.ts b/src/sub-modules/providers/aws/index.ts index ac3fbb3..9060a4e 100644 --- a/src/sub-modules/providers/aws/index.ts +++ b/src/sub-modules/providers/aws/index.ts @@ -1,3 +1,2 @@ export * from './embedding'; export * from './llms'; -export * from './types'; diff --git a/src/sub-modules/providers/aws/llms/bedrock-non-thinking-sdk.provider.ts b/src/sub-modules/providers/aws/llms/bedrock-non-thinking-sdk.provider.ts new file mode 100644 index 0000000..298ff09 --- /dev/null +++ b/src/sub-modules/providers/aws/llms/bedrock-non-thinking-sdk.provider.ts @@ -0,0 +1,19 @@ +import {Provider, ValueOrPromise} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; +import {BedrockSdk} from './bedrock-sdk.provider'; + +/** + * AI SDK (Vercel) provider for AWS Bedrock models with extended thinking disabled. + * + * Identical to `BedrockSdk` but always passes `thinking: false` so the model + * runs without the reasoning budget. Bind to + * `AiIntegrationBindings.AiSdkSmartNonThinkingLLM` or `AiSdkCheapLLM` as needed. + */ +export class BedrockNonThinkingSdk + extends BedrockSdk + implements Provider +{ + value(): ValueOrPromise { + return this._createInstance(false); + } +} diff --git a/src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts b/src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts deleted file mode 100644 index 06fcd75..0000000 --- a/src/sub-modules/providers/aws/llms/bedrock-non-thinking.provider.ts +++ /dev/null @@ -1,12 +0,0 @@ -import {Provider, ValueOrPromise} from '@loopback/core'; -import {RuntimeLLMProvider} from '../../../../types'; -import {Bedrock} from './bedrock.provider'; - -export class BedrockNonThinking - extends Bedrock - implements Provider -{ - value(): ValueOrPromise { - return this._createdInstance(false); - } -} diff --git a/src/sub-modules/providers/aws/llms/bedrock-sdk.provider.ts b/src/sub-modules/providers/aws/llms/bedrock-sdk.provider.ts new file mode 100644 index 0000000..5ed5358 --- /dev/null +++ b/src/sub-modules/providers/aws/llms/bedrock-sdk.provider.ts @@ -0,0 +1,39 @@ +import {createAmazonBedrock} from '@ai-sdk/amazon-bedrock'; +import {Provider, ValueOrPromise} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; + +/** + * AI SDK (Vercel) provider for AWS Bedrock models. + * + * Returns a `LanguageModel` compatible with `generateText()` / `generateObject()` + * from the `ai` package. Bind to `AiIntegrationBindings.AiSdkSmartLLM` (or the + * other `AiSdk*` keys) for use in the Mastra db-query workflow nodes. + * + * Environment variables: + * - `BEDROCK_MODEL` — model id, e.g. `anthropic.claude-3-5-sonnet-20241022-v2:0` + * - `BEDROCK_AWS_REGION` — AWS region + * - `BEDROCK_AWS_ACCESS_KEY_ID` — AWS access key + * - `BEDROCK_AWS_SECRET_ACCESS_KEY` — AWS secret key + * - `CLAUDE_THINKING` — set to `'true'` to enable extended thinking + * - `CLAUDE_THINKING_BUDGET` — token budget for extended thinking (default 1024) + */ +export class BedrockSdk implements Provider { + value(): ValueOrPromise { + return this._createInstance(true); + } + + protected _createInstance(thinking: boolean): LLMProvider { + if (!process.env.BEDROCK_MODEL) { + throw new Error('BEDROCK_MODEL environment variable is not set'); + } + const bedrock = createAmazonBedrock({ + region: process.env.BEDROCK_AWS_REGION ?? 'us-east-1', + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + }); + // reasoning_config is passed per-call via providerOptions in generateText() + // thinking flag is reserved for Phase 3 generateText() call sites + thinking; // eslint-disable-line no-unused-expressions + return bedrock(process.env.BEDROCK_MODEL); + } +} diff --git a/src/sub-modules/providers/aws/llms/bedrock.provider.ts b/src/sub-modules/providers/aws/llms/bedrock.provider.ts deleted file mode 100644 index d82da1d..0000000 --- a/src/sub-modules/providers/aws/llms/bedrock.provider.ts +++ /dev/null @@ -1,61 +0,0 @@ -import {ChatBedrockConverse, ChatBedrockConverseInput} from '@langchain/aws'; -import {Provider, ValueOrPromise} from '@loopback/core'; -import {RuntimeLLMProvider} from '../../../../types'; -import {sanitizeFilenameForAwsConverse} from '../utils'; -import {BedrockInstanceConfig} from '../types'; - -export class Bedrock implements Provider { - static createInstance(config: BedrockInstanceConfig): ChatBedrockConverse { - const client = new ChatBedrockConverse(config); - (client as unknown as RuntimeLLMProvider).getFile = ( - file: Express.Multer.File, - ) => { - return { - type: 'document', - document: { - format: 'pdf', - name: sanitizeFilenameForAwsConverse(file.originalname), - source: { - bytes: file.buffer, - }, - }, - }; - }; - return client; - } - value(): ValueOrPromise { - return this._createdInstance(true); - } - - protected _createdInstance(thinking: boolean) { - if (!process.env.BEDROCK_MODEL) { - throw new Error( - 'Bedrock model is not specified. Please set the BEDROCK_MODEL environment variable.', - ); - } - const config: ChatBedrockConverseInput = { - model: process.env.BEDROCK_MODEL!, - region: process.env.BEDROCK_AWS_REGION, - credentials: { - accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, - secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, - }, - }; - if (process.env.CLAUDE_THINKING && thinking) { - config.additionalModelRequestFields = { - // eslint-disable-next-line @typescript-eslint/naming-convention - reasoning_config: { - type: 'enabled', - // eslint-disable-next-line @typescript-eslint/naming-convention - budget_tokens: parseInt(process.env.CLAUDE_THINKING_BUDGET ?? '1024'), - }, - }; - } else { - config.temperature = parseInt(process.env.BEDROCK_TEMPERATURE ?? '0'); - } - return Bedrock.createInstance({ - model: process.env.BEDROCK_MODEL!, - ...config, - }); - } -} diff --git a/src/sub-modules/providers/aws/llms/index.ts b/src/sub-modules/providers/aws/llms/index.ts index c574640..21331b1 100644 --- a/src/sub-modules/providers/aws/llms/index.ts +++ b/src/sub-modules/providers/aws/llms/index.ts @@ -1,2 +1,4 @@ -export * from './bedrock.provider'; -export * from './bedrock-non-thinking.provider'; +export * from './bedrock-sdk.provider'; +export * from './bedrock-non-thinking-sdk.provider'; +export {BedrockSdk as Bedrock} from './bedrock-sdk.provider'; +export {BedrockNonThinkingSdk as BedrockNonThinking} from './bedrock-non-thinking-sdk.provider'; diff --git a/src/sub-modules/providers/aws/types.ts b/src/sub-modules/providers/aws/types.ts deleted file mode 100644 index 34bb4f0..0000000 --- a/src/sub-modules/providers/aws/types.ts +++ /dev/null @@ -1,6 +0,0 @@ -import {ChatBedrockConverseInput} from '@langchain/aws'; - -export type BedrockInstanceConfig = { - model: string; - config?: Partial; -}; diff --git a/src/sub-modules/providers/cerebras/llm/cerebras-sdk.provider.ts b/src/sub-modules/providers/cerebras/llm/cerebras-sdk.provider.ts new file mode 100644 index 0000000..3895092 --- /dev/null +++ b/src/sub-modules/providers/cerebras/llm/cerebras-sdk.provider.ts @@ -0,0 +1,31 @@ +import {createCerebras} from '@ai-sdk/cerebras'; +import {Provider} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; + +/** + * AI SDK (Vercel) provider for Cerebras models. + * + * Returns a `LanguageModel` compatible with `generateText()` / `generateObject()` + * from the `ai` package. Bind to `AiIntegrationBindings.AiSdkCheapLLM` (or the + * other `AiSdk*` keys) for use in the Mastra db-query workflow nodes. + * + * Environment variables: + * - `CEREBRAS_MODEL` — model id, e.g. `llama-4-scout-17b-16e-instruct` + * - `CEREBRAS_KEY` — Cerebras API key + * - `CEREBRAS_TEMPERATURE` — optional temperature (default 0) + * - `CEREBRAS_MAX_TOKENS` — optional max tokens + */ +export class CerebrasSdk implements Provider { + value(): LLMProvider { + if (!process.env.CEREBRAS_MODEL || !process.env.CEREBRAS_KEY) { + throw new Error( + 'CEREBRAS_MODEL and CEREBRAS_KEY environment variables must be set', + ); + } + const cerebras = createCerebras({ + apiKey: process.env.CEREBRAS_KEY, + }); + // temperature / maxTokens are passed per-call in generateText() + return cerebras(process.env.CEREBRAS_MODEL); + } +} diff --git a/src/sub-modules/providers/cerebras/llm/cerebras.provider.ts b/src/sub-modules/providers/cerebras/llm/cerebras.provider.ts deleted file mode 100644 index ee03ac7..0000000 --- a/src/sub-modules/providers/cerebras/llm/cerebras.provider.ts +++ /dev/null @@ -1,25 +0,0 @@ -import {ChatCerebras, ChatCerebrasInput} from '@langchain/cerebras'; -import {Provider} from '@loopback/core'; -import {RuntimeLLMProvider} from '../../../../types'; - -export class Cerebras implements Provider { - value() { - if (!process.env.CEREBRAS_MODEL || !process.env.CEREBRAS_KEY) { - throw new Error( - 'CEREBRAS_MODEL and CEREBRAS_KEY environment variable is not set.', - ); - } - const config: ChatCerebrasInput = { - temperature: parseFloat(process.env.CEREBRAS_TEMPERATURE ?? '0'), - model: process.env.CEREBRAS_MODEL, - apiKey: process.env.CEREBRAS_KEY, // Default value. - }; - if (process.env.CEREBRAS_TOP_P) { - config.topP = parseFloat(process.env.CEREBRAS_TOP_P); - } - if (process.env.CEREBRAS_MAX_TOKENS) { - config.maxCompletionTokens = parseInt(process.env.CEREBRAS_MAX_TOKENS); - } - return new ChatCerebras(config); - } -} diff --git a/src/sub-modules/providers/cerebras/llm/index.ts b/src/sub-modules/providers/cerebras/llm/index.ts index e40c9ce..e35ca40 100644 --- a/src/sub-modules/providers/cerebras/llm/index.ts +++ b/src/sub-modules/providers/cerebras/llm/index.ts @@ -1 +1,2 @@ -export * from './cerebras.provider'; +export * from './cerebras-sdk.provider'; +export {CerebrasSdk as Cerebras} from './cerebras-sdk.provider'; diff --git a/src/sub-modules/providers/google/embedding/gemini-embedding-sdk.provider.ts b/src/sub-modules/providers/google/embedding/gemini-embedding-sdk.provider.ts new file mode 100644 index 0000000..a01e49f --- /dev/null +++ b/src/sub-modules/providers/google/embedding/gemini-embedding-sdk.provider.ts @@ -0,0 +1,20 @@ +import {createGoogleGenerativeAI} from '@ai-sdk/google'; +import {Provider, ValueOrPromise} from '@loopback/core'; +import {EmbeddingModel} from 'ai'; + +/** + * AI SDK embedding provider for Google Gemini. + * + * Environment variables: + * - `GEMINI_EMBEDDING_MODEL` — e.g. `text-embedding-004` + * - `GOOGLE_GENERATIVE_AI_API_KEY` + */ +export class GeminiEmbeddingSdk implements Provider { + value(): ValueOrPromise { + const model = process.env.GEMINI_EMBEDDING_MODEL ?? 'text-embedding-004'; + const google = createGoogleGenerativeAI({ + apiKey: process.env.GOOGLE_GENERATIVE_AI_API_KEY, + }); + return google.textEmbeddingModel(model); + } +} diff --git a/src/sub-modules/providers/google/embedding/gemini-embedding.provider.ts b/src/sub-modules/providers/google/embedding/gemini-embedding.provider.ts deleted file mode 100644 index dc0bae1..0000000 --- a/src/sub-modules/providers/google/embedding/gemini-embedding.provider.ts +++ /dev/null @@ -1,20 +0,0 @@ -import {TaskType} from '@google/generative-ai'; -import {GoogleGenerativeAIEmbeddings} from '@langchain/google-genai'; -import {Provider} from '@loopback/core'; -import {EmbeddingProvider} from '../../../../types'; - -export class GeminiEmbedding implements Provider { - value() { - if (!process.env.GOOGLE_EMBEDDING_MODEL || !process.env.GOOGLE_API_KEY) { - throw new Error( - 'Google embedding model is not specified. Please set the GOOGLE_EMBEDDING_MODEL environment variable.', - ); - } - - return new GoogleGenerativeAIEmbeddings({ - model: process.env.GOOGLE_EMBEDDING_MODEL!, - taskType: TaskType.RETRIEVAL_DOCUMENT, - title: process.env.GOOGLE_EMBEDDING_TITLE ?? 'AI Integration Embedding', - }); - } -} diff --git a/src/sub-modules/providers/google/embedding/index.ts b/src/sub-modules/providers/google/embedding/index.ts index 83c4fb7..a598761 100644 --- a/src/sub-modules/providers/google/embedding/index.ts +++ b/src/sub-modules/providers/google/embedding/index.ts @@ -1 +1,2 @@ -export * from './gemini-embedding.provider'; +export * from './gemini-embedding-sdk.provider'; +export {GeminiEmbeddingSdk as GeminiEmbedding} from './gemini-embedding-sdk.provider'; diff --git a/src/sub-modules/providers/google/llms/gemini-sdk.provider.ts b/src/sub-modules/providers/google/llms/gemini-sdk.provider.ts new file mode 100644 index 0000000..117c05e --- /dev/null +++ b/src/sub-modules/providers/google/llms/gemini-sdk.provider.ts @@ -0,0 +1,28 @@ +import {createGoogleGenerativeAI} from '@ai-sdk/google'; +import {Provider} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; + +/** + * AI SDK (Vercel) provider for Google Gemini models. + * + * Returns a `LanguageModel` compatible with `generateText()` / `generateObject()` + * from the `ai` package. Bind to `AiIntegrationBindings.AiSdkSmartLLM` (or the + * other `AiSdk*` keys) for use in the Mastra db-query workflow nodes. + * + * Environment variables: + * - `GOOGLE_CHAT_MODEL` — model id, e.g. `gemini-2.0-flash` + * - `GOOGLE_API_KEY` — Google Generative AI API key + */ +export class GeminiSdk implements Provider { + value(): LLMProvider { + if (!process.env.GOOGLE_CHAT_MODEL || !process.env.GOOGLE_API_KEY) { + throw new Error( + 'GOOGLE_CHAT_MODEL and GOOGLE_API_KEY environment variables must be set', + ); + } + const google = createGoogleGenerativeAI({ + apiKey: process.env.GOOGLE_API_KEY, + }); + return google(process.env.GOOGLE_CHAT_MODEL); + } +} diff --git a/src/sub-modules/providers/google/llms/gemini.provider.ts b/src/sub-modules/providers/google/llms/gemini.provider.ts deleted file mode 100644 index 2c91cb8..0000000 --- a/src/sub-modules/providers/google/llms/gemini.provider.ts +++ /dev/null @@ -1,17 +0,0 @@ -import {Provider} from '@loopback/core'; -import {RuntimeLLMProvider} from '../../../../types'; -import {ChatGoogleGenerativeAI} from '@langchain/google-genai'; - -export class Gemini implements Provider { - value() { - if (!process.env.GOOGLE_CHAT_MODEL || !process.env.GOOGLE_API_KEY) { - throw new Error( - 'Google chat model is not specified. Please set the GOOGLE_CHAT_MODEL and GOOGLE_API_KEY environment variables.', - ); - } - - return new ChatGoogleGenerativeAI({ - model: process.env.GOOGLE_CHAT_MODEL!, - }); - } -} diff --git a/src/sub-modules/providers/google/llms/index.ts b/src/sub-modules/providers/google/llms/index.ts index 50009e2..93ab1cd 100644 --- a/src/sub-modules/providers/google/llms/index.ts +++ b/src/sub-modules/providers/google/llms/index.ts @@ -1 +1,2 @@ -export * from './gemini.provider'; +export * from './gemini-sdk.provider'; +export {GeminiSdk as Gemini} from './gemini-sdk.provider'; diff --git a/src/sub-modules/providers/groq/llms/groq-sdk.provider.ts b/src/sub-modules/providers/groq/llms/groq-sdk.provider.ts new file mode 100644 index 0000000..37e546e --- /dev/null +++ b/src/sub-modules/providers/groq/llms/groq-sdk.provider.ts @@ -0,0 +1,30 @@ +import {createGroq} from '@ai-sdk/groq'; +import {Provider} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; + +/** + * AI SDK (Vercel) provider for Groq models. + * + * Returns a `LanguageModel` compatible with `generateText()` / `generateObject()` + * from the `ai` package. Bind to `AiIntegrationBindings.AiSdkCheapLLM` (or the + * other `AiSdk*` keys) for use in the Mastra db-query workflow nodes. + * + * Environment variables: + * - `GROQ_MODEL` — model id, e.g. `llama-3.3-70b-versatile` + * - `GROQ_API_KEY` — Groq API key + * - `GROQ_TEMPERATURE` — optional temperature (default 0) + */ +export class GroqSdk implements Provider { + value(): LLMProvider { + if (!process.env.GROQ_MODEL || !process.env.GROQ_API_KEY) { + throw new Error( + 'GROQ_MODEL and GROQ_API_KEY environment variables must be set', + ); + } + const groq = createGroq({ + apiKey: process.env.GROQ_API_KEY, + }); + // temperature is passed per-call in generateText() + return groq(process.env.GROQ_MODEL); + } +} diff --git a/src/sub-modules/providers/groq/llms/groq.provider.ts b/src/sub-modules/providers/groq/llms/groq.provider.ts deleted file mode 100644 index 5818ae7..0000000 --- a/src/sub-modules/providers/groq/llms/groq.provider.ts +++ /dev/null @@ -1,18 +0,0 @@ -import {Provider} from '@loopback/core'; -import {ChatGroq} from '@langchain/groq'; -import {RuntimeLLMProvider} from '../../../../types'; - -export class Groq implements Provider { - value(): RuntimeLLMProvider { - if (!process.env.GROQ_MODEL || !process.env.GROQ_API_KEY) { - throw new Error( - 'GROQ_MODEL and GROQ_API_KEY environment variable is not set.', - ); - } - return new ChatGroq({ - model: 'llama-3.3-70b-versatile', - temperature: 0, - maxTokens: undefined, - }); - } -} diff --git a/src/sub-modules/providers/groq/llms/index.ts b/src/sub-modules/providers/groq/llms/index.ts index 25de54d..d5a6fb2 100644 --- a/src/sub-modules/providers/groq/llms/index.ts +++ b/src/sub-modules/providers/groq/llms/index.ts @@ -1 +1 @@ -export * from './groq.provider'; +export * from './groq-sdk.provider'; diff --git a/src/sub-modules/providers/ollama/embedding/index.ts b/src/sub-modules/providers/ollama/embedding/index.ts index ce41b1b..bd42173 100644 --- a/src/sub-modules/providers/ollama/embedding/index.ts +++ b/src/sub-modules/providers/ollama/embedding/index.ts @@ -1 +1,2 @@ -export * from './ollama-embedding.provider'; +export * from './ollama-embedding-sdk.provider'; +export {OllamaEmbeddingSdk as OllamaEmbedding} from './ollama-embedding-sdk.provider'; diff --git a/src/sub-modules/providers/ollama/embedding/ollama-embedding-sdk.provider.ts b/src/sub-modules/providers/ollama/embedding/ollama-embedding-sdk.provider.ts new file mode 100644 index 0000000..a204cbe --- /dev/null +++ b/src/sub-modules/providers/ollama/embedding/ollama-embedding-sdk.provider.ts @@ -0,0 +1,27 @@ +import {createOpenAI} from '@ai-sdk/openai'; +import {Provider, ValueOrPromise} from '@loopback/core'; +import {EmbeddingModel} from 'ai'; + +/** + * AI SDK embedding provider for Ollama. + * + * Uses @ai-sdk/openai pointed at Ollama's OpenAI-compatible endpoint + * because `ollama-ai-provider` only implements spec v1, + * which is unsupported by AI SDK 6. + * + * Environment variables: + * - `OLLAMA_EMBEDDING_MODEL` — e.g. `nomic-embed-text` + * - `OLLAMA_BASE_URL` — default `http://localhost:11434/api` + */ +export class OllamaEmbeddingSdk implements Provider { + value(): ValueOrPromise { + const model = process.env.OLLAMA_EMBEDDING_MODEL ?? 'nomic-embed-text'; + const baseURL = process.env.OLLAMA_BASE_URL ?? 'http://localhost:11434'; + // Ollama exposes an OpenAI-compatible API at /v1 + const openai = createOpenAI({ + baseURL: `${baseURL.replace(/\/api$/, '')}/v1`, + apiKey: 'ollama', // Ollama doesn't require a real key + }); + return openai.embedding(model) as unknown as EmbeddingModel; + } +} diff --git a/src/sub-modules/providers/ollama/embedding/ollama-embedding.provider.ts b/src/sub-modules/providers/ollama/embedding/ollama-embedding.provider.ts deleted file mode 100644 index 08de3b6..0000000 --- a/src/sub-modules/providers/ollama/embedding/ollama-embedding.provider.ts +++ /dev/null @@ -1,15 +0,0 @@ -import {OllamaEmbeddings} from '@langchain/ollama'; -import {Provider, ValueOrPromise} from '@loopback/core'; -import {EmbeddingProvider} from '../../../../types'; - -export class OllamaEmbedding implements Provider { - value(): ValueOrPromise { - if (!process.env.OLLAMA_EMBEDDING_MODEL) { - throw new Error('OLLAMA_EMBEDDING_MODEL environment variable is not set'); - } - return new OllamaEmbeddings({ - model: process.env.OLLAMA_EMBEDDING_MODEL!, - baseUrl: process.env.OLLAMA_URL ?? 'http://localhost:11434', - }); - } -} diff --git a/src/sub-modules/providers/ollama/llms/index.ts b/src/sub-modules/providers/ollama/llms/index.ts index 06cf381..8f8c691 100644 --- a/src/sub-modules/providers/ollama/llms/index.ts +++ b/src/sub-modules/providers/ollama/llms/index.ts @@ -1 +1,2 @@ -export * from './ollama.provider'; +export * from './ollama-sdk.provider'; +export {OllamaSdk as Ollama} from './ollama-sdk.provider'; diff --git a/src/sub-modules/providers/ollama/llms/ollama-sdk.provider.ts b/src/sub-modules/providers/ollama/llms/ollama-sdk.provider.ts new file mode 100644 index 0000000..3383165 --- /dev/null +++ b/src/sub-modules/providers/ollama/llms/ollama-sdk.provider.ts @@ -0,0 +1,30 @@ +import {createOllama} from 'ollama-ai-provider'; +import {Provider, ValueOrPromise} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; + +/** + * AI SDK (Vercel) provider for Ollama models. + * + * Returns a `LanguageModel` compatible with `generateText()` / `generateObject()` + * from the `ai` package. Bind to `AiIntegrationBindings.AiSdkSmartLLM` (or the + * other `AiSdk*` keys) for use in the Mastra db-query workflow nodes. + * + * Environment variables: + * - `OLLAMA_MODEL` — model id, e.g. `llama3.2` + * - `OLLAMA_BASE_URL` — Ollama server URL, e.g. `http://localhost:11434` + */ +export class OllamaSdk implements Provider { + value(): ValueOrPromise { + if (!process.env.OLLAMA_MODEL || !process.env.OLLAMA_BASE_URL) { + throw new Error( + 'OLLAMA_MODEL and OLLAMA_BASE_URL environment variables must be set', + ); + } + const ollama = createOllama({ + baseURL: `${process.env.OLLAMA_BASE_URL}/api`, + }); + // ollama-ai-provider returns LanguageModelV1; cast is safe — it satisfies the + // LanguageModel contract at runtime even though types diverge in ai v6. + return ollama(process.env.OLLAMA_MODEL) as unknown as LLMProvider; + } +} diff --git a/src/sub-modules/providers/ollama/llms/ollama.provider.ts b/src/sub-modules/providers/ollama/llms/ollama.provider.ts deleted file mode 100644 index 0fd79d4..0000000 --- a/src/sub-modules/providers/ollama/llms/ollama.provider.ts +++ /dev/null @@ -1,16 +0,0 @@ -import {ChatOllama} from '@langchain/ollama'; -import {Provider, ValueOrPromise} from '@loopback/core'; - -export class Ollama implements Provider { - value(): ValueOrPromise { - if (!process.env.OLLAMA_MODEL || !process.env.OLLAMA_BASE_URL) { - throw new Error( - 'OLLAMA_MODEL and OLLAMA_BASE_URL environment variables must be set', - ); - } - return new ChatOllama({ - model: process.env.OLLAMA_MODEL, - baseUrl: process.env.OLLAMA_BASE_URL, - }); - } -} diff --git a/src/sub-modules/providers/openai/llms/index.ts b/src/sub-modules/providers/openai/llms/index.ts index 9c3eeea..9e58e1c 100644 --- a/src/sub-modules/providers/openai/llms/index.ts +++ b/src/sub-modules/providers/openai/llms/index.ts @@ -1 +1,2 @@ -export * from './openai.provider'; +export * from './openai-sdk.provider'; +export {OpenAISdk as OpenAI} from './openai-sdk.provider'; diff --git a/src/sub-modules/providers/openai/llms/openai-sdk.provider.ts b/src/sub-modules/providers/openai/llms/openai-sdk.provider.ts new file mode 100644 index 0000000..d42f943 --- /dev/null +++ b/src/sub-modules/providers/openai/llms/openai-sdk.provider.ts @@ -0,0 +1,57 @@ +import {createOpenAI} from '@ai-sdk/openai'; +import {Provider} from '@loopback/core'; +import {LLMProvider} from '../../../../types'; + +export type OpenAISdkInstanceConfig = { + model: string; + config?: { + apiKey?: string; + baseURL?: string; + temperature?: number; + reasoningEffort?: 'low' | 'medium' | 'high'; + reasoningSummary?: 'auto' | 'concise' | 'detailed'; + /** OpenRouter / custom provider overrides */ + configuration?: { + baseURL?: string; + [key: string]: unknown; + }; + reasoning?: { + effort?: string | null; + summary?: string | null; + }; + modelKwargs?: Record; + }; +}; + +/** + * AI SDK (Vercel) provider for OpenAI models. + * + * Returns a `LanguageModel` compatible with `generateText()` / `generateObject()` + * from the `ai` package. Bind to `AiIntegrationBindings.AiSdkSmartLLM` (or the + * other `AiSdk*` keys) for use in the Mastra db-query workflow nodes. + */ +export class OpenAISdk implements Provider { + static createInstance(config: OpenAISdkInstanceConfig): LLMProvider { + const openai = createOpenAI({ + apiKey: config.config?.apiKey ?? process.env.OPENAI_API_KEY, + baseURL: config.config?.baseURL ?? config.config?.configuration?.baseURL, + }); + // temperature / reasoningEffort are passed per-call in generateText() — not at model creation + return openai(config.model); + } + + value(): LLMProvider { + if (!process.env.OPENAI_MODEL) { + throw new Error('OPENAI_MODEL environment variable is not set'); + } + return OpenAISdk.createInstance({ + model: process.env.OPENAI_MODEL, + config: { + temperature: process.env.OPENAI_TEMPERATURE + ? Number.parseFloat(process.env.OPENAI_TEMPERATURE) + : undefined, + baseURL: process.env.OPENAI_API_BASE_URL, + }, + }); + } +} diff --git a/src/sub-modules/providers/openai/llms/openai.provider.ts b/src/sub-modules/providers/openai/llms/openai.provider.ts deleted file mode 100644 index a194263..0000000 --- a/src/sub-modules/providers/openai/llms/openai.provider.ts +++ /dev/null @@ -1,24 +0,0 @@ -import {Provider} from '@loopback/core'; -import {RuntimeLLMProvider} from '../../../../types'; -import {ChatOpenAI} from '@langchain/openai'; -import {OpenAIInstanceConfig} from '../types'; - -export class OpenAI implements Provider { - static createInstance(config: OpenAIInstanceConfig): ChatOpenAI { - return new ChatOpenAI({ - model: config.model, - ...config.config, - }); - } - value(): RuntimeLLMProvider { - return OpenAI.createInstance({ - model: process.env.OPENAI_MODEL!, - config: { - temperature: Number.parseFloat(process.env.OPENAI_TEMPERATURE ?? '0'), - configuration: { - baseURL: process.env.OPENAI_API_BASE_URL, - }, - }, - }); - } -} diff --git a/src/sub-modules/providers/openai/types.ts b/src/sub-modules/providers/openai/types.ts deleted file mode 100644 index a7f0b57..0000000 --- a/src/sub-modules/providers/openai/types.ts +++ /dev/null @@ -1,6 +0,0 @@ -import {ChatOpenAIFields} from '@langchain/openai'; - -export type OpenAIInstanceConfig = { - model: string; - config: ChatOpenAIFields; -}; diff --git a/src/transports/http.transport.ts b/src/transports/http.transport.ts index ebcdb1c..c4c1847 100644 --- a/src/transports/http.transport.ts +++ b/src/transports/http.transport.ts @@ -1,7 +1,7 @@ import {BindingScope, inject, injectable} from '@loopback/core'; import {HttpErrors, Request, Response, RestBindings} from '@loopback/rest'; import {STATUS_CODE} from '@sourceloop/core'; -import {LLMStreamEvent, LLMStreamEventType} from '../graphs/event.types'; +import {LLMStreamEvent, LLMStreamEventType} from '../types/events'; import {ITransport} from './types'; const debug = require('debug')('ai-integration:log-events'); diff --git a/src/transports/sse.transport.ts b/src/transports/sse.transport.ts index cc8f931..67e572a 100644 --- a/src/transports/sse.transport.ts +++ b/src/transports/sse.transport.ts @@ -1,7 +1,7 @@ import {BindingScope, inject, injectable} from '@loopback/core'; import {HttpErrors, Request, Response, RestBindings} from '@loopback/rest'; import {STATUS_CODE} from '@sourceloop/core'; -import {LLMStreamEvent, LLMStreamEventType} from '../graphs/event.types'; +import {LLMStreamEvent, LLMStreamEventType} from '../types/events'; import {ITransport} from './types'; const debug = require('debug')('ai-integration:log-events'); diff --git a/src/transports/types.ts b/src/transports/types.ts index 33f08a8..88cd281 100644 --- a/src/transports/types.ts +++ b/src/transports/types.ts @@ -1,5 +1,5 @@ import {AnyObject} from '@loopback/repository'; -import {LLMStreamEvent} from '../graphs/event.types'; +import {LLMStreamEvent} from '../types/events'; export interface ITransport { start(): Promise; diff --git a/src/types.ts b/src/types.ts index b7295f9..e743644 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,13 +1,7 @@ -import {BedrockEmbeddings} from '@langchain/aws'; -import {GoogleGenerativeAIEmbeddings} from '@langchain/google-genai'; -import {OllamaEmbeddings} from '@langchain/ollama'; -import {OpenAIEmbeddings} from '@langchain/openai'; -import {LanguageModel} from 'ai'; +import {EmbeddingModel, LanguageModel} from 'ai'; import {Provider} from '@loopback/core'; import {AnyObject} from '@loopback/repository'; -import {AIMessage} from '@langchain/core/messages'; -import {RunnableConfig, RunnableInterface} from '@langchain/core/runnables'; -import {IGraphTool} from './graphs/types'; +import {IGraphTool} from './types/tool'; export enum SupportedDBs { PostgreSQL = 'PostgreSQL', @@ -18,7 +12,6 @@ export enum SupportedDBs { * Global component configuration consumed by the LoopBack integration component. */ export type AIIntegrationConfig = { - runtime?: RuntimeEngine; useCustomSequence?: boolean; mountCore?: boolean; mountFileUtils?: boolean; @@ -34,11 +27,6 @@ export type AIIntegrationConfig = { }; }; -/** - * Runtime engine selector used for phased migration and rollbacks. - */ -export type RuntimeEngine = 'langgraph' | 'mastra'; - export type FileMessageBuilder = (file: Express.Multer.File) => AnyObject; /** @@ -47,62 +35,57 @@ export type FileMessageBuilder = (file: Express.Multer.File) => AnyObject; export type LLMProvider = LanguageModel; /** - * Legacy LangGraph-compatible LLM contract used by existing graph implementations. - * - * The structure intentionally mirrors the methods used by current nodes so concrete - * LangChain chat models remain assignable without direct class dependencies. + * @deprecated Use `AiSdkEmbeddingModel` instead. Kept for backward compatibility. */ -export type LegacyLLMProvider = { - bindTools( - tools: unknown[], - ): RunnableInterface< - unknown, - AIMessage, - RunnableConfig> - >; - invoke(input: unknown): Promise; - withStructuredOutput( - schema: unknown, - ): RunnableInterface>>; - getFile?: FileMessageBuilder; -} & RunnableInterface< - unknown, - AIMessage, - RunnableConfig> ->; +export type EmbeddingProvider = EmbeddingModel; /** - * Adapter contract for converting an AI SDK model into the legacy tool-calling interface - * while LangGraph execution remains active. + * AI SDK embedding model type for the Mastra execution path. + * Zero LangChain dependency — use with `embed()` / `embedMany()` from `'ai'`. + * Bind an instance to `AiIntegrationBindings.AiSdkEmbeddingModel`. */ -export interface ILegacyLLMProviderAdapter { - toLegacyLLMProvider(): LegacyLLMProvider; -} +export type AiSdkEmbeddingModel = EmbeddingModel; /** - * Runtime-compatible union used by existing LangGraph execution paths during migration. + * Mastra-path vector store document. + * + * Property names mirror `DocumentInterface` from `@langchain/core/documents` so that + * existing Mastra step callers (`doc.pageContent`, `doc.metadata`) need no changes. */ -export type RuntimeLLMProvider = LegacyLLMProvider & - Partial; +export interface IVectorStoreDocument> { + /** The textual content of the document. */ + pageContent: string; + /** Arbitrary key-value metadata attached to the document. */ + metadata: T; +} /** - * Resolves a runtime-compatible provider into a legacy execution contract. + * Mastra-compatible vector store contract — zero LangChain dependency. + * + * Implemented by `PgVectorSdkStore` for the Mastra execution path. + * The LangGraph path continues to use `VectorStore` from `@langchain/core/vectorstores`. */ -export function resolveLegacyLLMProvider( - provider: RuntimeLLMProvider, -): LegacyLLMProvider { - if (provider.toLegacyLLMProvider) { - return provider.toLegacyLLMProvider(); - } - return provider; +export interface IVectorStore { + /** + * Persist documents (text + metadata) to the underlying store. + * Embeddings are computed internally via the configured AI SDK embedding model. + */ + addDocuments(docs: IVectorStoreDocument[]): Promise; + /** + * Return the `k` most semantically similar documents to `query`, + * optionally filtered by `filter` (matched against document metadata via JSON containment). + */ + similaritySearch>( + query: string, + k: number, + filter?: Record, + ): Promise[]>; + /** + * Delete all documents whose metadata contains every key-value pair in `params.filter`. + */ + delete(params: {filter: Record}): Promise; } -export type EmbeddingProvider = - | OpenAIEmbeddings - | OllamaEmbeddings - | BedrockEmbeddings - | GoogleGenerativeAIEmbeddings; - /** * Runtime persistence contract used by workflow/checkpoint adapters. */ diff --git a/src/graphs/event.types.ts b/src/types/events.ts similarity index 80% rename from src/graphs/event.types.ts rename to src/types/events.ts index 84def04..131dc83 100644 --- a/src/graphs/event.types.ts +++ b/src/types/events.ts @@ -1,5 +1,14 @@ +/** + * SSE stream event types emitted by both the Mastra and legacy execution paths. + * + * Moved from src/graphs/event.types.ts so this file has no dependency on + * @langchain/* or @langchain/langgraph. + */ import {AnyObject} from '@loopback/repository'; +// Re-export ToolStatus so consumers can import everything from one place. +export {ToolStatus} from './tool'; + export enum LLMStreamEventType { Message = 'message', Error = 'error', diff --git a/src/types/tool.ts b/src/types/tool.ts new file mode 100644 index 0000000..f9708ab --- /dev/null +++ b/src/types/tool.ts @@ -0,0 +1,103 @@ +/** + * Shared tool and runtime-config types used by both the Mastra workflow steps + * and the public extension API. + * + * Moved from src/graphs/types.ts — no @langchain/* imports. + */ +import {AnyObject, Command} from '@loopback/repository'; + +/** + * Runtime-agnostic execution config that can carry stream writers and runtime metadata. + */ +export type RunnableConfig = { + configurable?: Record; + signal?: AbortSignal; + writer?: (chunk: unknown) => void; +}; + +/** + * Node step execution function signature. + */ +export type GraphStepExecuteFn = ( + state: T, + config: RunnableConfig, +) => Promise | Command>; + +/** + * Minimal step contract. + */ +export interface IGraphStep { + execute: GraphStepExecuteFn; +} + +/** + * Graph node contract supporting Mastra-style `createStep`. + */ +export interface IGraphNode { + createStep?(config?: RunnableConfig): Promise> | IGraphStep; + execute?: GraphStepExecuteFn; +} + +/** + * Minimal runtime tool contract shared across the Mastra execution path. + */ +export interface IRuntimeTool { + name: string; + description?: string; + schema?: unknown; + invoke(input: TArgs): Promise; +} + +/** + * Tool contract supporting Mastra-style `createTool` and legacy `build` for compatibility. + */ +export interface IGraphTool { + key: string; + description?: string; + inputSchema?: unknown; + createTool?(config: RunnableConfig): Promise; + /** @deprecated Use `createTool()`. */ + build?(config: RunnableConfig): Promise; + getValue?(result: unknown): string; + getMetadata?(result: unknown): AnyObject; + needsReview?: boolean; +} + +/** + * Resolves a runtime tool from the contract while preserving legacy fallback. + */ +export async function resolveGraphTool( + tool: IGraphTool, + config: RunnableConfig, +): Promise { + if (tool.createTool) { + return tool.createTool(config); + } + + if (tool.build) { + return tool.build(config); + } + + throw new Error(`Tool ${tool.key} does not implement createTool().`); +} + +export type IGraphDirectEdge = { + from: string; + to: string; +}; + +export type IGraphConditionalEdge = { + from: string; + toList: string[]; + branchingFunction(state: T): string; +}; + +export type IGraphEdge = + | IGraphDirectEdge + | IGraphConditionalEdge; + +export enum ToolStatus { + Running = 'running', + Completed = 'completed', + Failed = 'failed', +} diff --git a/src/utils.ts b/src/utils.ts index 1173863..ff2840b 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,23 +1,25 @@ -import { - AIMessage, - MessageContent, - MessageContentComplex, - MessageContentText, -} from '@langchain/core/messages'; - -export function isTextContent( - content: MessageContent | MessageContentComplex | string, -): content is MessageContentText { +export function getTextContent(content: string | unknown): string { if (typeof content === 'string') { - return true; - } - if ((content as MessageContentText).text !== undefined) { - return true; + return content; } if (Array.isArray(content)) { - return content.filter(v => v.type === 'text').every(isTextContent); + return content + .map((c: unknown) => { + if (typeof c === 'string') return c; + if ( + c !== null && + typeof c === 'object' && + 'text' in c && + typeof (c as {text: unknown}).text === 'string' + ) { + return (c as {text: string}).text; + } + return ''; + }) + .filter(Boolean) + .join(''); } - return false; + return ''; } export function mergeAttachments( @@ -30,39 +32,24 @@ summary of file - ${fileName}: ${summary}`; } -export function getTextContent(content: MessageContent | string): string { - if (typeof content === 'string') { - return content; - } - if (isTextContent(content)) { - return typeof content === 'string' - ? content - : content - .map(c => (isTextContent(c) ? c.text : '')) - .filter(v => !!v) - .join(''); - } - return ''; -} - -export function stripThinkingTokens(text: AIMessage): string { - const message = getTextContent(text.content ?? text); +/** + * Strips `` / `` tags from a plain text string. + * Previously accepted `AIMessage` — now accepts a plain string so this + * function has zero dependency on @langchain. + */ +export function stripThinkingTokens(text: string): string { // remove all the content between and tags - let stripped = message.replace(/.*?<\/think(ing)?>/gs, ''); - // also strip any string that ends with or + let stripped = text.replace(/.*?<\/think(ing)?>/gs, ''); + // also strip any string that ends with or stripped = stripped.replace(/.*?<\/think(ing)?>/gs, ''); return stripped.trim(); } -export function approxTokenCounter(content: MessageContent): number { +export function approxTokenCounter(content: string | unknown): number { const text = getTextContent(content); // Approximate token count: 1 token ~ 4 characters // This is a rough estimate, actual tokenization may vary - if (typeof text === 'string') { - return Math.ceil(text.length / 4); - } - - return 0; + return Math.ceil(text.length / 4); } export function numericEnumValues(enumType: Object) {