Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/agent/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from src.finance.normalizer import normalize_account_name
from src.finance.signals import generate_signals
from src.finance.utils import table_rows
from src.finance.validators import validate_statements
from src.finance.validators import validate_facts, validate_statements
from src.llm.base import StructuredPromptRequest, get_default_llm_client, get_llm_client, get_llm_model_config, langsmith_metadata
from src.pdf.extractor import PDFExtractor
from src.pdf.tables import extract_tables as extract_tables_from_pdf
Expand Down Expand Up @@ -259,11 +259,15 @@ def extract_financial_statements(state: AgentState) -> AgentState:

def validate_and_reconcile(state: AgentState) -> AgentState:
start_ms = monotonic_ms()
if not state.facts:
state.facts = facts_from_statements(state.doc_meta, state.statements)
state.validation_results = validate_statements(state.statements)
state.fact_validation_results = validate_facts(state.doc_meta, state.facts, state.statements)
severe = any(
issue in {"balance_equation_failed", "balance_missing_totals"} or issue.startswith("unit_mismatch")
for issue in state.validation_results.get("issues", [])
)
severe = severe or any(issue.severity == "high" for issue in (state.fact_validation_results.issues if state.fact_validation_results else []))
# Remove any previous validation_failed entries before deciding, ensuring idempotency
state.errors = [err for err in state.errors if err != "validation_failed"]
if severe:
Expand Down Expand Up @@ -559,8 +563,14 @@ def finalize(state: AgentState) -> AgentState:
store.save_json(state.doc_meta.doc_id, "extracted/tables.json", [t.model_dump() for t in state.tables])
store.save_json(state.doc_meta.doc_id, "extracted/statements.json", {k: v.model_dump() for k, v in state.statements.items()})
if not state.facts:
state.facts = facts_from_statements(state.doc_meta.doc_id, state.statements)
state.facts = facts_from_statements(state.doc_meta, state.statements)
store.save_json(state.doc_meta.doc_id, "extracted/facts.json", [fact.model_dump(mode="json") for fact in state.facts])
if state.fact_validation_results:
store.save_json(
state.doc_meta.doc_id,
"extracted/fact_validation.json",
state.fact_validation_results.model_dump(mode="json"),
)
if state.corrections:
store.save_json(
state.doc_meta.doc_id,
Expand Down
2 changes: 2 additions & 0 deletions src/agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DocumentMeta,
EventStudyResult,
ExtractionTrace,
FactValidationResult,
FinancialFact,
FinancialStatement,
KeyNote,
Expand All @@ -33,6 +34,7 @@ class AgentState(BaseModel):
statements: dict[str, FinancialStatement] = Field(default_factory=dict)
notes: list[KeyNote] = Field(default_factory=list)
validation_results: dict[str, Any] = Field(default_factory=dict)
fact_validation_results: FactValidationResult | None = None
risk_signals: list[RiskSignal] = Field(default_factory=list)
facts: list[FinancialFact] = Field(default_factory=list)
corrections: list[Correction] = Field(default_factory=list)
Expand Down
16 changes: 16 additions & 0 deletions src/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ async def create_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
company: str | None = Form(default=None),
ticker: str | None = Form(default=None),
cik: str | None = Form(default=None),
filing_type: str | None = Form(default=None),
period_end: str | None = Form(default=None),
report_type: str | None = Form(default=None),
language: str | None = Form(default=None),
Expand Down Expand Up @@ -174,6 +177,9 @@ async def create_document(
doc_id=doc_id,
filename=safe_filename,
company=company,
ticker=ticker,
cik=cik,
filing_type=filing_type,
period_end=parsed_period,
report_type=report_type,
language=language,
Expand Down Expand Up @@ -274,6 +280,14 @@ async def get_facts(_auth: _AuthDep, doc_id: str):
return _ok(data)


@router.get("/documents/{doc_id}/fact-validation")
async def get_fact_validation(_auth: _AuthDep, doc_id: str):
data = store.load_json(doc_id, "extracted/fact_validation.json")
if data is None:
return _err("not_found", "Fact validation not found")
return _ok(data)


@router.get("/documents/{doc_id}/notes")
async def get_notes(_auth: _AuthDep, doc_id: str):
data = store.load_json(doc_id, "extracted/notes.json")
Expand Down Expand Up @@ -532,6 +546,8 @@ def _save_partial_results(doc_id: str) -> None:
s.save_json(doc_id, "extracted/statements.json", {k: v.model_dump() for k, v in partial.statements.items()})
if partial.facts:
s.save_json(doc_id, "extracted/facts.json", [fact.model_dump(mode="json") for fact in partial.facts])
if partial.fact_validation_results:
s.save_json(doc_id, "extracted/fact_validation.json", partial.fact_validation_results.model_dump(mode="json"))
if partial.notes:
s.save_json(doc_id, "extracted/notes.json", [n.model_dump() for n in partial.notes])
if partial.risk_signals:
Expand Down
6 changes: 6 additions & 0 deletions src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def analyze(
pdf: str = typer.Option(..., help="Path to PDF file"),
out: str = typer.Option("data", help="Output directory"),
company: str | None = typer.Option(None, help="Company name"),
ticker: str | None = typer.Option(None, help="Ticker symbol"),
cik: str | None = typer.Option(None, help="CIK identifier"),
filing_type: str | None = typer.Option(None, help="Filing type (for example 10-K or 10-Q)"),
period_end: str | None = typer.Option(None, help="Report period end (YYYY-MM-DD)"),
report_type: str | None = typer.Option(None, help="Report type"),
language: str | None = typer.Option(None, help="Language"),
Expand All @@ -38,6 +41,9 @@ def analyze(
doc_id=doc_id,
filename=pdf_path.name,
company=company,
ticker=ticker,
cik=cik,
filing_type=filing_type,
period_end=parsed_period,
report_type=report_type,
language=language,
Expand Down
33 changes: 24 additions & 9 deletions src/finance/facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from collections.abc import Iterable
from typing import Any

from src.schemas.models import Correction, FinancialFact, FinancialStatement, SourceRef, StatementLineItem
from src.schemas.models import Correction, DocumentMeta, FinancialFact, FinancialStatement, SourceRef, StatementLineItem


def facts_from_statements(doc_id: str, statements: dict[str, FinancialStatement]) -> list[FinancialFact]:
def facts_from_statements(doc_meta: str | DocumentMeta, statements: dict[str, FinancialStatement]) -> list[FinancialFact]:
meta = _coerce_doc_meta(doc_meta)
facts: list[FinancialFact] = []
seen: set[tuple[str, str]] = set()

Expand All @@ -17,7 +18,7 @@ def facts_from_statements(doc_id: str, statements: dict[str, FinancialStatement]
for item in statement.line_items:
if item.value_current is None and not item.source_refs:
continue
fact = _fact_from_line_item(doc_id, statement, item)
fact = _fact_from_line_item(meta, statement, item)
facts.append(fact)
seen.add((statement.statement_type, item.name_norm))

Expand All @@ -27,10 +28,11 @@ def facts_from_statements(doc_id: str, statements: dict[str, FinancialStatement]
continue
facts.append(
_build_fact(
doc_id=doc_id,
doc_meta=meta,
statement=statement,
concept=concept,
label=concept,
raw_label=concept,
value=value,
unit=None,
currency=None,
Expand All @@ -55,16 +57,17 @@ def apply_corrections(facts: Iterable[FinancialFact], corrections: Iterable[Corr


def _fact_from_line_item(
doc_id: str,
doc_meta: DocumentMeta,
statement: FinancialStatement,
item: StatementLineItem,
) -> FinancialFact:
confidence = _evidence_confidence(item.source_refs, statement.extraction_confidence)
return _build_fact(
doc_id=doc_id,
doc_meta=doc_meta,
statement=statement,
concept=item.name_norm,
label=item.name_raw,
raw_label=item.name_raw,
value=item.value_current,
unit=item.unit,
currency=item.currency,
Expand All @@ -76,10 +79,11 @@ def _fact_from_line_item(

def _build_fact(
*,
doc_id: str,
doc_meta: DocumentMeta,
statement: FinancialStatement,
concept: str,
label: str,
raw_label: str,
value: float | None,
unit: str | None,
currency: str | None,
Expand All @@ -88,11 +92,16 @@ def _build_fact(
metadata: dict[str, Any],
) -> FinancialFact:
return FinancialFact(
fact_id=_fact_id(doc_id, statement.statement_type, concept, statement.period_end, label),
doc_id=doc_id,
fact_id=_fact_id(doc_meta.doc_id, statement.statement_type, concept, statement.period_end, raw_label),
doc_id=doc_meta.doc_id,
company=doc_meta.company,
ticker=doc_meta.ticker,
cik=doc_meta.cik,
filing_type=doc_meta.filing_type or doc_meta.report_type,
statement_type=statement.statement_type,
concept=concept,
label=label,
raw_label=raw_label,
value=value,
unit=unit,
scale=_infer_scale(unit),
Expand All @@ -107,6 +116,12 @@ def _build_fact(
)


def _coerce_doc_meta(doc_meta: str | DocumentMeta) -> DocumentMeta:
if isinstance(doc_meta, DocumentMeta):
return doc_meta
return DocumentMeta(doc_id=doc_meta, filename=f"{doc_meta}.pdf")


def _fact_id(doc_id: str, statement_type: str, concept: str, period_end: object, label: str) -> str:
raw = "|".join([doc_id, statement_type, concept, str(period_end or ""), label])
return "fact_" + hashlib.sha1(raw.encode("utf-8")).hexdigest()[:16]
Expand Down
Loading
Loading