diff --git a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py index 0cd9f653..211fd671 100644 --- a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py +++ b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py @@ -10,6 +10,21 @@ DEFAULT_CONTEXT_WINDOW = 4096 + +def _parse_temperature(prop: LLMProviderProperty | None): + if prop is None: + return omit + + value = prop.value + if value is None: + return omit + + value = str(value).strip() + if value == "": + return omit + + return float(value) + class OpenAICompatibleChatModel(ChatModel): def __init__(self, provider: "OpenAICompatibleLLMProvider"): super().__init__(provider) @@ -19,6 +34,7 @@ def __init__(self, provider: "OpenAICompatibleLLMProvider"): LLMProviderProperty("model_id", "Model", "Model (must support streaming)", "", False), LLMProviderProperty("base_url", "Base URL", "Base URL", "", True), LLMProviderProperty("context_window", "Context window", "Context window length", "", True), + LLMProviderProperty("temperature", "Temperature", "Sampling temperature", "", True), ] @property @@ -46,6 +62,7 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: base_url = base_url_prop.value if base_url_prop is not None else None base_url = base_url if base_url.strip() != "" else None api_key = self.get_property("api_key").value + temperature = _parse_temperature(self.get_property("temperature")) client = OpenAI(base_url=base_url, api_key=api_key) resp = client.chat.completions.create( @@ -53,6 +70,7 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: messages=messages.copy(), tools=tools or omit, tool_choice=options.get("tool_choice", omit), + temperature=temperature, stream=stream, ) @@ -94,6 +112,7 @@ def __init__(self, provider: "OpenAICompatibleLLMProvider"): LLMProviderProperty("model_id", "Model", "Model", "", False), LLMProviderProperty("base_url", "Base URL", "Base URL", "", True), LLMProviderProperty("context_window", "Context window", "Context window length", "", True), + LLMProviderProperty("temperature", "Temperature", "Sampling temperature", "", True), ] @property @@ -142,6 +161,7 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple base_url = base_url_prop.value if base_url_prop is not None else None base_url = base_url if base_url and base_url.strip() != "" else None api_key = self.get_property("api_key").value + temperature = _parse_temperature(self.get_property("temperature")) client = OpenAI(base_url=base_url, api_key=api_key) resp = client.chat.completions.create( @@ -154,6 +174,7 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple """} ], max_tokens=1000, + temperature=temperature, stream=False, ) diff --git a/tests/test_openai_compatible_llm_provider.py b/tests/test_openai_compatible_llm_provider.py new file mode 100644 index 00000000..5e354859 --- /dev/null +++ b/tests/test_openai_compatible_llm_provider.py @@ -0,0 +1,62 @@ +from unittest.mock import Mock, patch + +from openai import omit + +from notebook_intelligence.llm_providers.openai_compatible_llm_provider import ( + OpenAICompatibleLLMProvider, + _parse_temperature, +) +from notebook_intelligence.api import LLMProviderProperty + + +def test_parse_temperature_omits_missing_values(): + assert _parse_temperature(None) is omit + assert _parse_temperature(LLMProviderProperty("temperature", "Temperature", "", " ", True)) is omit + + +def test_parse_temperature_returns_float_value(): + prop = LLMProviderProperty("temperature", "Temperature", "", "0.25", True) + assert _parse_temperature(prop) == 0.25 + + +@patch("notebook_intelligence.llm_providers.openai_compatible_llm_provider.OpenAI") +def test_chat_completions_passes_temperature(mock_openai_cls): + provider = OpenAICompatibleLLMProvider() + model = provider.chat_models[0] + model.set_property_value("api_key", "test-key") + model.set_property_value("model_id", "gpt-4.1") + model.set_property_value("temperature", "0.2") + + mock_client = Mock() + mock_response = Mock() + mock_response.model_dump_json.return_value = '{"choices": []}' + mock_response.choices = [] + mock_client.chat.completions.create.return_value = mock_response + mock_openai_cls.return_value = mock_client + + model.completions(messages=[{"role": "user", "content": "hi"}]) + + mock_client.chat.completions.create.assert_called_once() + assert mock_client.chat.completions.create.call_args.kwargs["temperature"] == 0.2 + + +@patch("notebook_intelligence.llm_providers.openai_compatible_llm_provider.OpenAI") +def test_inline_completions_omits_blank_temperature(mock_openai_cls): + provider = OpenAICompatibleLLMProvider() + model = provider.inline_completion_models[0] + model.set_property_value("api_key", "test-key") + model.set_property_value("model_id", "gpt-4.1") + model.set_property_value("temperature", "") + + mock_client = Mock() + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="```python\npass\n```"))] + mock_client.chat.completions.create.return_value = mock_response + mock_openai_cls.return_value = mock_client + + cancel_token = Mock(is_cancel_requested=False) + result = model.inline_completions("", "", "python", "test.py", Mock(), cancel_token) + + assert result.strip() == "pass" + mock_client.chat.completions.create.assert_called_once() + assert mock_client.chat.completions.create.call_args.kwargs["temperature"] is omit