Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/isar_agent_memory.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ export 'src/sync/sync_backend.dart';
export 'src/sync/firebase_sync_backend.dart';
export 'src/sync/websocket_sync_backend.dart';
export 'src/sync/cross_device_sync_manager.dart';
export 'src/rag/memory_pipeline.dart';
export 'src/rag/medical_rag_pipeline.dart';
export 'src/rag/medical_prompt_builder.dart';

// Advanced features (v0.5.0)
export 'src/memory_consolidation.dart';
Expand Down
64 changes: 64 additions & 0 deletions lib/src/rag/medical_prompt_builder.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import '../models/memory_node.dart';

/// Builder for medical-specific RAG prompts.
class MedicalPromptBuilder {
/// Default medical safety disclaimer in Spanish.
static const String spanishDisclaimer =
'AVISO MÉDICO: Esta respuesta es generada por una IA y tiene fines informativos únicamente. '
'No sustituye el consejo, diagnóstico o tratamiento médico profesional. '
'Siempre busque el consejo de su médico u otro proveedor de salud calificado.';

/// Constructs a RAG prompt with the given query and context nodes.
String buildRagPrompt({
required String query,
required List<MemoryNode> contextNodes,
String language = 'es',
}) {
final contextBuffer = StringBuffer();
for (var i = 0; i < contextNodes.length; i++) {
contextBuffer.writeln('[${i + 1}] ${contextNodes[i].content}');
}

if (language == 'es') {
return '''Utiliza la siguiente información de contexto para responder a la pregunta médica de forma precisa.
Si la información no es suficiente, indícalo.

CONTEXTO:
${contextBuffer.toString()}

PREGUNTA: $query

RESPUESTA (incluye citas numéricas como [1] si corresponde):''';
} else {
return '''Use the following context to answer the medical question accurately.
If the information is not sufficient, please state it.

CONTEXT:
${contextBuffer.toString()}

QUESTION: $query

RESPONSE (include numeric citations like [1] if applicable):''';
}
}

/// Builds a prompt for query decomposition.
String buildDecompositionPrompt(String query) {
return '''Divide la siguiente consulta médica compleja en sub-preguntas más simples que puedan ser buscadas de forma independiente.
Responde solo con la lista de preguntas, una por línea.

CONSULTA: $query

SUB-PREGUNTAS:''';
}

/// Wraps a response with the medical disclaimer.
String wrapWithDisclaimer(String response, {String language = 'es'}) {
final disclaimer = language == 'es'
? spanishDisclaimer
: 'MEDICAL DISCLAIMER: This response is AI-generated and for informational purposes only. '
'It is not a substitute for professional medical advice, diagnosis, or treatment.';

return '$response\n\n---\n$disclaimer';
}
}
184 changes: 184 additions & 0 deletions lib/src/rag/medical_rag_pipeline.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import '../memory_graph.dart';
import '../llm_adapter.dart';
import '../models/memory_node.dart';
import '../reranking_strategy.dart';
import 'memory_pipeline.dart';
import 'medical_prompt_builder.dart';

/// Medical RAG Pipeline implementation.
class MedicalRagPipeline implements MemoryPipeline<String, RagContext> {
final MemoryGraph memoryGraph;
final LLMAdapter llm;
final ReRankingStrategy? reranker;
final MedicalPromptBuilder promptBuilder;

MedicalRagPipeline({
required this.memoryGraph,
required this.llm,
this.reranker,
MedicalPromptBuilder? promptBuilder,
}) : promptBuilder = promptBuilder ?? MedicalPromptBuilder();

@override
Future<RagContext> execute(String query) async {
var context = RagContext(originalQuery: query);

// 1. Normalization
context = await QueryNormalizationStage().process(context);

// 2. Decomposition (optional, for complex queries)
context = await MedicalQueryDecompositionStage(llm, promptBuilder)
.process(context);

// 3. Hybrid Retrieval
context = await HybridRetrievalStage(memoryGraph).process(context);

// 4. Re-ranking
if (reranker != null) {
context = await ReRankingStage(reranker!).process(context);
}

// 5. Generation
final prompt = promptBuilder.buildRagPrompt(
query: context.currentQuery,
contextNodes: context.retrievedNodes.map((r) => r.node).toList(),
);
final response = await llm.generate(prompt);
context.generatedResponse = response;

// 6. Evidence Citation
context = await EvidenceCitationStage().process(context);

// Add disclaimer
context.generatedResponse =
promptBuilder.wrapWithDisclaimer(context.generatedResponse!);

return context;
}
}

/// Stage to expand medical abbreviations.
class QueryNormalizationStage extends PipelineStage<RagContext, RagContext> {
static const Map<String, String> abbreviations = {
'TA': 'tensión arterial',
'HTA': 'hipertensión arterial',
'DM': 'diabetes mellitus',
'IMC': 'índice de masa corporal',
'EPOC': 'enfermedad pulmonar obstructiva crónica',
'SNC': 'sistema nervioso central',
'PCR': 'proteína C reactiva',
'ECG': 'electrocardiograma',
};

@override
Future<RagContext> process(RagContext context) async {
String normalized = context.currentQuery;
abbreviations.forEach((abbrev, expansion) {
// Simple regex-like replacement for whole words
final regex = RegExp('\\b$abbrev\\b', caseSensitive: false);
normalized = normalized.replaceAll(regex, expansion);
});
context.currentQuery = normalized;
return context;
}
}
Comment on lines +73 to +84
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a new RegExp instance for every abbreviation on every call to process is inefficient. Since the abbreviations are static, these patterns should be pre-compiled to improve performance, especially if this stage is called frequently.

  static final Map<String, RegExp> _regexCache = abbreviations.map(
    (k, v) => MapEntry(k, RegExp(r'\b' + k + r'\b', caseSensitive: false)),
  );

  @override
  Future<RagContext> process(RagContext context) async {
    String normalized = context.currentQuery;
    _regexCache.forEach((abbrev, regex) {
      normalized = normalized.replaceAll(regex, abbreviations[abbrev]!);
    });
    context.currentQuery = normalized;
    return context;
  }


/// Stage to split complex queries using an LLM.
class MedicalQueryDecompositionStage
extends PipelineStage<RagContext, RagContext> {
final LLMAdapter llm;
final MedicalPromptBuilder promptBuilder;

MedicalQueryDecompositionStage(this.llm, this.promptBuilder);

@override
Future<RagContext> process(RagContext context) async {
// Only decompose if the query is long enough or contains conjunctions
if (context.currentQuery.length > 60 ||
context.currentQuery.contains(' y ') ||
context.currentQuery.contains(' e ')) {
Comment on lines +98 to +99
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The decomposition logic relies on hardcoded Spanish conjunctions (' y ', ' e '). This limits the pipeline's effectiveness in other languages, despite other parts of the code attempting to support English. Consider making these triggers language-aware.

final prompt =
promptBuilder.buildDecompositionPrompt(context.currentQuery);
final response = await llm.generate(prompt);

context.decomposedQueries = response
.split('\n')
.map((s) => s.trim())
.where((s) => s.isNotEmpty && s.length > 5)
.toList();
}

if (context.decomposedQueries.isEmpty) {
context.decomposedQueries = [context.currentQuery];
}

return context;
}
}

/// Stage to retrieve nodes using hybrid search.
class HybridRetrievalStage extends PipelineStage<RagContext, RagContext> {
final MemoryGraph memoryGraph;

HybridRetrievalStage(this.memoryGraph);

@override
Future<RagContext> process(RagContext context) async {
final allResults = <int, ({MemoryNode node, double score})>{};

for (final q in context.decomposedQueries) {
final results = await memoryGraph.hybridSearch(q, topK: 5, alpha: 0.5);
for (final res in results) {
final existing = allResults[res.node.id];
if (existing == null || res.score > existing.score) {
allResults[res.node.id] = res;
}
}
}
Comment on lines +129 to +137
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sub-queries generated during decomposition are processed sequentially. This can lead to high latency in the retrieval stage. Using Future.wait to execute these searches in parallel would significantly improve performance.

Suggested change
for (final q in context.decomposedQueries) {
final results = await memoryGraph.hybridSearch(q, topK: 5, alpha: 0.5);
for (final res in results) {
final existing = allResults[res.node.id];
if (existing == null || res.score > existing.score) {
allResults[res.node.id] = res;
}
}
}
final searchFutures = context.decomposedQueries.map(
(q) => memoryGraph.hybridSearch(q, topK: 5, alpha: 0.5),
);
final resultsList = await Future.wait(searchFutures);
for (final results in resultsList) {
for (final res in results) {
final existing = allResults[res.node.id];
if (existing == null || res.score > existing.score) {
allResults[res.node.id] = res;
}
}
}


final sortedResults = allResults.values.toList()
..sort((a, b) => b.score.compareTo(a.score));

context.retrievedNodes = sortedResults.take(10).toList();
return context;
}
}

/// Stage to re-rank retrieved nodes.
class ReRankingStage extends PipelineStage<RagContext, RagContext> {
final ReRankingStrategy reranker;

ReRankingStage(this.reranker);

@override
Future<RagContext> process(RagContext context) async {
if (context.retrievedNodes.isEmpty) return context;

context.retrievedNodes = reranker.reRank(
context.retrievedNodes,
query: context.currentQuery,
);

return context;
}
}

/// Stage to attach source nodes to the generated response.
class EvidenceCitationStage extends PipelineStage<RagContext, RagContext> {
@override
Future<RagContext> process(RagContext context) async {
if (context.generatedResponse == null) return context;

final cited = <MemoryNode>[];
for (var i = 0; i < context.retrievedNodes.length; i++) {
final node = context.retrievedNodes[i].node;
final citation = '[${i + 1}]';
if (context.generatedResponse!.contains(citation)) {
cited.add(node);
}
}
Comment on lines +173 to +179
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The citation matching logic using String.contains is prone to false positives. For example, a check for [1] will return true if the response contains [10], [11], etc. This results in incorrect evidence attribution. Using a regular expression with a negative lookahead for digits ensures that only the exact citation is matched.

Suggested change
for (var i = 0; i < context.retrievedNodes.length; i++) {
final node = context.retrievedNodes[i].node;
final citation = '[${i + 1}]';
if (context.generatedResponse!.contains(citation)) {
cited.add(node);
}
}
for (var i = 0; i < context.retrievedNodes.length; i++) {
final node = context.retrievedNodes[i].node;
final citation = '[${i + 1}]';
// Use regex with negative lookahead to ensure exact match (e.g., avoid matching [1] in [10])
final pattern = RegExp(RegExp.escape(citation) + r'(?!\d)');
if (context.generatedResponse!.contains(pattern)) {
cited.add(node);
}
}


context.citedNodes = cited;
return context;
}
}
31 changes: 31 additions & 0 deletions lib/src/rag/memory_pipeline.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import '../models/memory_node.dart';

/// Base class for a pipeline stage.
abstract class PipelineStage<I, O> {
Future<O> process(I input);
}

/// A generic pipeline that executes a series of stages.
abstract class MemoryPipeline<I, O> {
Future<O> execute(I input);
}

/// Data structure for RAG context and results.
class RagContext {
final String originalQuery;
String currentQuery;
List<String> decomposedQueries;
List<({MemoryNode node, double score})> retrievedNodes;
String? generatedResponse;
List<MemoryNode> citedNodes;

RagContext({
required this.originalQuery,
this.currentQuery = '',
this.decomposedQueries = const [],
this.retrievedNodes = const [],
this.citedNodes = const [],
}) {
currentQuery = originalQuery;
}
}
96 changes: 96 additions & 0 deletions test/medical_rag_pipeline_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import 'package:flutter_test/flutter_test.dart';
import 'package:isar/isar.dart';
import 'package:isar_agent_memory/isar_agent_memory.dart';

class MockLLMAdapter implements LLMAdapter {
@override
Future<String> generate(String prompt) async {
if (prompt.contains('SUB-PREGUNTAS:')) {
return '¿Qué es la tensión arterial?\n¿Cómo se mide la tensión arterial?';
}
if (prompt.contains('CONTEXTO:')) {
return 'La tensión arterial es la fuerza de la sangre contra las paredes de las arterias [1].';
}
return 'Mocked response';
}
}

class MockEmbeddingsAdapter implements EmbeddingsAdapter {
@override
String get providerName => 'mock';
@override
int get dimension => 384;
@override
Future<List<double>> embed(String text) async {
return List.generate(384, (i) => (text.length + i) / 1000.0);
}
}

void main() {
TestWidgetsFlutterBinding.ensureInitialized();

late Isar isar;
late MemoryGraph memoryGraph;
late MockLLMAdapter llm;

setUp(() async {
await Isar.initializeIsarCore(download: true);
isar = await Isar.open(
[MemoryNodeSchema, MemoryEdgeSchema],
directory: '.',
name: 'medical_test_db',
);
memoryGraph = MemoryGraph(isar, embeddingsAdapter: MockEmbeddingsAdapter());
llm = MockLLMAdapter();
await isar.writeTxn(() async => await isar.clear());
});

tearDown(() async {
await isar.close(deleteFromDisk: true);
});

group('MedicalRagPipeline Stages', () {
test('QueryNormalizationStage expands abbreviations', () async {
final stage = QueryNormalizationStage();
final context = RagContext(originalQuery: 'Paciente con TA alta y DM');

final result = await stage.process(context);

expect(result.currentQuery, contains('tensión arterial'));
expect(result.currentQuery, contains('diabetes mellitus'));
});

test('MedicalQueryDecompositionStage splits queries', () async {
final stage = MedicalQueryDecompositionStage(llm, MedicalPromptBuilder());
final context = RagContext(
originalQuery:
'Explique qué es la TA y cómo se mide en pacientes con HTA severa');

final result = await stage.process(context);

expect(result.decomposedQueries.length, greaterThan(1));
expect(result.decomposedQueries[0], contains('tensión arterial'));
});
});

test('MedicalRagPipeline full execution', () async {
// Seed some data
await memoryGraph.storeNode(MemoryNode(
content:
'La tensión arterial es vital para el funcionamiento del cuerpo.',
type: 'medical',
));

final pipeline = MedicalRagPipeline(
memoryGraph: memoryGraph,
llm: llm,
);

final result = await pipeline.execute('¿Qué es la TA?');

expect(result.currentQuery, contains('tensión arterial'));
expect(result.generatedResponse, contains('tensión arterial'));
expect(result.generatedResponse, contains('AVISO MÉDICO'));
expect(result.citedNodes, isNotEmpty);
});
}
Loading