diff --git a/bigdata_briefs/api/models.py b/bigdata_briefs/api/models.py index b30e7e7..edf3b82 100644 --- a/bigdata_briefs/api/models.py +++ b/bigdata_briefs/api/models.py @@ -113,6 +113,29 @@ class BriefCreationRequest(BaseModel): le=10, examples=[settings.API_FRESHNESS_BOOST], ) + sentiment_threshold: float | None = Field( + None, + description=( + "Sentiment filter magnitude for every Bigdata /v1/search in this brief: chunks outside " + "[-1,-t] ∪ [t,1] are excluded. Omit to use the server default " + f"({settings.EXPLORATORY_SENTIMENT_THRESHOLD}). Use 0 to disable sentiment filtering." + ), + ge=0, + le=1, + examples=[settings.EXPLORATORY_SENTIMENT_THRESHOLD], + ) + rerank_threshold: float | None = Field( + None, + description=( + "Reranker score threshold for exploratory and follow-up Bigdata /v1/search calls " + "(ranking_params.reranker). Omit to use built-in defaults: exploratory " + f"{settings.API_RERANK_EXPLORATORY}, follow-up {settings.API_RERANK_FOLLOWUP}. " + "The initial lightweight \"has results\" probe always runs with reranking disabled." + ), + ge=0, + le=1, + examples=[settings.API_RERANK_EXPLORATORY], + ) class BriefAcceptedResponse(BaseModel): diff --git a/bigdata_briefs/api/utils.py b/bigdata_briefs/api/utils.py index 8970756..5193f50 100644 --- a/bigdata_briefs/api/utils.py +++ b/bigdata_briefs/api/utils.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Any, cast from pydantic import BaseModel @@ -7,7 +7,7 @@ from bigdata_briefs.sql_models import SQLBriefReport -def get_example_values_from_schema(schema_model: Type[BaseModel]) -> dict: +def get_example_values_from_schema(schema_model: type[BaseModel]) -> dict: """ Extract example values from a Pydantic model's fields, falling back to defaults if no example is provided. Args: @@ -18,9 +18,9 @@ def get_example_values_from_schema(schema_model: Type[BaseModel]) -> dict: example_values = {} for field_name, field in schema_model.model_fields.items(): example = None - if isinstance(field.json_schema_extra, dict): - if "example" in field.json_schema_extra: - example = field.json_schema_extra["example"] + extra = field.json_schema_extra + if isinstance(extra, dict): + example = cast(dict[str, Any], extra).get("example") elif field.examples: if field.examples: example = field.examples[0] diff --git a/bigdata_briefs/llm_client.py b/bigdata_briefs/llm_client.py index 5e6ab14..9e0f0db 100644 --- a/bigdata_briefs/llm_client.py +++ b/bigdata_briefs/llm_client.py @@ -3,7 +3,7 @@ import openai from pydantic import BaseModel -from bigdata_briefs import logger +from bigdata_briefs import LOG_LEVEL, logger from bigdata_briefs.metrics import LLMMetrics from bigdata_briefs.models import LLMUsage from bigdata_briefs.settings import settings @@ -22,7 +22,10 @@ class FollowUpQuestionsPromptDefaults(BaseModel): class LLMClient: def __init__(self, client: openai.OpenAI | None = None): if client is None: - client = openai.OpenAI() + client = openai.OpenAI( + timeout=settings.OPENAI_TIMEOUT_SECONDS, + max_retries=settings.LLM_RETRIES, + ) self.client = client @log_time @@ -91,6 +94,12 @@ def _call_with_retries(self, func, *args, **kwargs): try: return func(*args, **kwargs) except Exception as e: + if LOG_LEVEL == "DEBUG" and "timeout" in str(e).lower(): + prompt_payload = kwargs.get("input") or kwargs.get("messages") + print( + "\n[DEBUG][LLM TIMEOUT] Prompt payload sent to LLM:\n", + json.dumps(prompt_payload, indent=2, ensure_ascii=False), + ) if attempt >= settings.LLM_RETRIES - 1: raise logger.warning(f"Error calling LLM: {e}. Attempt {attempt + 1}") diff --git a/bigdata_briefs/metrics.py b/bigdata_briefs/metrics.py index 1ea075c..69e9be5 100644 --- a/bigdata_briefs/metrics.py +++ b/bigdata_briefs/metrics.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from queue import Queue from threading import Lock +from typing import Any, ClassVar from bigdata_briefs import logger from bigdata_briefs.models import ( @@ -12,6 +13,9 @@ class Metrics(ABC): + lock: ClassVar[Any] + metrics_queue: ClassVar[Queue[Any]] + @classmethod @abstractmethod def track_usage(cls, usage): ... @@ -66,12 +70,12 @@ class WarningsMetrics(Metrics): lock = Lock() @classmethod - def track_usage(cls, warning_message: str): + def track_usage(cls, usage: str): with cls.lock: # Avoid logging duplicate warnings - if warning_message not in cls.warnings: - logger.info("A warning have been suppressed", warning=warning_message) - cls.warnings.add(warning_message) + if usage not in cls.warnings: + logger.info("A warning have been suppressed", warning=usage) + cls.warnings.add(usage) @classmethod def get_total_usage(cls) -> set[str]: diff --git a/bigdata_briefs/models.py b/bigdata_briefs/models.py index 197e1dd..f213f99 100644 --- a/bigdata_briefs/models.py +++ b/bigdata_briefs/models.py @@ -168,7 +168,9 @@ def from_api(cls, api_document): ts=api_document["timestamp"], document_scope=api_document.get("document_type", "Unknown"), language=api_document.get("language", "Unknown"), - chunks=[Chunk.from_api(api_chunk) for api_chunk in api_document["chunks"]], + chunks=tuple( + Chunk.from_api(api_chunk) for api_chunk in api_document["chunks"] + ), ) @@ -205,6 +207,8 @@ class ValidatedInput(BaseModel): categories: list[str] | None source_rank_boost: int | None freshness_boost: int | None + sentiment_threshold: float + rerank_threshold: float | None = None class FollowUpAnalysis(BaseModel): diff --git a/bigdata_briefs/novelty/embedding_client.py b/bigdata_briefs/novelty/embedding_client.py index 3bbf7ae..7018b71 100644 --- a/bigdata_briefs/novelty/embedding_client.py +++ b/bigdata_briefs/novelty/embedding_client.py @@ -13,7 +13,10 @@ class EmbeddingClient: def __init__(self, model: str, client: openai.OpenAI | None = None): self.model = model if client is None: - client = openai.OpenAI() + client = openai.OpenAI( + timeout=settings.OPENAI_TIMEOUT_SECONDS, + max_retries=settings.EMBEDDING_RETRIES, + ) self.client = client def compute(self, texts: list[str], **kwargs) -> list[list[float]]: diff --git a/bigdata_briefs/query_service/api.py b/bigdata_briefs/query_service/api.py index c6af52d..4d028c8 100644 --- a/bigdata_briefs/query_service/api.py +++ b/bigdata_briefs/query_service/api.py @@ -275,6 +275,13 @@ def run_exploratory_search( self._run_single_exploratory_search, entity_id=entity.id, report_dates=report_dates, + source_filter=source_filter, + categories=categories, + sentiment_threshold=sentiment_threshold, + chunk_limit=chunk_limit, + rerank_threshold=rerank_threshold, + source_rank_boost=source_rank_boost, + freshness_boost=freshness_boost, enable_metric=True, metric_name=f"Exploratory search. Entity {entity.id}", ) @@ -288,6 +295,13 @@ def run_exploratory_search( return self._run_single_exploratory_search( entity_id=entity.id, report_dates=report_dates, + source_filter=source_filter, + categories=categories, + sentiment_threshold=sentiment_threshold, + chunk_limit=chunk_limit, + rerank_threshold=rerank_threshold, + source_rank_boost=source_rank_boost, + freshness_boost=freshness_boost, enable_metric=True, metric_name=f"Exploratory search. Entity {entity.id}", ) @@ -348,6 +362,8 @@ def run_query_with_follow_up_questions( executor: ThreadPoolExecutor, source_rank_boost: int | None = settings.API_SOURCE_RANK_BOOST, freshness_boost: int | None = settings.API_FRESHNESS_BOOST, + sentiment_threshold: float | None = settings.FOLLOWUP_SENTIMENT_THRESHOLD, + rerank_threshold: float | None = settings.API_RERANK_FOLLOWUP, ) -> QAPairs: future_to_question = { executor.submit( @@ -359,6 +375,8 @@ def run_query_with_follow_up_questions( categories=categories, source_rank_boost=source_rank_boost, freshness_boost=freshness_boost, + sentiment_threshold=sentiment_threshold, + rerank_threshold=rerank_threshold, ): question for question in follow_up_questions } diff --git a/bigdata_briefs/query_service/base.py b/bigdata_briefs/query_service/base.py index eec2e72..9631fe2 100644 --- a/bigdata_briefs/query_service/base.py +++ b/bigdata_briefs/query_service/base.py @@ -39,7 +39,7 @@ def check_if_entity_has_results( categories: list[str] | None = None, sentiment_threshold: float | None = None, chunk_limit: int | None = None, - rerank_threshold: float = 0.0, + rerank_threshold: float | None = None, ) -> list[Result]: ... @abstractmethod @@ -106,7 +106,10 @@ def run_query_with_follow_up_questions( follow_up_questions: list[str], report_dates: ReportDates, source_filter: list[str] | None, + categories: list[str] | None, executor: ThreadPoolExecutor, source_rank_boost: int | None, freshness_boost: int | None, + sentiment_threshold: float | None = None, + rerank_threshold: float | None = None, ) -> QAPairs: ... diff --git a/bigdata_briefs/query_service/models.py b/bigdata_briefs/query_service/models.py index 648326d..8198600 100644 --- a/bigdata_briefs/query_service/models.py +++ b/bigdata_briefs/query_service/models.py @@ -1,4 +1,4 @@ -from typing import List, Literal, NotRequired, TypedDict +from typing import Literal, NotRequired, TypedDict class TimestampFilter(TypedDict): @@ -7,21 +7,29 @@ class TimestampFilter(TypedDict): class EntityFilter(TypedDict): - any_of: List[str] + any_of: list[str] -class SentimentFilter(TypedDict): - values: List[Literal["positive", "negative", "neutral"]] +class SentimentRangeBand(TypedDict): + min: float + max: float + + +class SentimentFilter(TypedDict, total=False): + """API supports categorical values or numeric range bands (magnitude filter).""" + + values: list[Literal["positive", "negative", "neutral"]] + ranges: list[SentimentRangeBand] class SourceFilter(TypedDict): mode: Literal["INCLUDE", "EXCLUDE"] - values: List[str] + values: list[str] class CategoryFilter(TypedDict): mode: Literal["INCLUDE", "EXCLUDE"] - values: List[str] + values: list[str] class Filters(TypedDict, total=False): diff --git a/bigdata_briefs/service.py b/bigdata_briefs/service.py index 9a43485..a05b6c5 100644 --- a/bigdata_briefs/service.py +++ b/bigdata_briefs/service.py @@ -4,6 +4,7 @@ from hashlib import sha256 from importlib.metadata import version from threading import Lock +from time import perf_counter from uuid import UUID from bigdata_briefs import logger @@ -194,16 +195,21 @@ def execute_entity_report_pipeline( report_dates: ReportDates, source_rank_boost: int | None, freshness_boost: int | None, + sentiment_threshold: float, executor: ThreadPoolExecutor, + *, + rerank_threshold: float | None = None, ) -> tuple[SingleEntityReport, RetrievedSources]: logger.debug(f"Starting report on {entity}") # Quick initial search to check if there are any results + # Initial probe: keep rerank off (cheap existence check); overrides apply only below. initial_results = self.query_service.check_if_entity_has_results( entity_id=entity.id, report_dates=report_dates, source_filter=source_filter, categories=categories, + sentiment_threshold=sentiment_threshold, ) if not initial_results: @@ -215,19 +221,33 @@ def execute_entity_report_pipeline( ) # If we found results, proceed with full exploratory search + exploratory_kw: dict = { + "entity": entity, + "topics": topics, + "report_dates": report_dates, + "executor": executor, + "enable_metric": True, + "metric_name": "Exploratory search. All entities", + "source_filter": source_filter, + "categories": categories, + "source_rank_boost": source_rank_boost, + "freshness_boost": freshness_boost, + "sentiment_threshold": sentiment_threshold, + } + if rerank_threshold is not None: + exploratory_kw["rerank_threshold"] = rerank_threshold + logger.debug( + f"[diag] {entity.id} entering run_exploratory_search with {len(topics)} topics" + ) + exploratory_t0 = perf_counter() with self.weighted_semaphore(len(topics) + 1): exploratory_search_results = self.query_service.run_exploratory_search( - entity=entity, - topics=topics, - report_dates=report_dates, - executor=executor, - enable_metric=True, - metric_name="Exploratory search. All entities", - source_filter=source_filter, - categories=categories, - source_rank_boost=source_rank_boost, - freshness_boost=freshness_boost, + **exploratory_kw ) + logger.debug( + f"[diag] {entity.id} completed run_exploratory_search in " + f"{perf_counter() - exploratory_t0:.2f}s with {len(exploratory_search_results)} results" + ) if not exploratory_search_results: logger.debug(f"No new information found for {entity}") return self.create_no_info_report( @@ -236,6 +256,8 @@ def execute_entity_report_pipeline( generation_step=NoInfoReportGenerationStep.EXPLORATORY_SEARCH, ) + logger.debug(f"[diag] {entity.id} entering generate_follow_up_questions") + followup_questions_t0 = perf_counter() follow_up_questions = self.generate_follow_up_questions( entity, topics, @@ -244,6 +266,10 @@ def execute_entity_report_pipeline( enable_metric=True, metric_name="Generate follow up questions", ) + logger.debug( + f"[diag] {entity.id} completed generate_follow_up_questions in " + f"{perf_counter() - followup_questions_t0:.2f}s with {len(follow_up_questions)} questions" + ) if not follow_up_questions: logger.debug(f"No follow-up questions generated for {entity}") return self.create_no_info_report( @@ -255,19 +281,32 @@ def execute_entity_report_pipeline( if len(follow_up_questions) != settings.LLM_FOLLOW_UP_QUESTIONS: logger.debug(f"Number of followup questions: {len(follow_up_questions)}") + followup_kw: dict = { + "entity": entity, + "follow_up_questions": follow_up_questions, + "report_dates": report_dates, + "executor": executor, + "enable_metric": True, + "metric_name": "Run follow up questions", + "source_filter": source_filter, + "categories": categories, + "source_rank_boost": source_rank_boost, + "freshness_boost": freshness_boost, + "sentiment_threshold": sentiment_threshold, + } + if rerank_threshold is not None: + followup_kw["rerank_threshold"] = rerank_threshold + logger.debug( + f"[diag] {entity.id} entering run_query_with_follow_up_questions " + f"with {len(follow_up_questions)} questions" + ) + qa_t0 = perf_counter() with self.weighted_semaphore(len(follow_up_questions)): - qa_pairs = self.query_service.run_query_with_follow_up_questions( - entity=entity, - follow_up_questions=follow_up_questions, - report_dates=report_dates, - executor=executor, - enable_metric=True, - metric_name="Run follow up questions", - source_filter=source_filter, - categories=categories, - source_rank_boost=source_rank_boost, - freshness_boost=freshness_boost, - ) + qa_pairs = self.query_service.run_query_with_follow_up_questions(**followup_kw) + logger.debug( + f"[diag] {entity.id} completed run_query_with_follow_up_questions in " + f"{perf_counter() - qa_t0:.2f}s with {len(qa_pairs.pairs)} qa pairs" + ) if not any(pair.answer for pair in qa_pairs.pairs): logger.debug(f"No qa-pairs generated for {entity}") return self.create_no_info_report( @@ -276,6 +315,8 @@ def execute_entity_report_pipeline( generation_step=NoInfoReportGenerationStep.QA_PAIRS, ) + logger.debug(f"[diag] {entity.id} entering generate_new_report") + report_t0 = perf_counter() entity_report, source_mapping = self.generate_new_report( entity, qa_pairs, @@ -284,6 +325,11 @@ def execute_entity_report_pipeline( enable_metric=True, metric_name="Generating report", ) + logger.debug( + f"[diag] {entity.id} completed generate_new_report in " + f"{perf_counter() - report_t0:.2f}s with " + f"{len(entity_report.report_bulletpoints)} bullets" + ) BulletPointMetrics.track_usage( BulletPointsUsage( bullet_points_before_novelty=len(entity_report.report_bulletpoints) @@ -492,8 +538,11 @@ def execute_watchlist_report_pipeline( disable_introduction: bool, source_rank_boost: int | None, freshness_boost: int | None, + sentiment_threshold: float, request_id: UUID, storage_manager: StorageManager, + *, + rerank_threshold: float | None = None, ) -> tuple[WatchlistReport, RetrievedSources]: storage_manager.log_message(request_id, "Generating report per entity") with ThreadPoolExecutor(max_workers=EXECUTOR_WORKERS) as executor: @@ -507,7 +556,9 @@ def execute_watchlist_report_pipeline( report_dates, source_rank_boost, freshness_boost, + sentiment_threshold, executor, + rerank_threshold=rerank_threshold, ): entity for entity in entities } @@ -629,10 +680,12 @@ def generate_brief( record_data.disable_introduction, record_data.source_rank_boost, record_data.freshness_boost, - enable_metric=True, - metric_name="Execute watchlist report pipeline", + record_data.sentiment_threshold, request_id=request_id, storage_manager=storage_manager, + rerank_threshold=record_data.rerank_threshold, + enable_metric=True, + metric_name="Execute watchlist report pipeline", ) n_watchlist_items = len(record_data.entities) @@ -788,6 +841,12 @@ def parse_and_validate( logger.debug(disable_intro_msg) storage_manager.log_message(request_id, disable_intro_msg) + resolved_sentiment = ( + record.sentiment_threshold + if record.sentiment_threshold is not None + else settings.EXPLORATORY_SENTIMENT_THRESHOLD + ) + return ValidatedInput( watchlist=Watchlist( id=watchlist.id, @@ -805,6 +864,8 @@ def parse_and_validate( disable_introduction=record.disable_introduction, source_rank_boost=record.source_rank_boost, freshness_boost=record.freshness_boost, + sentiment_threshold=resolved_sentiment, + rerank_threshold=record.rerank_threshold, ) diff --git a/bigdata_briefs/settings.py b/bigdata_briefs/settings.py index 85bde53..3fe09fe 100644 --- a/bigdata_briefs/settings.py +++ b/bigdata_briefs/settings.py @@ -75,7 +75,7 @@ class Settings(BaseSettings): NOVELTY_LOOKBACK_DAYS: int = 14 NOVELTY_STORAGE_LOOKBACK_HOURS: int = 1 NOVELTY_STORAGE_THRESHOLD: float = 0.8 - EMBEDDING_RETRIES: int = 3 + EMBEDDING_RETRIES: int = 2 # Search configuration API_SIMULTANEOUS_REQUESTS: int = 40 # Reduced to prevent rate limit bursts @@ -93,7 +93,8 @@ class Settings(BaseSettings): # LLM configuration LLM_FOLLOW_UP_QUESTIONS: int = 5 - LLM_RETRIES: int = 3 + LLM_RETRIES: int = 2 + OPENAI_TIMEOUT_SECONDS: int = 30 # Server configuration HOST: str = "0.0.0.0" diff --git a/bigdata_briefs/storage.py b/bigdata_briefs/storage.py index e52dd02..75053a1 100644 --- a/bigdata_briefs/storage.py +++ b/bigdata_briefs/storage.py @@ -56,7 +56,7 @@ def get_report_with_sources( # since some users might have older versions of the database with the report column stored as string # Remove in future versions when a breaking change is acceptable report.brief_report = json.loads(report.brief_report) - brief_report = BriefReport(**report.brief_report) # ty: ignore[missing-argument] + brief_report = BriefReport(**report.brief_report) return brief_report except Exception as e: logger.error(f"Error reconstructing BriefReport from database records: {e}") diff --git a/bigdata_briefs/utils.py b/bigdata_briefs/utils.py index 8276b60..d9ffaca 100644 --- a/bigdata_briefs/utils.py +++ b/bigdata_briefs/utils.py @@ -5,7 +5,7 @@ from datetime import datetime from functools import wraps from time import perf_counter -from typing import Type +from typing import cast from json_repair import repair_json from pydantic import BaseModel, ValidationError @@ -60,13 +60,14 @@ def wrapper( return wrapper -def validate_and_repair_model(json_str: str, model: Type[BaseModel]) -> BaseModel: +def validate_and_repair_model(json_str: str, model: type[BaseModel]) -> BaseModel: try: response = model.model_validate_json(json_str) return response except ValidationError: - # With return_objects=False, it always returns a string, so ignore type checking error - fixed_json_str: str = repair_json(json_str, return_objects=False) # type: ignore[invalid-assignment] + fixed_json_str = cast( + str, repair_json(json_str, return_objects=False) + ) try: response = model.model_validate_json(fixed_json_str) logger.debug( diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py index 92cb73f..d7a7e92 100644 --- a/tests/test_llm_client.py +++ b/tests/test_llm_client.py @@ -9,6 +9,7 @@ from bigdata_briefs.llm_client import ( openai as llm_client_openai, ) +from bigdata_briefs.settings import settings from bigdata_briefs.utils import time as utils_time @@ -141,7 +142,7 @@ def test_call_with_retries_but_failure( # Mock all calls to fail mock_llm_client.client.responses.parse.side_effect = [ Exception("API Error"), - ] * 3 + ] * settings.LLM_RETRIES monkeypatch.setattr(utils_time, "sleep", lambda _: None) with pytest.raises(Exception, match="API Error"): @@ -153,6 +154,6 @@ def test_call_with_retries_but_failure( response_format=DummyResponseFormat, ) - assert mock_llm_client.client.responses.parse.call_count == 3, ( - "Expected 3 retries but got a different count" + assert mock_llm_client.client.responses.parse.call_count == settings.LLM_RETRIES, ( + "Expected retries to match LLM_RETRIES" )