Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ SECRET_KEY=
# DEBUG: Enable debug mode (tracebacks, detailed errors). Default is true.
# In production you must set this to false (DEBUG=0 or DEBUG=false).
DEBUG=1
# If true, returns a fixed json response for the /api/v1/retrieve endpoint.
FLAG_RETRIEVE_DEFAULT_JSON=0

# -----------------------------------------------------------------------------
# Database (required)
Expand Down
9 changes: 5 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ eu_fact_force/ingestion/parsing/output/benchmark_results_extended.csv
eu_fact_force/ingestion/parsing/output/extraction_scores.csv
eu_fact_force/ingestion/parsing/output/analysis/
s3/
eu_fact_force/exploration/docling/results/html/
eu_fact_force/exploration/docling/results/json/
eu_fact_force/exploration/docling/results/md/
eu_fact_force/exploration/
annotated_pdf/
eu_fact_force/exploration/docling/results/annotated_pdf/

# docker volumes
postgres_data/
rustfs_data
7 changes: 7 additions & 0 deletions eu_fact_force/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,10 @@ def _get_databases():
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage",
},
}


FLAG_RETRIEVE_DEFAULT_JSON = os.getenv("FLAG_RETRIEVE_DEFAULT_JSON", "0").lower() in (
"true",
"1",
"yes",
)
37 changes: 37 additions & 0 deletions eu_fact_force/ingestion/data_collection/default_search.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"status": "success",
"narrative": "vaccine_autism",
"chunks": [
{
"type": "text",
"content": "...",
"score": 0.94,
"metadata": {
"document_id": "<document_id>",
"page": 0
}
}
],
"documents": {
"<document_id>": {
"link": "https://.../",
"title": "...",
"date": "...",
"journal": "...",
"authors": [
"...",
"..."
],
"doi": "...",
"abstract": "...",
"keywords": [
"keyword_1",
"keyword_2"
],
"evidence": {
"name": "...",
"rank": 0
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
I want quantitative results about the link between vaccines and autism.
51 changes: 50 additions & 1 deletion eu_fact_force/ingestion/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
"""Semantic search over ingested document chunks using pgvector."""

from pathlib import Path

from pgvector.django import CosineDistance

from eu_fact_force.ingestion.embedding import embed_query
from eu_fact_force.ingestion.models import DocumentChunk
from pgvector.django import CosineDistance

_PROMPTS_DIR = Path(__file__).resolve().parent / "data_collection" / "prompts"


class NarrativeNotFoundError(FileNotFoundError):
"""No prompts/<narrative>.md for the given narrative keyword."""


def list_prompt_keywords() -> list[str]:
"""Basenames of narrative prompts (one .md file per keyword), sorted."""
return sorted(p.stem for p in _PROMPTS_DIR.glob("*.md"))


def search_chunks(query: str, k: int = 10) -> list[tuple[DocumentChunk, float]]:
Expand All @@ -26,3 +40,38 @@ def search_chunks(query: str, k: int = 10) -> list[tuple[DocumentChunk, float]]:
.order_by("distance")[:k]
)
return [(chunk, float(chunk.distance)) for chunk in qs]


def search_narrative(narrative: str, k: int = 10) -> list[tuple[DocumentChunk, float]]:
prompt = _PROMPTS_DIR / f"{narrative}.md"
if not prompt.exists():
raise NarrativeNotFoundError(f"Prompt file not found: {prompt}")
return search_chunks(prompt.read_text(), k)


def chunks_context(top_chunks: list[tuple[DocumentChunk, float]]) -> dict:
chunks = [
{
"type": "text",
"content": chunk.content,
"score": score,
"metadata": {"document_id": chunk.source_file.id, "page": -1},
}
for chunk, score in top_chunks
]

documents = {}
for chunk, _ in top_chunks:
source_file = chunk.source_file
if source_file.id in documents:
continue
meta = source_file.metadata
documents[source_file.id] = {
"id": source_file.id,
"doi": source_file.doi,
"tags_pubmed": meta.tags_pubmed,
}
return {
"chunks": chunks,
"documents": documents,
}
1 change: 1 addition & 0 deletions eu_fact_force/ingestion/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
app_name = "ingestion"
urlpatterns = [
path("ingest/", views.ingest, name="ingest"),
path("search/<str:keyword>/", views.search, name="search"),
]
39 changes: 39 additions & 0 deletions eu_fact_force/ingestion/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
import json
from pathlib import Path

from django.http import JsonResponse
from django.shortcuts import render

from eu_fact_force.app.settings import FLAG_RETRIEVE_DEFAULT_JSON
from eu_fact_force.ingestion.search import (
NarrativeNotFoundError,
chunks_context,
list_prompt_keywords,
search_narrative,
)

from .forms import IngestForm
from .services import run_pipeline

_DEFAULT_SEARCH_PATH = (
Path(__file__).resolve().parent / "data_collection" / "default_search.json"
)


def ingest(request):
"""Accept a DOI via form, run the pipeline, display success and count."""
Expand Down Expand Up @@ -31,3 +47,26 @@ def ingest(request):
else:
context["form"] = form
return render(request, "ingestion/ingest.html", context)


def search(request, keyword: str):
"""Return the default search fixture JSON (keyword reserved for future filtering)."""
_ = keyword
if FLAG_RETRIEVE_DEFAULT_JSON:
return JsonResponse(
json.loads(_DEFAULT_SEARCH_PATH.read_text(encoding="utf-8"))
)
try:
chunks = search_narrative(keyword)
except NarrativeNotFoundError:
return JsonResponse(
{
"error": f"Unknown narrative keyword {keyword!r}; no matching prompt.",
"keywords": list_prompt_keywords(),
},
status=404,
)

return JsonResponse(
{"status": "success", "narrative": keyword, **chunks_context(chunks)}
)
9 changes: 9 additions & 0 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from eu_fact_force.ingestion.models import (
EMBEDDING_DIMENSIONS,
DocumentChunk,
FileMetadata,
SourceFile,
)

Expand All @@ -21,6 +22,14 @@ class Meta:
status = SourceFile.Status.STORED


class FileMetadataFactory(DjangoModelFactory):
class Meta:
model = FileMetadata

source_file = factory.SubFactory(SourceFileFactory)
tags_pubmed = factory.LazyFunction(list)


def _random_embedding_vector() -> list[float]:
return [random.random() for _ in range(EMBEDDING_DIMENSIONS)]

Expand Down
64 changes: 62 additions & 2 deletions tests/ingestion/test_search.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Tests for semantic search over document chunks."""

from pathlib import Path
from unittest.mock import patch

import pytest

from eu_fact_force.ingestion import models as ingestion_models
from eu_fact_force.ingestion import search as search_module
from eu_fact_force.ingestion.chunking import MAX_CHUNK_CHARS
from eu_fact_force.ingestion.models import EMBEDDING_DIMENSIONS, DocumentChunk
from tests.factories import DocumentChunkFactory, SourceFileFactory
from eu_fact_force.ingestion.models import (
EMBEDDING_DIMENSIONS,
DocumentChunk,
)
from eu_fact_force.ingestion.search import chunks_context
from tests.factories import DocumentChunkFactory, FileMetadataFactory, SourceFileFactory

PROJECT_ROOT = Path(__file__).resolve().parents[2]

# Tolerance for float distance comparison.
DISTANCE_TOLERANCE = 1e-5
Expand Down Expand Up @@ -135,3 +142,56 @@ def _rank2_in_full_space(x: float, y: float) -> list[float]:
pytest.approx(1.0),
]
assert all(r[0].source_file_id == source_file.pk for r in results)


@pytest.mark.django_db
class TestChunksContext:
def test_empty_top_chunks(self):
assert chunks_context([]) == {"chunks": [], "documents": {}}

def test_two_chunks_single_source_file(self):
source = SourceFileFactory(doi="doi/single", s3_key="key/single")
FileMetadataFactory(source_file=source, tags_pubmed=["mesh:a"])
chunk_a = DocumentChunkFactory(source_file=source, content="first")
chunk_b = DocumentChunkFactory(source_file=source, content="second")

result = chunks_context([(chunk_a, 0.9), (chunk_b, 0.8)])

assert result["chunks"] == [
{
"type": "text",
"content": "first",
"score": 0.9,
"metadata": {"document_id": source.id, "page": -1},
},
{
"type": "text",
"content": "second",
"score": 0.8,
"metadata": {"document_id": source.id, "page": -1},
},
]
assert result["documents"] == {
source.id: {
"id": source.id,
"doi": "doi/single",
"tags_pubmed": ["mesh:a"],
}
}

def test_two_chunks_two_source_files(self):
src1 = SourceFileFactory(doi="doi/one", s3_key="k1")
FileMetadataFactory(source_file=src1, tags_pubmed=["t1"])
src2 = SourceFileFactory(doi="doi/two", s3_key="k2")
FileMetadataFactory(source_file=src2, tags_pubmed=["t2", "t3"])

c1 = DocumentChunkFactory(source_file=src1, content="alpha", order=0)
c2 = DocumentChunkFactory(source_file=src2, content="beta", order=0)

result = chunks_context([(c1, 0.1), (c2, 0.2)])

assert [x["content"] for x in result["chunks"]] == ["alpha", "beta"]
assert result["documents"] == {
src1.id: {"id": src1.id, "doi": "doi/one", "tags_pubmed": ["t1"]},
src2.id: {"id": src2.id, "doi": "doi/two", "tags_pubmed": ["t2", "t3"]},
}
Loading