From 6b75d5434e5cfe4521d529015a8014ee4cc3d464 Mon Sep 17 00:00:00 2001 From: iberi22 <10615454+iberi22@users.noreply.github.com> Date: Thu, 16 Apr 2026 22:32:37 +0000 Subject: [PATCH] feat: implement Medical RAG Pipeline with Multi-Hop Search - Added `MemoryPipeline` and `PipelineStage` base classes for modular RAG flows. - Implemented `MedicalRagPipeline` with stages for normalization, decomposition, hybrid retrieval, re-ranking, and citation. - Created `MedicalPromptBuilder` for medical-specific prompts and safety disclaimers in Spanish. - Added comprehensive unit and integration tests in `test/medical_rag_pipeline_test.dart`. - Exported new components in `lib/isar_agent_memory.dart`. --- lib/isar_agent_memory.dart | 3 + lib/src/rag/medical_prompt_builder.dart | 64 +++++++++ lib/src/rag/medical_rag_pipeline.dart | 184 ++++++++++++++++++++++++ lib/src/rag/memory_pipeline.dart | 31 ++++ test/medical_rag_pipeline_test.dart | 96 +++++++++++++ 5 files changed, 378 insertions(+) create mode 100644 lib/src/rag/medical_prompt_builder.dart create mode 100644 lib/src/rag/medical_rag_pipeline.dart create mode 100644 lib/src/rag/memory_pipeline.dart create mode 100644 test/medical_rag_pipeline_test.dart diff --git a/lib/isar_agent_memory.dart b/lib/isar_agent_memory.dart index 27b0206..49dbed0 100644 --- a/lib/isar_agent_memory.dart +++ b/lib/isar_agent_memory.dart @@ -27,6 +27,9 @@ export 'src/sync/sync_backend.dart'; export 'src/sync/firebase_sync_backend.dart'; export 'src/sync/websocket_sync_backend.dart'; export 'src/sync/cross_device_sync_manager.dart'; +export 'src/rag/memory_pipeline.dart'; +export 'src/rag/medical_rag_pipeline.dart'; +export 'src/rag/medical_prompt_builder.dart'; // Advanced features (v0.5.0) export 'src/memory_consolidation.dart'; diff --git a/lib/src/rag/medical_prompt_builder.dart b/lib/src/rag/medical_prompt_builder.dart new file mode 100644 index 0000000..68c9611 --- /dev/null +++ b/lib/src/rag/medical_prompt_builder.dart @@ -0,0 +1,64 @@ +import '../models/memory_node.dart'; + +/// Builder for medical-specific RAG prompts. +class MedicalPromptBuilder { + /// Default medical safety disclaimer in Spanish. + static const String spanishDisclaimer = + 'AVISO MÉDICO: Esta respuesta es generada por una IA y tiene fines informativos únicamente. ' + 'No sustituye el consejo, diagnóstico o tratamiento médico profesional. ' + 'Siempre busque el consejo de su médico u otro proveedor de salud calificado.'; + + /// Constructs a RAG prompt with the given query and context nodes. + String buildRagPrompt({ + required String query, + required List contextNodes, + String language = 'es', + }) { + final contextBuffer = StringBuffer(); + for (var i = 0; i < contextNodes.length; i++) { + contextBuffer.writeln('[${i + 1}] ${contextNodes[i].content}'); + } + + if (language == 'es') { + return '''Utiliza la siguiente información de contexto para responder a la pregunta médica de forma precisa. +Si la información no es suficiente, indícalo. + +CONTEXTO: +${contextBuffer.toString()} + +PREGUNTA: $query + +RESPUESTA (incluye citas numéricas como [1] si corresponde):'''; + } else { + return '''Use the following context to answer the medical question accurately. +If the information is not sufficient, please state it. + +CONTEXT: +${contextBuffer.toString()} + +QUESTION: $query + +RESPONSE (include numeric citations like [1] if applicable):'''; + } + } + + /// Builds a prompt for query decomposition. + String buildDecompositionPrompt(String query) { + return '''Divide la siguiente consulta médica compleja en sub-preguntas más simples que puedan ser buscadas de forma independiente. +Responde solo con la lista de preguntas, una por línea. + +CONSULTA: $query + +SUB-PREGUNTAS:'''; + } + + /// Wraps a response with the medical disclaimer. + String wrapWithDisclaimer(String response, {String language = 'es'}) { + final disclaimer = language == 'es' + ? spanishDisclaimer + : 'MEDICAL DISCLAIMER: This response is AI-generated and for informational purposes only. ' + 'It is not a substitute for professional medical advice, diagnosis, or treatment.'; + + return '$response\n\n---\n$disclaimer'; + } +} diff --git a/lib/src/rag/medical_rag_pipeline.dart b/lib/src/rag/medical_rag_pipeline.dart new file mode 100644 index 0000000..077e577 --- /dev/null +++ b/lib/src/rag/medical_rag_pipeline.dart @@ -0,0 +1,184 @@ +import '../memory_graph.dart'; +import '../llm_adapter.dart'; +import '../models/memory_node.dart'; +import '../reranking_strategy.dart'; +import 'memory_pipeline.dart'; +import 'medical_prompt_builder.dart'; + +/// Medical RAG Pipeline implementation. +class MedicalRagPipeline implements MemoryPipeline { + final MemoryGraph memoryGraph; + final LLMAdapter llm; + final ReRankingStrategy? reranker; + final MedicalPromptBuilder promptBuilder; + + MedicalRagPipeline({ + required this.memoryGraph, + required this.llm, + this.reranker, + MedicalPromptBuilder? promptBuilder, + }) : promptBuilder = promptBuilder ?? MedicalPromptBuilder(); + + @override + Future execute(String query) async { + var context = RagContext(originalQuery: query); + + // 1. Normalization + context = await QueryNormalizationStage().process(context); + + // 2. Decomposition (optional, for complex queries) + context = await MedicalQueryDecompositionStage(llm, promptBuilder) + .process(context); + + // 3. Hybrid Retrieval + context = await HybridRetrievalStage(memoryGraph).process(context); + + // 4. Re-ranking + if (reranker != null) { + context = await ReRankingStage(reranker!).process(context); + } + + // 5. Generation + final prompt = promptBuilder.buildRagPrompt( + query: context.currentQuery, + contextNodes: context.retrievedNodes.map((r) => r.node).toList(), + ); + final response = await llm.generate(prompt); + context.generatedResponse = response; + + // 6. Evidence Citation + context = await EvidenceCitationStage().process(context); + + // Add disclaimer + context.generatedResponse = + promptBuilder.wrapWithDisclaimer(context.generatedResponse!); + + return context; + } +} + +/// Stage to expand medical abbreviations. +class QueryNormalizationStage extends PipelineStage { + static const Map abbreviations = { + 'TA': 'tensión arterial', + 'HTA': 'hipertensión arterial', + 'DM': 'diabetes mellitus', + 'IMC': 'índice de masa corporal', + 'EPOC': 'enfermedad pulmonar obstructiva crónica', + 'SNC': 'sistema nervioso central', + 'PCR': 'proteína C reactiva', + 'ECG': 'electrocardiograma', + }; + + @override + Future process(RagContext context) async { + String normalized = context.currentQuery; + abbreviations.forEach((abbrev, expansion) { + // Simple regex-like replacement for whole words + final regex = RegExp('\\b$abbrev\\b', caseSensitive: false); + normalized = normalized.replaceAll(regex, expansion); + }); + context.currentQuery = normalized; + return context; + } +} + +/// Stage to split complex queries using an LLM. +class MedicalQueryDecompositionStage + extends PipelineStage { + final LLMAdapter llm; + final MedicalPromptBuilder promptBuilder; + + MedicalQueryDecompositionStage(this.llm, this.promptBuilder); + + @override + Future process(RagContext context) async { + // Only decompose if the query is long enough or contains conjunctions + if (context.currentQuery.length > 60 || + context.currentQuery.contains(' y ') || + context.currentQuery.contains(' e ')) { + final prompt = + promptBuilder.buildDecompositionPrompt(context.currentQuery); + final response = await llm.generate(prompt); + + context.decomposedQueries = response + .split('\n') + .map((s) => s.trim()) + .where((s) => s.isNotEmpty && s.length > 5) + .toList(); + } + + if (context.decomposedQueries.isEmpty) { + context.decomposedQueries = [context.currentQuery]; + } + + return context; + } +} + +/// Stage to retrieve nodes using hybrid search. +class HybridRetrievalStage extends PipelineStage { + final MemoryGraph memoryGraph; + + HybridRetrievalStage(this.memoryGraph); + + @override + Future process(RagContext context) async { + final allResults = {}; + + for (final q in context.decomposedQueries) { + final results = await memoryGraph.hybridSearch(q, topK: 5, alpha: 0.5); + for (final res in results) { + final existing = allResults[res.node.id]; + if (existing == null || res.score > existing.score) { + allResults[res.node.id] = res; + } + } + } + + final sortedResults = allResults.values.toList() + ..sort((a, b) => b.score.compareTo(a.score)); + + context.retrievedNodes = sortedResults.take(10).toList(); + return context; + } +} + +/// Stage to re-rank retrieved nodes. +class ReRankingStage extends PipelineStage { + final ReRankingStrategy reranker; + + ReRankingStage(this.reranker); + + @override + Future process(RagContext context) async { + if (context.retrievedNodes.isEmpty) return context; + + context.retrievedNodes = reranker.reRank( + context.retrievedNodes, + query: context.currentQuery, + ); + + return context; + } +} + +/// Stage to attach source nodes to the generated response. +class EvidenceCitationStage extends PipelineStage { + @override + Future process(RagContext context) async { + if (context.generatedResponse == null) return context; + + final cited = []; + for (var i = 0; i < context.retrievedNodes.length; i++) { + final node = context.retrievedNodes[i].node; + final citation = '[${i + 1}]'; + if (context.generatedResponse!.contains(citation)) { + cited.add(node); + } + } + + context.citedNodes = cited; + return context; + } +} diff --git a/lib/src/rag/memory_pipeline.dart b/lib/src/rag/memory_pipeline.dart new file mode 100644 index 0000000..e264304 --- /dev/null +++ b/lib/src/rag/memory_pipeline.dart @@ -0,0 +1,31 @@ +import '../models/memory_node.dart'; + +/// Base class for a pipeline stage. +abstract class PipelineStage { + Future process(I input); +} + +/// A generic pipeline that executes a series of stages. +abstract class MemoryPipeline { + Future execute(I input); +} + +/// Data structure for RAG context and results. +class RagContext { + final String originalQuery; + String currentQuery; + List decomposedQueries; + List<({MemoryNode node, double score})> retrievedNodes; + String? generatedResponse; + List citedNodes; + + RagContext({ + required this.originalQuery, + this.currentQuery = '', + this.decomposedQueries = const [], + this.retrievedNodes = const [], + this.citedNodes = const [], + }) { + currentQuery = originalQuery; + } +} diff --git a/test/medical_rag_pipeline_test.dart b/test/medical_rag_pipeline_test.dart new file mode 100644 index 0000000..328b891 --- /dev/null +++ b/test/medical_rag_pipeline_test.dart @@ -0,0 +1,96 @@ +import 'package:flutter_test/flutter_test.dart'; +import 'package:isar/isar.dart'; +import 'package:isar_agent_memory/isar_agent_memory.dart'; + +class MockLLMAdapter implements LLMAdapter { + @override + Future generate(String prompt) async { + if (prompt.contains('SUB-PREGUNTAS:')) { + return '¿Qué es la tensión arterial?\n¿Cómo se mide la tensión arterial?'; + } + if (prompt.contains('CONTEXTO:')) { + return 'La tensión arterial es la fuerza de la sangre contra las paredes de las arterias [1].'; + } + return 'Mocked response'; + } +} + +class MockEmbeddingsAdapter implements EmbeddingsAdapter { + @override + String get providerName => 'mock'; + @override + int get dimension => 384; + @override + Future> embed(String text) async { + return List.generate(384, (i) => (text.length + i) / 1000.0); + } +} + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + late Isar isar; + late MemoryGraph memoryGraph; + late MockLLMAdapter llm; + + setUp(() async { + await Isar.initializeIsarCore(download: true); + isar = await Isar.open( + [MemoryNodeSchema, MemoryEdgeSchema], + directory: '.', + name: 'medical_test_db', + ); + memoryGraph = MemoryGraph(isar, embeddingsAdapter: MockEmbeddingsAdapter()); + llm = MockLLMAdapter(); + await isar.writeTxn(() async => await isar.clear()); + }); + + tearDown(() async { + await isar.close(deleteFromDisk: true); + }); + + group('MedicalRagPipeline Stages', () { + test('QueryNormalizationStage expands abbreviations', () async { + final stage = QueryNormalizationStage(); + final context = RagContext(originalQuery: 'Paciente con TA alta y DM'); + + final result = await stage.process(context); + + expect(result.currentQuery, contains('tensión arterial')); + expect(result.currentQuery, contains('diabetes mellitus')); + }); + + test('MedicalQueryDecompositionStage splits queries', () async { + final stage = MedicalQueryDecompositionStage(llm, MedicalPromptBuilder()); + final context = RagContext( + originalQuery: + 'Explique qué es la TA y cómo se mide en pacientes con HTA severa'); + + final result = await stage.process(context); + + expect(result.decomposedQueries.length, greaterThan(1)); + expect(result.decomposedQueries[0], contains('tensión arterial')); + }); + }); + + test('MedicalRagPipeline full execution', () async { + // Seed some data + await memoryGraph.storeNode(MemoryNode( + content: + 'La tensión arterial es vital para el funcionamiento del cuerpo.', + type: 'medical', + )); + + final pipeline = MedicalRagPipeline( + memoryGraph: memoryGraph, + llm: llm, + ); + + final result = await pipeline.execute('¿Qué es la TA?'); + + expect(result.currentQuery, contains('tensión arterial')); + expect(result.generatedResponse, contains('tensión arterial')); + expect(result.generatedResponse, contains('AVISO MÉDICO')); + expect(result.citedNodes, isNotEmpty); + }); +}