diff --git a/Makefile b/Makefile index 73c358c..21747bd 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ format: @uvx ruff format src/bigdata_research_tools/ examples/ tutorial/ tests/ format-check: - @uvx ruff format --check bigdata_thematic_screener/ tests/ + @uvx ruff format --check src/bigdata_research_tools/ tests/ type-check: @uvx ty@0.0.1a26 check --python-version 3.13 src/bigdata_research_tools/ examples/ tests/ # tutorial/ # Fix version to 3.13 due to this issue https://github.com/astral-sh/ty/issues/1355 # Ignore tutorials, the issues come from this open issue https://github.com/astral-sh/ty/issues/1297 \ No newline at end of file diff --git a/src/bigdata_research_tools/labeler/labeler.py b/src/bigdata_research_tools/labeler/labeler.py index f0b9b38..5172b37 100644 --- a/src/bigdata_research_tools/labeler/labeler.py +++ b/src/bigdata_research_tools/labeler/labeler.py @@ -169,6 +169,7 @@ def _run_labeling_prompts( if provider == "bedrock": llm = LLMEngine( model=self.llm_model_config.model, + api_selection=self.llm_model_config.api_selection, **self.llm_model_config.connection_config, ) return run_parallel_prompts( @@ -177,6 +178,7 @@ def _run_labeling_prompts( else: llm = AsyncLLMEngine( model=self.llm_model_config.model, + api_selection=self.llm_model_config.api_selection, **self.llm_model_config.connection_config, ) return run_concurrent_prompts( diff --git a/src/bigdata_research_tools/llm/azure.py b/src/bigdata_research_tools/llm/azure.py index 37ac3b9..8a7de3e 100644 --- a/src/bigdata_research_tools/llm/azure.py +++ b/src/bigdata_research_tools/llm/azure.py @@ -33,9 +33,10 @@ class AsyncAzureProvider(AsyncLLMProvider): def __init__( self, model: str, + api_selection: str | None = None, **connection_config, ): - super().__init__(model, **connection_config) + super().__init__(model, api_selection=api_selection, **connection_config) self._client = None self.configure_azure_client() @@ -80,11 +81,18 @@ async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> st last_exception = None for attempt in range(max_retries): try: - chat_completion = await self._client.chat.completions.create( - messages=chat_history, model=self.model, **kwargs - ) + if self.api_selection == "chat" or self.api_selection is None: + chat_completion = await self._client.chat.completions.create( + messages=chat_history, model=self.model, **kwargs + ) + + return chat_completion.choices[0].message.content + elif self.api_selection == "responses": + response = await self._client.responses.create( + messages=chat_history, model=self.model, **kwargs + ) # ty: ignore - return chat_completion.choices[0].message.content + return response.output[0].content[0].text except Exception as e: await asyncio.sleep(delay) delay = 2 * delay + random.random() # exponential backoff @@ -170,9 +178,10 @@ class AzureProvider(LLMProvider): def __init__( self, model: str, + api_selection: str | None = None, **connection_config, ): - super().__init__(model, **connection_config) + super().__init__(model, api_selection=api_selection, **connection_config) self._client = None self.configure_azure_client() @@ -217,11 +226,19 @@ def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: last_exception = None for attempt in range(max_retries): try: - chat_completion = self._client.chat.completions.create( - messages=chat_history, model=self.model, **kwargs - ) + if self.api_selection == "chat" or self.api_selection is None: + chat_completion = self._client.chat.completions.create( + messages=chat_history, model=self.model, **kwargs + ) + + return chat_completion.choices[0].message.content + elif self.api_selection == "responses": + response = self._client.responses.create( + messages=chat_history, model=self.model, **kwargs + ) # ty: ignore + + return response.output[0].content[0].text - return chat_completion.choices[0].message.content except Exception as e: time.sleep(delay) delay = 2 * delay + random.random() # exponential backoff diff --git a/src/bigdata_research_tools/llm/base.py b/src/bigdata_research_tools/llm/base.py index d919ae1..0667823 100644 --- a/src/bigdata_research_tools/llm/base.py +++ b/src/bigdata_research_tools/llm/base.py @@ -18,6 +18,7 @@ class LLMConfig(BaseModel): model: str response_format: dict = {"type": "json_object"} + text: dict = {"format": {"type": "json_object"}} temperature: float | None = None reasoning_effort: str | None = None top_p: float | None = 1 @@ -33,6 +34,11 @@ class LLMConfig(BaseModel): description="A pair of key-value connection configurations for the LLM provider, the contents will be passed as kwargs to the provider client.", ) + api_selection: str | None = Field( + default="chat", + description="API selection for the OpenAI based LLM provider, e.g., 'chat' or 'responses'.", + ) + @model_validator(mode="after") def check_temperature_and_reasoning_effort(self): ## Only one of temperature or reasoning_effort should be set. @@ -50,6 +56,20 @@ def check_temperature_and_reasoning_effort(self): ) return self + @model_validator(mode="after") + def check_api_endpoint(self): + ## Only one of temperature or reasoning_effort should be set. + if self.api_selection is not None: + if self.api_selection not in ["chat", "responses"]: + raise ValueError( + "Only chat or responses are supported as api_endpoint." + ) + if self.model.split("::")[0] not in ["openai", "azure"]: + raise ValueError( + "api_selection is only supported for OpenAI and Azure providers." + ) + return self + @model_validator(mode="after") def validate_reasoning_config(self): if any(rm in self.model for rm in REASONING_MODELS): @@ -93,13 +113,30 @@ def get_llm_kwargs( config_dict.pop("connection_config", None) if remove_timeout: config_dict.pop("timeout", None) + if self.api_selection is None or self.api_selection == "chat": + config_dict.pop("text", None) + if self.api_selection == "responses": + config_dict.pop("response_format", None) + config_dict.pop("frequency_penalty", None) + config_dict.pop("presence_penalty", None) + config_dict.pop("seed", None) # Remove None values and model key - return {k: v for k, v in config_dict.items() if v is not None and k != "model"} + return { + k: v + for k, v in config_dict.items() + if v is not None and k != "model" and k != "api_selection" + } class AsyncLLMProvider(ABC): - def __init__(self, model: str | None = None, **connection_config): + def __init__( + self, + model: str | None = None, + api_selection: str | None = None, + **connection_config, + ): self.model = model + self.api_selection = api_selection self.connection_config = connection_config @abstractmethod @@ -147,7 +184,12 @@ async def get_stream_response( class AsyncLLMEngine: - def __init__(self, model: str | None = None, **connection_config): + def __init__( + self, + model: str | None = None, + api_selection: str | None = None, + **connection_config, + ): if model is None: model = os.getenv("BIGDATA_RESEARCH_DEFAULT_LLM", "openai::gpt-4o-mini") source = "Environment" @@ -166,6 +208,8 @@ def __init__(self, model: str | None = None, **connection_config): "Invalid model format. It should be `::`." ) + self.api_selection = api_selection + self.provider = self.load_provider( provider_name=self.provider_name, **connection_config ) @@ -177,16 +221,22 @@ def load_provider( if provider == "openai": from bigdata_research_tools.llm.openai import AsyncOpenAIProvider - return AsyncOpenAIProvider(model=self.model, **connection_config) + return AsyncOpenAIProvider( + model=self.model, api_selection=self.api_selection, **connection_config + ) elif provider == "bedrock": from bigdata_research_tools.llm.bedrock import AsyncBedrockProvider - return AsyncBedrockProvider(model=self.model, **connection_config) + return AsyncBedrockProvider( + model=self.model, api_selection=self.api_selection, **connection_config + ) elif provider == "azure": from bigdata_research_tools.llm.azure import AsyncAzureProvider - return AsyncAzureProvider(model=self.model, **connection_config) + return AsyncAzureProvider( + model=self.model, api_selection=self.api_selection, **connection_config + ) else: logger.error(f"Invalid provider: `{self.provider}`") @@ -230,8 +280,14 @@ async def get_tools_response( class LLMProvider(ABC): - def __init__(self, model: str | None = None, **connection_config): + def __init__( + self, + model: str | None = None, + api_selection: str | None = None, + **connection_config, + ): self.model = model + self.api_selection = api_selection self.connection_config = connection_config @abstractmethod @@ -278,12 +334,18 @@ def get_stream_response( class LLMEngine: - def __init__(self, model: str | None = None, **connection_config): + def __init__( + self, + model: str | None = None, + api_selection: str | None = None, + **connection_config, + ): if model is None: model = os.getenv("BIGDATA_RESEARCH_DEFAULT_LLM", "openai::gpt-4o-mini") source = "Environment" else: source = "Argument" + self.api_selection = api_selection try: self.provider_name, self.model = model.split("::") @@ -306,15 +368,21 @@ def load_provider(self, provider_name: str, **connection_config) -> LLMProvider: if provider == "openai": from bigdata_research_tools.llm.openai import OpenAIProvider - return OpenAIProvider(model=self.model, **connection_config) + return OpenAIProvider( + model=self.model, api_selection=self.api_selection, **connection_config + ) elif provider == "bedrock": from bigdata_research_tools.llm.bedrock import BedrockProvider - return BedrockProvider(model=self.model, **connection_config) + return BedrockProvider( + model=self.model, api_selection=self.api_selection, **connection_config + ) elif provider == "azure": from bigdata_research_tools.llm.azure import AzureProvider - return AzureProvider(model=self.model, **connection_config) + return AzureProvider( + model=self.model, api_selection=self.api_selection, **connection_config + ) else: logger.error(f"Invalid provider: `{self.provider}`") diff --git a/src/bigdata_research_tools/llm/bedrock.py b/src/bigdata_research_tools/llm/bedrock.py index f099586..7bc6190 100644 --- a/src/bigdata_research_tools/llm/bedrock.py +++ b/src/bigdata_research_tools/llm/bedrock.py @@ -18,8 +18,10 @@ class AsyncBedrockProvider(AsyncLLMProvider): # Asynchronous boto3 is tricky, for now use the synchronous client, this will not # provide the benefits from async, but will at least let our workflows run for now - def __init__(self, model: str, **connection_config): - super().__init__(model, **connection_config) + def __init__( + self, model: str, api_selection: str | None = None, **connection_config + ): + super().__init__(model, api_selection=api_selection, **connection_config) self._client: Session | None = None self.configure_bedrock_client() @@ -179,8 +181,10 @@ async def get_stream_response( class BedrockProvider(LLMProvider): - def __init__(self, model: str, **connection_config): - super().__init__(model, **connection_config) + def __init__( + self, model: str, api_selection: str | None = None, **connection_config + ): + super().__init__(model, api_selection=api_selection, **connection_config) self._client: Session | None = None self.configure_bedrock_client() diff --git a/src/bigdata_research_tools/llm/openai.py b/src/bigdata_research_tools/llm/openai.py index 54b7116..6e08522 100644 --- a/src/bigdata_research_tools/llm/openai.py +++ b/src/bigdata_research_tools/llm/openai.py @@ -19,8 +19,10 @@ class AsyncOpenAIProvider(AsyncLLMProvider): - def __init__(self, model: str, **connection_config): - super().__init__(model, **connection_config) + def __init__( + self, model: str, api_selection: str | None = None, **connection_config + ): + super().__init__(model, api_selection=api_selection, **connection_config) self._client = None self.configure_openai_client() @@ -36,7 +38,7 @@ def configure_openai_client(self) -> None: if not self._client: self._client = AsyncOpenAI(**self.connection_config) - async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: + async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: # ty: ignore """ Get the response from an LLM model from OpenAI. @@ -49,11 +51,19 @@ async def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> st """ if not self._client: raise NotInitializedLLMProviderError(self) - chat_completion = await self._client.chat.completions.create( - messages=chat_history, model=self.model, **kwargs - ) - return chat_completion.choices[0].message.content + if self.api_selection == "chat" or self.api_selection is None: + chat_completion = await self._client.chat.completions.create( + messages=chat_history, model=self.model, **kwargs + ) + + return chat_completion.choices[0].message.content + + elif self.api_selection == "responses": + response = await self._client.responses.create( + input=chat_history, model=self.model, **kwargs + ) + return response.output[0].content[0].text async def get_tools_response( self, @@ -127,9 +137,10 @@ class OpenAIProvider(LLMProvider): def __init__( self, model: str, + api_selection: str | None = None, **connection_config, ): - super().__init__(model, **connection_config) + super().__init__(model, api_selection=api_selection, **connection_config) self._client = None self.configure_openai_client() @@ -145,7 +156,7 @@ def configure_openai_client(self) -> None: if not self._client: self._client = OpenAI(**self.connection_config) - def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: + def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: # ty: ignore """ Get the response from an LLM model from OpenAI. @@ -158,11 +169,19 @@ def get_response(self, chat_history: list[dict[str, str]], **kwargs) -> str: """ if not self._client: raise NotInitializedLLMProviderError(self) - chat_completion = self._client.chat.completions.create( - messages=chat_history, model=self.model, **kwargs - ) - return chat_completion.choices[0].message.content + if self.api_selection == "chat" or self.api_selection is None: + chat_completion = self._client.chat.completions.create( + messages=chat_history, model=self.model, **kwargs + ) + + return chat_completion.choices[0].message.content + elif self.api_selection == "responses": + response = self._client.responses.create( + input=chat_history, model=self.model, **kwargs + ) + + return response.output[0].content[0].text def get_tools_response( self,