-
Notifications
You must be signed in to change notification settings - Fork 1
Medical RAG Pipeline with Multi-Hop Search #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import '../models/memory_node.dart'; | ||
|
|
||
| /// Builder for medical-specific RAG prompts. | ||
| class MedicalPromptBuilder { | ||
| /// Default medical safety disclaimer in Spanish. | ||
| static const String spanishDisclaimer = | ||
| 'AVISO MÉDICO: Esta respuesta es generada por una IA y tiene fines informativos únicamente. ' | ||
| 'No sustituye el consejo, diagnóstico o tratamiento médico profesional. ' | ||
| 'Siempre busque el consejo de su médico u otro proveedor de salud calificado.'; | ||
|
|
||
| /// Constructs a RAG prompt with the given query and context nodes. | ||
| String buildRagPrompt({ | ||
| required String query, | ||
| required List<MemoryNode> contextNodes, | ||
| String language = 'es', | ||
| }) { | ||
| final contextBuffer = StringBuffer(); | ||
| for (var i = 0; i < contextNodes.length; i++) { | ||
| contextBuffer.writeln('[${i + 1}] ${contextNodes[i].content}'); | ||
| } | ||
|
|
||
| if (language == 'es') { | ||
| return '''Utiliza la siguiente información de contexto para responder a la pregunta médica de forma precisa. | ||
| Si la información no es suficiente, indícalo. | ||
|
|
||
| CONTEXTO: | ||
| ${contextBuffer.toString()} | ||
|
|
||
| PREGUNTA: $query | ||
|
|
||
| RESPUESTA (incluye citas numéricas como [1] si corresponde):'''; | ||
| } else { | ||
| return '''Use the following context to answer the medical question accurately. | ||
| If the information is not sufficient, please state it. | ||
|
|
||
| CONTEXT: | ||
| ${contextBuffer.toString()} | ||
|
|
||
| QUESTION: $query | ||
|
|
||
| RESPONSE (include numeric citations like [1] if applicable):'''; | ||
| } | ||
| } | ||
|
|
||
| /// Builds a prompt for query decomposition. | ||
| String buildDecompositionPrompt(String query) { | ||
| return '''Divide la siguiente consulta médica compleja en sub-preguntas más simples que puedan ser buscadas de forma independiente. | ||
| Responde solo con la lista de preguntas, una por línea. | ||
|
|
||
| CONSULTA: $query | ||
|
|
||
| SUB-PREGUNTAS:'''; | ||
| } | ||
|
|
||
| /// Wraps a response with the medical disclaimer. | ||
| String wrapWithDisclaimer(String response, {String language = 'es'}) { | ||
| final disclaimer = language == 'es' | ||
| ? spanishDisclaimer | ||
| : 'MEDICAL DISCLAIMER: This response is AI-generated and for informational purposes only. ' | ||
| 'It is not a substitute for professional medical advice, diagnosis, or treatment.'; | ||
|
|
||
| return '$response\n\n---\n$disclaimer'; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,184 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| import '../memory_graph.dart'; | ||||||||||||||||||||||||||||||||||||||||||||||
| import '../llm_adapter.dart'; | ||||||||||||||||||||||||||||||||||||||||||||||
| import '../models/memory_node.dart'; | ||||||||||||||||||||||||||||||||||||||||||||||
| import '../reranking_strategy.dart'; | ||||||||||||||||||||||||||||||||||||||||||||||
| import 'memory_pipeline.dart'; | ||||||||||||||||||||||||||||||||||||||||||||||
| import 'medical_prompt_builder.dart'; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| /// Medical RAG Pipeline implementation. | ||||||||||||||||||||||||||||||||||||||||||||||
| class MedicalRagPipeline implements MemoryPipeline<String, RagContext> { | ||||||||||||||||||||||||||||||||||||||||||||||
| final MemoryGraph memoryGraph; | ||||||||||||||||||||||||||||||||||||||||||||||
| final LLMAdapter llm; | ||||||||||||||||||||||||||||||||||||||||||||||
| final ReRankingStrategy? reranker; | ||||||||||||||||||||||||||||||||||||||||||||||
| final MedicalPromptBuilder promptBuilder; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| MedicalRagPipeline({ | ||||||||||||||||||||||||||||||||||||||||||||||
| required this.memoryGraph, | ||||||||||||||||||||||||||||||||||||||||||||||
| required this.llm, | ||||||||||||||||||||||||||||||||||||||||||||||
| this.reranker, | ||||||||||||||||||||||||||||||||||||||||||||||
| MedicalPromptBuilder? promptBuilder, | ||||||||||||||||||||||||||||||||||||||||||||||
| }) : promptBuilder = promptBuilder ?? MedicalPromptBuilder(); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||||||
| Future<RagContext> execute(String query) async { | ||||||||||||||||||||||||||||||||||||||||||||||
| var context = RagContext(originalQuery: query); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // 1. Normalization | ||||||||||||||||||||||||||||||||||||||||||||||
| context = await QueryNormalizationStage().process(context); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // 2. Decomposition (optional, for complex queries) | ||||||||||||||||||||||||||||||||||||||||||||||
| context = await MedicalQueryDecompositionStage(llm, promptBuilder) | ||||||||||||||||||||||||||||||||||||||||||||||
| .process(context); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // 3. Hybrid Retrieval | ||||||||||||||||||||||||||||||||||||||||||||||
| context = await HybridRetrievalStage(memoryGraph).process(context); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // 4. Re-ranking | ||||||||||||||||||||||||||||||||||||||||||||||
| if (reranker != null) { | ||||||||||||||||||||||||||||||||||||||||||||||
| context = await ReRankingStage(reranker!).process(context); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // 5. Generation | ||||||||||||||||||||||||||||||||||||||||||||||
| final prompt = promptBuilder.buildRagPrompt( | ||||||||||||||||||||||||||||||||||||||||||||||
| query: context.currentQuery, | ||||||||||||||||||||||||||||||||||||||||||||||
| contextNodes: context.retrievedNodes.map((r) => r.node).toList(), | ||||||||||||||||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||||||||||||||||
| final response = await llm.generate(prompt); | ||||||||||||||||||||||||||||||||||||||||||||||
| context.generatedResponse = response; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // 6. Evidence Citation | ||||||||||||||||||||||||||||||||||||||||||||||
| context = await EvidenceCitationStage().process(context); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| // Add disclaimer | ||||||||||||||||||||||||||||||||||||||||||||||
| context.generatedResponse = | ||||||||||||||||||||||||||||||||||||||||||||||
| promptBuilder.wrapWithDisclaimer(context.generatedResponse!); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return context; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| /// Stage to expand medical abbreviations. | ||||||||||||||||||||||||||||||||||||||||||||||
| class QueryNormalizationStage extends PipelineStage<RagContext, RagContext> { | ||||||||||||||||||||||||||||||||||||||||||||||
| static const Map<String, String> abbreviations = { | ||||||||||||||||||||||||||||||||||||||||||||||
| 'TA': 'tensión arterial', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'HTA': 'hipertensión arterial', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'DM': 'diabetes mellitus', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'IMC': 'índice de masa corporal', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'EPOC': 'enfermedad pulmonar obstructiva crónica', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'SNC': 'sistema nervioso central', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'PCR': 'proteína C reactiva', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'ECG': 'electrocardiograma', | ||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||||||
| Future<RagContext> process(RagContext context) async { | ||||||||||||||||||||||||||||||||||||||||||||||
| String normalized = context.currentQuery; | ||||||||||||||||||||||||||||||||||||||||||||||
| abbreviations.forEach((abbrev, expansion) { | ||||||||||||||||||||||||||||||||||||||||||||||
| // Simple regex-like replacement for whole words | ||||||||||||||||||||||||||||||||||||||||||||||
| final regex = RegExp('\\b$abbrev\\b', caseSensitive: false); | ||||||||||||||||||||||||||||||||||||||||||||||
| normalized = normalized.replaceAll(regex, expansion); | ||||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||||
| context.currentQuery = normalized; | ||||||||||||||||||||||||||||||||||||||||||||||
| return context; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| /// Stage to split complex queries using an LLM. | ||||||||||||||||||||||||||||||||||||||||||||||
| class MedicalQueryDecompositionStage | ||||||||||||||||||||||||||||||||||||||||||||||
| extends PipelineStage<RagContext, RagContext> { | ||||||||||||||||||||||||||||||||||||||||||||||
| final LLMAdapter llm; | ||||||||||||||||||||||||||||||||||||||||||||||
| final MedicalPromptBuilder promptBuilder; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| MedicalQueryDecompositionStage(this.llm, this.promptBuilder); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||||||
| Future<RagContext> process(RagContext context) async { | ||||||||||||||||||||||||||||||||||||||||||||||
| // Only decompose if the query is long enough or contains conjunctions | ||||||||||||||||||||||||||||||||||||||||||||||
| if (context.currentQuery.length > 60 || | ||||||||||||||||||||||||||||||||||||||||||||||
| context.currentQuery.contains(' y ') || | ||||||||||||||||||||||||||||||||||||||||||||||
| context.currentQuery.contains(' e ')) { | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+98
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||
| final prompt = | ||||||||||||||||||||||||||||||||||||||||||||||
| promptBuilder.buildDecompositionPrompt(context.currentQuery); | ||||||||||||||||||||||||||||||||||||||||||||||
| final response = await llm.generate(prompt); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| context.decomposedQueries = response | ||||||||||||||||||||||||||||||||||||||||||||||
| .split('\n') | ||||||||||||||||||||||||||||||||||||||||||||||
| .map((s) => s.trim()) | ||||||||||||||||||||||||||||||||||||||||||||||
| .where((s) => s.isNotEmpty && s.length > 5) | ||||||||||||||||||||||||||||||||||||||||||||||
| .toList(); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if (context.decomposedQueries.isEmpty) { | ||||||||||||||||||||||||||||||||||||||||||||||
| context.decomposedQueries = [context.currentQuery]; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return context; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| /// Stage to retrieve nodes using hybrid search. | ||||||||||||||||||||||||||||||||||||||||||||||
| class HybridRetrievalStage extends PipelineStage<RagContext, RagContext> { | ||||||||||||||||||||||||||||||||||||||||||||||
| final MemoryGraph memoryGraph; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| HybridRetrievalStage(this.memoryGraph); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||||||
| Future<RagContext> process(RagContext context) async { | ||||||||||||||||||||||||||||||||||||||||||||||
| final allResults = <int, ({MemoryNode node, double score})>{}; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| for (final q in context.decomposedQueries) { | ||||||||||||||||||||||||||||||||||||||||||||||
| final results = await memoryGraph.hybridSearch(q, topK: 5, alpha: 0.5); | ||||||||||||||||||||||||||||||||||||||||||||||
| for (final res in results) { | ||||||||||||||||||||||||||||||||||||||||||||||
| final existing = allResults[res.node.id]; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (existing == null || res.score > existing.score) { | ||||||||||||||||||||||||||||||||||||||||||||||
| allResults[res.node.id] = res; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+129
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The sub-queries generated during decomposition are processed sequentially. This can lead to high latency in the retrieval stage. Using
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| final sortedResults = allResults.values.toList() | ||||||||||||||||||||||||||||||||||||||||||||||
| ..sort((a, b) => b.score.compareTo(a.score)); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| context.retrievedNodes = sortedResults.take(10).toList(); | ||||||||||||||||||||||||||||||||||||||||||||||
| return context; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| /// Stage to re-rank retrieved nodes. | ||||||||||||||||||||||||||||||||||||||||||||||
| class ReRankingStage extends PipelineStage<RagContext, RagContext> { | ||||||||||||||||||||||||||||||||||||||||||||||
| final ReRankingStrategy reranker; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| ReRankingStage(this.reranker); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||||||
| Future<RagContext> process(RagContext context) async { | ||||||||||||||||||||||||||||||||||||||||||||||
| if (context.retrievedNodes.isEmpty) return context; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| context.retrievedNodes = reranker.reRank( | ||||||||||||||||||||||||||||||||||||||||||||||
| context.retrievedNodes, | ||||||||||||||||||||||||||||||||||||||||||||||
| query: context.currentQuery, | ||||||||||||||||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return context; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| /// Stage to attach source nodes to the generated response. | ||||||||||||||||||||||||||||||||||||||||||||||
| class EvidenceCitationStage extends PipelineStage<RagContext, RagContext> { | ||||||||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||||||
| Future<RagContext> process(RagContext context) async { | ||||||||||||||||||||||||||||||||||||||||||||||
| if (context.generatedResponse == null) return context; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| final cited = <MemoryNode>[]; | ||||||||||||||||||||||||||||||||||||||||||||||
| for (var i = 0; i < context.retrievedNodes.length; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||
| final node = context.retrievedNodes[i].node; | ||||||||||||||||||||||||||||||||||||||||||||||
| final citation = '[${i + 1}]'; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (context.generatedResponse!.contains(citation)) { | ||||||||||||||||||||||||||||||||||||||||||||||
| cited.add(node); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+173
to
+179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The citation matching logic using
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| context.citedNodes = cited; | ||||||||||||||||||||||||||||||||||||||||||||||
| return context; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import '../models/memory_node.dart'; | ||
|
|
||
| /// Base class for a pipeline stage. | ||
| abstract class PipelineStage<I, O> { | ||
| Future<O> process(I input); | ||
| } | ||
|
|
||
| /// A generic pipeline that executes a series of stages. | ||
| abstract class MemoryPipeline<I, O> { | ||
| Future<O> execute(I input); | ||
| } | ||
|
|
||
| /// Data structure for RAG context and results. | ||
| class RagContext { | ||
| final String originalQuery; | ||
| String currentQuery; | ||
| List<String> decomposedQueries; | ||
| List<({MemoryNode node, double score})> retrievedNodes; | ||
| String? generatedResponse; | ||
| List<MemoryNode> citedNodes; | ||
|
|
||
| RagContext({ | ||
| required this.originalQuery, | ||
| this.currentQuery = '', | ||
| this.decomposedQueries = const [], | ||
| this.retrievedNodes = const [], | ||
| this.citedNodes = const [], | ||
| }) { | ||
| currentQuery = originalQuery; | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| import 'package:flutter_test/flutter_test.dart'; | ||
| import 'package:isar/isar.dart'; | ||
| import 'package:isar_agent_memory/isar_agent_memory.dart'; | ||
|
|
||
| class MockLLMAdapter implements LLMAdapter { | ||
| @override | ||
| Future<String> generate(String prompt) async { | ||
| if (prompt.contains('SUB-PREGUNTAS:')) { | ||
| return '¿Qué es la tensión arterial?\n¿Cómo se mide la tensión arterial?'; | ||
| } | ||
| if (prompt.contains('CONTEXTO:')) { | ||
| return 'La tensión arterial es la fuerza de la sangre contra las paredes de las arterias [1].'; | ||
| } | ||
| return 'Mocked response'; | ||
| } | ||
| } | ||
|
|
||
| class MockEmbeddingsAdapter implements EmbeddingsAdapter { | ||
| @override | ||
| String get providerName => 'mock'; | ||
| @override | ||
| int get dimension => 384; | ||
| @override | ||
| Future<List<double>> embed(String text) async { | ||
| return List.generate(384, (i) => (text.length + i) / 1000.0); | ||
| } | ||
| } | ||
|
|
||
| void main() { | ||
| TestWidgetsFlutterBinding.ensureInitialized(); | ||
|
|
||
| late Isar isar; | ||
| late MemoryGraph memoryGraph; | ||
| late MockLLMAdapter llm; | ||
|
|
||
| setUp(() async { | ||
| await Isar.initializeIsarCore(download: true); | ||
| isar = await Isar.open( | ||
| [MemoryNodeSchema, MemoryEdgeSchema], | ||
| directory: '.', | ||
| name: 'medical_test_db', | ||
| ); | ||
| memoryGraph = MemoryGraph(isar, embeddingsAdapter: MockEmbeddingsAdapter()); | ||
| llm = MockLLMAdapter(); | ||
| await isar.writeTxn(() async => await isar.clear()); | ||
| }); | ||
|
|
||
| tearDown(() async { | ||
| await isar.close(deleteFromDisk: true); | ||
| }); | ||
|
|
||
| group('MedicalRagPipeline Stages', () { | ||
| test('QueryNormalizationStage expands abbreviations', () async { | ||
| final stage = QueryNormalizationStage(); | ||
| final context = RagContext(originalQuery: 'Paciente con TA alta y DM'); | ||
|
|
||
| final result = await stage.process(context); | ||
|
|
||
| expect(result.currentQuery, contains('tensión arterial')); | ||
| expect(result.currentQuery, contains('diabetes mellitus')); | ||
| }); | ||
|
|
||
| test('MedicalQueryDecompositionStage splits queries', () async { | ||
| final stage = MedicalQueryDecompositionStage(llm, MedicalPromptBuilder()); | ||
| final context = RagContext( | ||
| originalQuery: | ||
| 'Explique qué es la TA y cómo se mide en pacientes con HTA severa'); | ||
|
|
||
| final result = await stage.process(context); | ||
|
|
||
| expect(result.decomposedQueries.length, greaterThan(1)); | ||
| expect(result.decomposedQueries[0], contains('tensión arterial')); | ||
| }); | ||
| }); | ||
|
|
||
| test('MedicalRagPipeline full execution', () async { | ||
| // Seed some data | ||
| await memoryGraph.storeNode(MemoryNode( | ||
| content: | ||
| 'La tensión arterial es vital para el funcionamiento del cuerpo.', | ||
| type: 'medical', | ||
| )); | ||
|
|
||
| final pipeline = MedicalRagPipeline( | ||
| memoryGraph: memoryGraph, | ||
| llm: llm, | ||
| ); | ||
|
|
||
| final result = await pipeline.execute('¿Qué es la TA?'); | ||
|
|
||
| expect(result.currentQuery, contains('tensión arterial')); | ||
| expect(result.generatedResponse, contains('tensión arterial')); | ||
| expect(result.generatedResponse, contains('AVISO MÉDICO')); | ||
| expect(result.citedNodes, isNotEmpty); | ||
| }); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a new
RegExpinstance for every abbreviation on every call toprocessis inefficient. Since the abbreviations are static, these patterns should be pre-compiled to improve performance, especially if this stage is called frequently.