Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ format:
@uvx ruff format src/bigdata_research_tools/ examples/ tutorial/ tests/

format-check:
@uvx ruff format --check bigdata_thematic_screener/ tests/
@uvx ruff format --check src/bigdata_research_tools/ tests/

type-check:
@uvx ty@0.0.1a26 check --python-version 3.13 src/bigdata_research_tools/ examples/ tests/ # tutorial/ # Fix version to 3.13 due to this issue https://github.com/astral-sh/ty/issues/1355 # Ignore tutorials, the issues come from this open issue https://github.com/astral-sh/ty/issues/1297
2 changes: 2 additions & 0 deletions src/bigdata_research_tools/labeler/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def _run_labeling_prompts(
if provider == "bedrock":
llm = LLMEngine(
model=self.llm_model_config.model,
api_selection=self.llm_model_config.api_selection,
**self.llm_model_config.connection_config,
)
return run_parallel_prompts(
Expand All @@ -177,6 +178,7 @@ def _run_labeling_prompts(
else:
llm = AsyncLLMEngine(
model=self.llm_model_config.model,
api_selection=self.llm_model_config.api_selection,
**self.llm_model_config.connection_config,
)
return run_concurrent_prompts(
Expand Down
37 changes: 27 additions & 10 deletions src/bigdata_research_tools/llm/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ class AsyncAzureProvider(AsyncLLMProvider):
def __init__(
self,
model: str,
api_selection: str | None = None,
**connection_config,
):
super().__init__(model, **connection_config)
super().__init__(model, api_selection=api_selection, **connection_config)
self._client = None
self.configure_azure_client()

Expand Down Expand Up @@ -80,11 +81,18 @@ async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> st
last_exception = None
for attempt in range(max_retries):
try:
chat_completion = await self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)
if self.api_selection == "chat" or self.api_selection is None:
chat_completion = await self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)

return chat_completion.choices[0].message.content
elif self.api_selection == "responses":
response = await self._client.responses.create(
messages=chat_history, model=self.model, **kwargs
) # ty: ignore

return chat_completion.choices[0].message.content
return response.output[0].content[0].text
except Exception as e:
await asyncio.sleep(delay)
delay = 2 * delay + random.random() # exponential backoff
Expand Down Expand Up @@ -170,9 +178,10 @@ class AzureProvider(LLMProvider):
def __init__(
self,
model: str,
api_selection: str | None = None,
**connection_config,
):
super().__init__(model, **connection_config)
super().__init__(model, api_selection=api_selection, **connection_config)
self._client = None
self.configure_azure_client()

Expand Down Expand Up @@ -217,11 +226,19 @@ def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str:
last_exception = None
for attempt in range(max_retries):
try:
chat_completion = self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)
if self.api_selection == "chat" or self.api_selection is None:
chat_completion = self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)

return chat_completion.choices[0].message.content
elif self.api_selection == "responses":
response = self._client.responses.create(
messages=chat_history, model=self.model, **kwargs
) # ty: ignore

return response.output[0].content[0].text

return chat_completion.choices[0].message.content
except Exception as e:
time.sleep(delay)
delay = 2 * delay + random.random() # exponential backoff
Expand Down
90 changes: 79 additions & 11 deletions src/bigdata_research_tools/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class LLMConfig(BaseModel):

model: str
response_format: dict = {"type": "json_object"}
text: dict = {"format": {"type": "json_object"}}
temperature: float | None = None
reasoning_effort: str | None = None
top_p: float | None = 1
Expand All @@ -33,6 +34,11 @@ class LLMConfig(BaseModel):
description="A pair of key-value connection configurations for the LLM provider, the contents will be passed as kwargs to the provider client.",
)

api_selection: str | None = Field(
default="chat",
description="API selection for the OpenAI based LLM provider, e.g., 'chat' or 'responses'.",
)

@model_validator(mode="after")
def check_temperature_and_reasoning_effort(self):
## Only one of temperature or reasoning_effort should be set.
Expand All @@ -50,6 +56,20 @@ def check_temperature_and_reasoning_effort(self):
)
return self

@model_validator(mode="after")
def check_api_endpoint(self):
## Only one of temperature or reasoning_effort should be set.
if self.api_selection is not None:
if self.api_selection not in ["chat", "responses"]:
raise ValueError(
"Only chat or responses are supported as api_endpoint."
)
if self.model.split("::")[0] not in ["openai", "azure"]:
raise ValueError(
"api_selection is only supported for OpenAI and Azure providers."
)
return self

@model_validator(mode="after")
def validate_reasoning_config(self):
if any(rm in self.model for rm in REASONING_MODELS):
Expand Down Expand Up @@ -93,13 +113,30 @@ def get_llm_kwargs(
config_dict.pop("connection_config", None)
if remove_timeout:
config_dict.pop("timeout", None)
if self.api_selection is None or self.api_selection == "chat":
config_dict.pop("text", None)
if self.api_selection == "responses":
config_dict.pop("response_format", None)
config_dict.pop("frequency_penalty", None)
config_dict.pop("presence_penalty", None)
config_dict.pop("seed", None)
# Remove None values and model key
return {k: v for k, v in config_dict.items() if v is not None and k != "model"}
return {
k: v
for k, v in config_dict.items()
if v is not None and k != "model" and k != "api_selection"
}


class AsyncLLMProvider(ABC):
def __init__(self, model: str | None = None, **connection_config):
def __init__(
self,
model: str | None = None,
api_selection: str | None = None,
**connection_config,
):
self.model = model
self.api_selection = api_selection
self.connection_config = connection_config

@abstractmethod
Expand Down Expand Up @@ -147,7 +184,12 @@ async def get_stream_response(


class AsyncLLMEngine:
def __init__(self, model: str | None = None, **connection_config):
def __init__(
self,
model: str | None = None,
api_selection: str | None = None,
**connection_config,
):
if model is None:
model = os.getenv("BIGDATA_RESEARCH_DEFAULT_LLM", "openai::gpt-4o-mini")
source = "Environment"
Expand All @@ -166,6 +208,8 @@ def __init__(self, model: str | None = None, **connection_config):
"Invalid model format. It should be `<provider>::<model>`."
)

self.api_selection = api_selection

self.provider = self.load_provider(
provider_name=self.provider_name, **connection_config
)
Expand All @@ -177,16 +221,22 @@ def load_provider(
if provider == "openai":
from bigdata_research_tools.llm.openai import AsyncOpenAIProvider

return AsyncOpenAIProvider(model=self.model, **connection_config)
return AsyncOpenAIProvider(
model=self.model, api_selection=self.api_selection, **connection_config
)

elif provider == "bedrock":
from bigdata_research_tools.llm.bedrock import AsyncBedrockProvider

return AsyncBedrockProvider(model=self.model, **connection_config)
return AsyncBedrockProvider(
model=self.model, api_selection=self.api_selection, **connection_config
)
elif provider == "azure":
from bigdata_research_tools.llm.azure import AsyncAzureProvider

return AsyncAzureProvider(model=self.model, **connection_config)
return AsyncAzureProvider(
model=self.model, api_selection=self.api_selection, **connection_config
)
else:
logger.error(f"Invalid provider: `{self.provider}`")

Expand Down Expand Up @@ -230,8 +280,14 @@ async def get_tools_response(


class LLMProvider(ABC):
def __init__(self, model: str | None = None, **connection_config):
def __init__(
self,
model: str | None = None,
api_selection: str | None = None,
**connection_config,
):
self.model = model
self.api_selection = api_selection
self.connection_config = connection_config

@abstractmethod
Expand Down Expand Up @@ -278,12 +334,18 @@ def get_stream_response(


class LLMEngine:
def __init__(self, model: str | None = None, **connection_config):
def __init__(
self,
model: str | None = None,
api_selection: str | None = None,
**connection_config,
):
if model is None:
model = os.getenv("BIGDATA_RESEARCH_DEFAULT_LLM", "openai::gpt-4o-mini")
source = "Environment"
else:
source = "Argument"
self.api_selection = api_selection

try:
self.provider_name, self.model = model.split("::")
Expand All @@ -306,15 +368,21 @@ def load_provider(self, provider_name: str, **connection_config) -> LLMProvider:
if provider == "openai":
from bigdata_research_tools.llm.openai import OpenAIProvider

return OpenAIProvider(model=self.model, **connection_config)
return OpenAIProvider(
model=self.model, api_selection=self.api_selection, **connection_config
)
elif provider == "bedrock":
from bigdata_research_tools.llm.bedrock import BedrockProvider

return BedrockProvider(model=self.model, **connection_config)
return BedrockProvider(
model=self.model, api_selection=self.api_selection, **connection_config
)
elif provider == "azure":
from bigdata_research_tools.llm.azure import AzureProvider

return AzureProvider(model=self.model, **connection_config)
return AzureProvider(
model=self.model, api_selection=self.api_selection, **connection_config
)
else:
logger.error(f"Invalid provider: `{self.provider}`")

Expand Down
12 changes: 8 additions & 4 deletions src/bigdata_research_tools/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
class AsyncBedrockProvider(AsyncLLMProvider):
# Asynchronous boto3 is tricky, for now use the synchronous client, this will not
# provide the benefits from async, but will at least let our workflows run for now
def __init__(self, model: str, **connection_config):
super().__init__(model, **connection_config)
def __init__(
self, model: str, api_selection: str | None = None, **connection_config
):
super().__init__(model, api_selection=api_selection, **connection_config)
self._client: Session | None = None
self.configure_bedrock_client()

Expand Down Expand Up @@ -179,8 +181,10 @@ async def get_stream_response(


class BedrockProvider(LLMProvider):
def __init__(self, model: str, **connection_config):
super().__init__(model, **connection_config)
def __init__(
self, model: str, api_selection: str | None = None, **connection_config
):
super().__init__(model, api_selection=api_selection, **connection_config)
self._client: Session | None = None
self.configure_bedrock_client()

Expand Down
45 changes: 32 additions & 13 deletions src/bigdata_research_tools/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@


class AsyncOpenAIProvider(AsyncLLMProvider):
def __init__(self, model: str, **connection_config):
super().__init__(model, **connection_config)
def __init__(
self, model: str, api_selection: str | None = None, **connection_config
):
super().__init__(model, api_selection=api_selection, **connection_config)
self._client = None
self.configure_openai_client()

Expand All @@ -36,7 +38,7 @@ def configure_openai_client(self) -> None:
if not self._client:
self._client = AsyncOpenAI(**self.connection_config)

async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str:
async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: # ty: ignore
"""
Get the response from an LLM model from OpenAI.

Expand All @@ -49,11 +51,19 @@ async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> st
"""
if not self._client:
raise NotInitializedLLMProviderError(self)
chat_completion = await self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)

return chat_completion.choices[0].message.content
if self.api_selection == "chat" or self.api_selection is None:
chat_completion = await self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)

return chat_completion.choices[0].message.content

elif self.api_selection == "responses":
response = await self._client.responses.create(
input=chat_history, model=self.model, **kwargs
)
return response.output[0].content[0].text

async def get_tools_response(
self,
Expand Down Expand Up @@ -127,9 +137,10 @@ class OpenAIProvider(LLMProvider):
def __init__(
self,
model: str,
api_selection: str | None = None,
**connection_config,
):
super().__init__(model, **connection_config)
super().__init__(model, api_selection=api_selection, **connection_config)
self._client = None
self.configure_openai_client()

Expand All @@ -145,7 +156,7 @@ def configure_openai_client(self) -> None:
if not self._client:
self._client = OpenAI(**self.connection_config)

def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str:
def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: # ty: ignore
"""
Get the response from an LLM model from OpenAI.

Expand All @@ -158,11 +169,19 @@ def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str:
"""
if not self._client:
raise NotInitializedLLMProviderError(self)
chat_completion = self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)

return chat_completion.choices[0].message.content
if self.api_selection == "chat" or self.api_selection is None:
chat_completion = self._client.chat.completions.create(
messages=chat_history, model=self.model, **kwargs
)

return chat_completion.choices[0].message.content
elif self.api_selection == "responses":
response = self._client.responses.create(
input=chat_history, model=self.model, **kwargs
)

return response.output[0].content[0].text

def get_tools_response(
self,
Expand Down