diff --git a/.env.template b/.env.template index 6e776a5..677c235 100644 --- a/.env.template +++ b/.env.template @@ -13,6 +13,8 @@ SECRET_KEY= # DEBUG: Enable debug mode (tracebacks, detailed errors). Default is true. # In production you must set this to false (DEBUG=0 or DEBUG=false). DEBUG=1 +# If true, returns a fixed json response for the /api/v1/retrieve endpoint. +FLAG_RETRIEVE_DEFAULT_JSON=0 # ----------------------------------------------------------------------------- # Database (required) diff --git a/.gitignore b/.gitignore index 1d8cf59..de02f97 100644 --- a/.gitignore +++ b/.gitignore @@ -172,8 +172,9 @@ eu_fact_force/ingestion/parsing/output/benchmark_results_extended.csv eu_fact_force/ingestion/parsing/output/extraction_scores.csv eu_fact_force/ingestion/parsing/output/analysis/ s3/ -eu_fact_force/exploration/docling/results/html/ -eu_fact_force/exploration/docling/results/json/ -eu_fact_force/exploration/docling/results/md/ +eu_fact_force/exploration/ annotated_pdf/ -eu_fact_force/exploration/docling/results/annotated_pdf/ + +# docker volumes +postgres_data/ +rustfs_data \ No newline at end of file diff --git a/eu_fact_force/app/settings.py b/eu_fact_force/app/settings.py index 46ea182..8bff827 100644 --- a/eu_fact_force/app/settings.py +++ b/eu_fact_force/app/settings.py @@ -187,3 +187,10 @@ def _get_databases(): "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage", }, } + + +FLAG_RETRIEVE_DEFAULT_JSON = os.getenv("FLAG_RETRIEVE_DEFAULT_JSON", "0").lower() in ( + "true", + "1", + "yes", +) diff --git a/eu_fact_force/ingestion/data_collection/default_search.json b/eu_fact_force/ingestion/data_collection/default_search.json new file mode 100644 index 0000000..03a6de4 --- /dev/null +++ b/eu_fact_force/ingestion/data_collection/default_search.json @@ -0,0 +1,37 @@ +{ + "status": "success", + "narrative": "vaccine_autism", + "chunks": [ + { + "type": "text", + "content": "...", + "score": 0.94, + "metadata": { + "document_id": "", + "page": 0 + } + } + ], + "documents": { + "": { + "link": "https://.../", + "title": "...", + "date": "...", + "journal": "...", + "authors": [ + "...", + "..." + ], + "doi": "...", + "abstract": "...", + "keywords": [ + "keyword_1", + "keyword_2" + ], + "evidence": { + "name": "...", + "rank": 0 + } + } + } +} \ No newline at end of file diff --git a/eu_fact_force/ingestion/data_collection/prompts/vaccine_autism.md b/eu_fact_force/ingestion/data_collection/prompts/vaccine_autism.md new file mode 100644 index 0000000..e34a010 --- /dev/null +++ b/eu_fact_force/ingestion/data_collection/prompts/vaccine_autism.md @@ -0,0 +1 @@ +I want quantitative results about the link between vaccines and autism. \ No newline at end of file diff --git a/eu_fact_force/ingestion/search.py b/eu_fact_force/ingestion/search.py index 70608be..b9888eb 100644 --- a/eu_fact_force/ingestion/search.py +++ b/eu_fact_force/ingestion/search.py @@ -1,8 +1,22 @@ """Semantic search over ingested document chunks using pgvector.""" +from pathlib import Path + +from pgvector.django import CosineDistance + from eu_fact_force.ingestion.embedding import embed_query from eu_fact_force.ingestion.models import DocumentChunk -from pgvector.django import CosineDistance + +_PROMPTS_DIR = Path(__file__).resolve().parent / "data_collection" / "prompts" + + +class NarrativeNotFoundError(FileNotFoundError): + """No prompts/.md for the given narrative keyword.""" + + +def list_prompt_keywords() -> list[str]: + """Basenames of narrative prompts (one .md file per keyword), sorted.""" + return sorted(p.stem for p in _PROMPTS_DIR.glob("*.md")) def search_chunks(query: str, k: int = 10) -> list[tuple[DocumentChunk, float]]: @@ -26,3 +40,38 @@ def search_chunks(query: str, k: int = 10) -> list[tuple[DocumentChunk, float]]: .order_by("distance")[:k] ) return [(chunk, float(chunk.distance)) for chunk in qs] + + +def search_narrative(narrative: str, k: int = 10) -> list[tuple[DocumentChunk, float]]: + prompt = _PROMPTS_DIR / f"{narrative}.md" + if not prompt.exists(): + raise NarrativeNotFoundError(f"Prompt file not found: {prompt}") + return search_chunks(prompt.read_text(), k) + + +def chunks_context(top_chunks: list[tuple[DocumentChunk, float]]) -> dict: + chunks = [ + { + "type": "text", + "content": chunk.content, + "score": score, + "metadata": {"document_id": chunk.source_file.id, "page": -1}, + } + for chunk, score in top_chunks + ] + + documents = {} + for chunk, _ in top_chunks: + source_file = chunk.source_file + if source_file.id in documents: + continue + meta = source_file.metadata + documents[source_file.id] = { + "id": source_file.id, + "doi": source_file.doi, + "tags_pubmed": meta.tags_pubmed, + } + return { + "chunks": chunks, + "documents": documents, + } diff --git a/eu_fact_force/ingestion/urls.py b/eu_fact_force/ingestion/urls.py index 0a3a66c..92b1876 100644 --- a/eu_fact_force/ingestion/urls.py +++ b/eu_fact_force/ingestion/urls.py @@ -5,4 +5,5 @@ app_name = "ingestion" urlpatterns = [ path("ingest/", views.ingest, name="ingest"), + path("search//", views.search, name="search"), ] diff --git a/eu_fact_force/ingestion/views.py b/eu_fact_force/ingestion/views.py index 675555e..9c0f97d 100644 --- a/eu_fact_force/ingestion/views.py +++ b/eu_fact_force/ingestion/views.py @@ -1,8 +1,24 @@ +import json +from pathlib import Path + +from django.http import JsonResponse from django.shortcuts import render +from eu_fact_force.app.settings import FLAG_RETRIEVE_DEFAULT_JSON +from eu_fact_force.ingestion.search import ( + NarrativeNotFoundError, + chunks_context, + list_prompt_keywords, + search_narrative, +) + from .forms import IngestForm from .services import run_pipeline +_DEFAULT_SEARCH_PATH = ( + Path(__file__).resolve().parent / "data_collection" / "default_search.json" +) + def ingest(request): """Accept a DOI via form, run the pipeline, display success and count.""" @@ -31,3 +47,26 @@ def ingest(request): else: context["form"] = form return render(request, "ingestion/ingest.html", context) + + +def search(request, keyword: str): + """Return the default search fixture JSON (keyword reserved for future filtering).""" + _ = keyword + if FLAG_RETRIEVE_DEFAULT_JSON: + return JsonResponse( + json.loads(_DEFAULT_SEARCH_PATH.read_text(encoding="utf-8")) + ) + try: + chunks = search_narrative(keyword) + except NarrativeNotFoundError: + return JsonResponse( + { + "error": f"Unknown narrative keyword {keyword!r}; no matching prompt.", + "keywords": list_prompt_keywords(), + }, + status=404, + ) + + return JsonResponse( + {"status": "success", "narrative": keyword, **chunks_context(chunks)} + ) diff --git a/tests/factories.py b/tests/factories.py index 1759233..600fe96 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -8,6 +8,7 @@ from eu_fact_force.ingestion.models import ( EMBEDDING_DIMENSIONS, DocumentChunk, + FileMetadata, SourceFile, ) @@ -21,6 +22,14 @@ class Meta: status = SourceFile.Status.STORED +class FileMetadataFactory(DjangoModelFactory): + class Meta: + model = FileMetadata + + source_file = factory.SubFactory(SourceFileFactory) + tags_pubmed = factory.LazyFunction(list) + + def _random_embedding_vector() -> list[float]: return [random.random() for _ in range(EMBEDDING_DIMENSIONS)] diff --git a/tests/ingestion/test_search.py b/tests/ingestion/test_search.py index aedf997..a15c9cf 100644 --- a/tests/ingestion/test_search.py +++ b/tests/ingestion/test_search.py @@ -1,5 +1,6 @@ """Tests for semantic search over document chunks.""" +from pathlib import Path from unittest.mock import patch import pytest @@ -7,8 +8,14 @@ from eu_fact_force.ingestion import models as ingestion_models from eu_fact_force.ingestion import search as search_module from eu_fact_force.ingestion.chunking import MAX_CHUNK_CHARS -from eu_fact_force.ingestion.models import EMBEDDING_DIMENSIONS, DocumentChunk -from tests.factories import DocumentChunkFactory, SourceFileFactory +from eu_fact_force.ingestion.models import ( + EMBEDDING_DIMENSIONS, + DocumentChunk, +) +from eu_fact_force.ingestion.search import chunks_context +from tests.factories import DocumentChunkFactory, FileMetadataFactory, SourceFileFactory + +PROJECT_ROOT = Path(__file__).resolve().parents[2] # Tolerance for float distance comparison. DISTANCE_TOLERANCE = 1e-5 @@ -135,3 +142,56 @@ def _rank2_in_full_space(x: float, y: float) -> list[float]: pytest.approx(1.0), ] assert all(r[0].source_file_id == source_file.pk for r in results) + + +@pytest.mark.django_db +class TestChunksContext: + def test_empty_top_chunks(self): + assert chunks_context([]) == {"chunks": [], "documents": {}} + + def test_two_chunks_single_source_file(self): + source = SourceFileFactory(doi="doi/single", s3_key="key/single") + FileMetadataFactory(source_file=source, tags_pubmed=["mesh:a"]) + chunk_a = DocumentChunkFactory(source_file=source, content="first") + chunk_b = DocumentChunkFactory(source_file=source, content="second") + + result = chunks_context([(chunk_a, 0.9), (chunk_b, 0.8)]) + + assert result["chunks"] == [ + { + "type": "text", + "content": "first", + "score": 0.9, + "metadata": {"document_id": source.id, "page": -1}, + }, + { + "type": "text", + "content": "second", + "score": 0.8, + "metadata": {"document_id": source.id, "page": -1}, + }, + ] + assert result["documents"] == { + source.id: { + "id": source.id, + "doi": "doi/single", + "tags_pubmed": ["mesh:a"], + } + } + + def test_two_chunks_two_source_files(self): + src1 = SourceFileFactory(doi="doi/one", s3_key="k1") + FileMetadataFactory(source_file=src1, tags_pubmed=["t1"]) + src2 = SourceFileFactory(doi="doi/two", s3_key="k2") + FileMetadataFactory(source_file=src2, tags_pubmed=["t2", "t3"]) + + c1 = DocumentChunkFactory(source_file=src1, content="alpha", order=0) + c2 = DocumentChunkFactory(source_file=src2, content="beta", order=0) + + result = chunks_context([(c1, 0.1), (c2, 0.2)]) + + assert [x["content"] for x in result["chunks"]] == ["alpha", "beta"] + assert result["documents"] == { + src1.id: {"id": src1.id, "doi": "doi/one", "tags_pubmed": ["t1"]}, + src2.id: {"id": src2.id, "doi": "doi/two", "tags_pubmed": ["t2", "t3"]}, + }