Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions examples/grounded_mindmaps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import logging

from dotenv import load_dotenv

from bigdata_research_tools.mindmap.mindmap_generator import MindMapGenerator
from bigdata_research_tools.visuals.mindmap_visuals import plot_mindmap

# Load environment variables for authentication
print(f"Environment variables loaded: {load_dotenv()}")

# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def test_one_shot_mindmap(
main_theme,
focus,
map_type,
instructions,
llm_base_config: str = "openai::gpt-4o-mini",
):
"""Test one-shot mind map generation with base LLM."""
logger.info("=" * 60)
logger.info("TEST 1: One-Shot Mind Map Generation with Base LLM")
logger.info("=" * 60)
mindmap_generator = MindMapGenerator(
llm_model_config_base=llm_base_config,
)
_, mindmap = mindmap_generator.generate_one_shot(
instructions=instructions,
focus=focus,
main_theme=main_theme,
map_type=map_type,
allow_grounding=False,
)
logger.info("Results: %s", mindmap["mindmap_text"])
return mindmap["mindmap_df"], mindmap["mindmap_json"]


def test_refined_mindmap(
main_theme,
focus,
map_type,
instructions,
base_mindmap: str,
llm_base_config: str = "openai::o3-mini",
):
"""Test refined mindmap generation with reasoning LLM sent in the base config."""
logger.info("=" * 60)
logger.info("TEST 2: Refined MindMap Generation with Reasoning LLM in Base Config")
logger.info("=" * 60)
mindmap_generator = MindMapGenerator(
llm_model_config_base=llm_base_config,
)
_, mindmap = mindmap_generator.generate_refined(
focus=focus,
main_theme=main_theme,
initial_mindmap=base_mindmap,
output_dir="./refined_mindmaps",
filename="refined_mindmap.json",
map_type=map_type,
instructions=instructions,
)
logger.info("Results: %s", mindmap["mindmap_text"])


def test_refined_mindmap2(
main_theme,
focus,
map_type,
instructions,
base_mindmap: str,
llm_base_config: str,
llm_reasoning_config: str = "openai::o3-mini",
):
"""Test refined mindmap generation with reasoning LLM sent in the reasoning config."""
logger.info("=" * 60)
logger.info(
"TEST 3: Refined MindMap Generation with Reasoning LLM in Reasoning Config"
)
logger.info("=" * 60)
mindmap_generator = MindMapGenerator(
llm_model_config_base=llm_base_config,
llm_model_config_reasoning=llm_reasoning_config,
)
_, mindmap = mindmap_generator.generate_refined(
focus=focus,
main_theme=main_theme,
initial_mindmap=base_mindmap,
date_range=("2025-10-01", "2025-10-31"),
output_dir="./refined_mindmaps",
filename="refined_mindmap.json",
map_type=map_type,
instructions=instructions,
)
logger.info("Results: %s", mindmap["mindmap_text"])


def test_dynamic_mindmap(
main_theme,
focus,
map_type,
instructions,
llm_base_config: str = "openai::gpt-4o-mini",
llm_reasoning_config: str = "openai::o3-mini",
):
"""Test dynamic mindmap generation with two LLMs."""
logger.info("=" * 60)
logger.info("TEST 4: Dynamic MindMap Generation with Two LLMs")
logger.info("=" * 60)
mindmap_generator = MindMapGenerator(
llm_model_config_base=llm_base_config,
llm_model_config_reasoning=llm_reasoning_config,
)
_, mindmap = mindmap_generator.generate_dynamic(
instructions=instructions,
focus=focus,
main_theme=main_theme,
month_intervals=[
("2025-09-01", "2025-09-30"),
("2025-10-01", "2025-10-31"),
("2025-11-01", "2025-11-30"),
],
month_names=[
"September_2025",
"October_2025",
"November_2025",
],
)
logger.info("Results: %s", mindmap["base_mindmap"]["mindmap_json"])
logger.info("Results: %s", mindmap["October_2025"]["mindmap_json"])
logger.info("")


def main(
MAIN_THEME="Political Change in Japan.",
INSTRUCTIONS="Create a mindmap according to a given risk scenario. Map by risk type for any industry and assess short term impact only.",
FOCUS="Provide a detailed taxonomy of risks related to changes in the Japanese political landscape. Evaluate how the resignation of the Prime Minister and the pre-election of Sanae Takaichi will affect companies, their strategy and operations. Take into consideration their increased conservative stance on immigration, energy, and trade. Add any other risk areas that may arise from these political changes. The mind map should be as comprehensive as possible and cover all major risk areas.",
map_type="risk",
):
"""Run all tests."""
logger.info("Testing Grounded MindMap Generation")
logger.info("=" * 60)

try:
df_mindmap, base_mindmap = test_one_shot_mindmap(
MAIN_THEME,
FOCUS,
map_type,
INSTRUCTIONS,
llm_base_config="openai::gpt-4o-mini",
)
plot_mindmap(df_mindmap, MAIN_THEME)
test_refined_mindmap(
MAIN_THEME,
FOCUS,
map_type,
INSTRUCTIONS,
base_mindmap,
llm_base_config="openai::o3-mini",
)
test_refined_mindmap2(
MAIN_THEME,
FOCUS,
map_type,
INSTRUCTIONS,
base_mindmap,
llm_base_config="openai::o3-mini",
)
test_dynamic_mindmap(
MAIN_THEME,
FOCUS,
map_type,
INSTRUCTIONS,
llm_base_config="openai::gpt-4o-mini",
llm_reasoning_config="openai::o3-mini",
)

logger.info("=" * 60)
logger.info("All tests completed successfully")

except Exception as e:
logger.error("Error during testing: %s", e)
raise


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/risk_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def risk_analyzer_example(
control_entities=control_entities,
focus=focus, # Optional focus to narrow the theme,
llm_model_config=llm_model_config,
ground_mindmap=False,
)

class PrintObserver(Observer):
Expand Down
3 changes: 2 additions & 1 deletion examples/thematic_screener.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def thematic_screener_example(
end_date="2024-02-28",
document_type=DocumentType.TRANSCRIPTS,
fiscal_year=2024,
ground_mindmap=True,
)

class PrintObserver(Observer):
Expand Down Expand Up @@ -59,7 +60,7 @@ def update(self, message: OberserverNotification):
x = thematic_screener_example(
"Chip Manufacturers",
export_path=str(output_path),
llm_model_config="openai::gpt-5-mini",
llm_model_config="openai::gpt-4o-mini",
)
custom_config = {
"company_column": "Company",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"json-repair>=0.44.1",
"tabulate>=0.9.0,<1.0.0",
"plotly>=6.0.0,<7.0.0",
"matplotlib>=3.10.6,<4.0.0"
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions src/bigdata_research_tools/llm/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> st
async def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
tools: list[dict],
temperature: float = 0,
**kwargs,
) -> dict[str, list[dict] | str]:
Expand Down Expand Up @@ -234,7 +234,7 @@ def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str:
def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
tools: list[dict],
temperature: float = 0,
**kwargs,
) -> dict[str, list[dict] | str]:
Expand Down
15 changes: 6 additions & 9 deletions src/bigdata_research_tools/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def validate_reasoning_config(self):
self.reasoning_effort = (
self.reasoning_effort if self.reasoning_effort is not None else "high"
)
self.max_completion_tokens = None
if self.temperature is not None:
warnings.warn(
"The selected model does not support temperature settings. "
Expand Down Expand Up @@ -112,7 +113,7 @@ async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> st
async def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
tools: list[dict],
temperature: float = 0,
**kwargs,
) -> dict[str, list[dict] | str]:
Expand Down Expand Up @@ -202,7 +203,7 @@ async def get_stream_response(
async def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
tools: list[dict],
temperature: float = 0,
**kwargs,
) -> dict[str, list[dict] | str]:
Expand Down Expand Up @@ -244,8 +245,7 @@ def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str:
def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
temperature: float = 0,
tools: list[dict],
**kwargs,
) -> dict[str, list[dict] | str]:
"""
Expand Down Expand Up @@ -331,8 +331,7 @@ def get_stream_response(
def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
temperature: float = 0,
tools: list[dict],
**kwargs,
) -> dict[str, list[dict] | str]:
"""
Expand All @@ -352,9 +351,7 @@ def get_tools_response(
- arguments (list[dict]): List of arguments for each function
- text (str): The text content of the message, if any.
"""
return self.provider.get_tools_response(
chat_history, tools, temperature, **kwargs
)
return self.provider.get_tools_response(chat_history, tools, **kwargs)


class NotInitializedLLMProviderError(Exception):
Expand Down
9 changes: 4 additions & 5 deletions src/bigdata_research_tools/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

try:
from boto3 import Session # ty: ignore[unresolved-import]
from botocore import BaseClient # type: ignore[unresolved-import]
except ImportError:
raise ImportError(
"Missing optional dependency for LLM Bedrock provider, "
Expand Down Expand Up @@ -34,7 +33,7 @@ def configure_bedrock_client(self) -> None:
if not self._client:
self._client = Session(**self.connection_config)

def _get_bedrock_client(self) -> BaseClient:
def _get_bedrock_client(self):
if not self._client:
raise NotInitializedLLMProviderError(self)
return self._client.client("bedrock-runtime")
Expand Down Expand Up @@ -113,7 +112,7 @@ async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> st
async def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
tools: list[dict],
temperature: float = 0,
**kwargs,
) -> dict[str, list[dict] | str]:
Expand Down Expand Up @@ -195,7 +194,7 @@ def configure_bedrock_client(self) -> None:
if not self._client:
self._client = Session(**self.connection_config)

def _get_bedrock_client(self) -> BaseClient:
def _get_bedrock_client(self):
if not self._client:
raise NotInitializedLLMProviderError(self)
return self._client.client("bedrock-runtime")
Expand Down Expand Up @@ -274,7 +273,7 @@ def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str:
def get_tools_response(
self,
chat_history: list[dict[str, str]],
tools: list[dict[str, str]],
tools: list[dict],
temperature: float = 0,
**kwargs,
) -> dict[str, list[dict] | str]:
Expand Down
Loading