Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@

DEFAULT_CONTEXT_WINDOW = 4096


def _parse_temperature(prop: LLMProviderProperty | None):
if prop is None:
return omit

value = prop.value
if value is None:
return omit

value = str(value).strip()
if value == "":
return omit

return float(value)

class OpenAICompatibleChatModel(ChatModel):
def __init__(self, provider: "OpenAICompatibleLLMProvider"):
super().__init__(provider)
Expand All @@ -19,6 +34,7 @@ def __init__(self, provider: "OpenAICompatibleLLMProvider"):
LLMProviderProperty("model_id", "Model", "Model (must support streaming)", "", False),
LLMProviderProperty("base_url", "Base URL", "Base URL", "", True),
LLMProviderProperty("context_window", "Context window", "Context window length", "", True),
LLMProviderProperty("temperature", "Temperature", "Sampling temperature", "", True),
]

@property
Expand Down Expand Up @@ -46,13 +62,15 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response:
base_url = base_url_prop.value if base_url_prop is not None else None
base_url = base_url if base_url.strip() != "" else None
api_key = self.get_property("api_key").value
temperature = _parse_temperature(self.get_property("temperature"))

client = OpenAI(base_url=base_url, api_key=api_key)
resp = client.chat.completions.create(
model=model_id,
messages=messages.copy(),
tools=tools or omit,
tool_choice=options.get("tool_choice", omit),
temperature=temperature,
stream=stream,
)

Expand Down Expand Up @@ -94,6 +112,7 @@ def __init__(self, provider: "OpenAICompatibleLLMProvider"):
LLMProviderProperty("model_id", "Model", "Model", "", False),
LLMProviderProperty("base_url", "Base URL", "Base URL", "", True),
LLMProviderProperty("context_window", "Context window", "Context window length", "", True),
LLMProviderProperty("temperature", "Temperature", "Sampling temperature", "", True),
]

@property
Expand Down Expand Up @@ -142,6 +161,7 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple
base_url = base_url_prop.value if base_url_prop is not None else None
base_url = base_url if base_url and base_url.strip() != "" else None
api_key = self.get_property("api_key").value
temperature = _parse_temperature(self.get_property("temperature"))

client = OpenAI(base_url=base_url, api_key=api_key)
resp = client.chat.completions.create(
Expand All @@ -154,6 +174,7 @@ def inline_completions(self, prefix, suffix, language, filename, context: Comple
"""}
],
max_tokens=1000,
temperature=temperature,
stream=False,
)

Expand Down
62 changes: 62 additions & 0 deletions tests/test_openai_compatible_llm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from unittest.mock import Mock, patch

from openai import omit

from notebook_intelligence.llm_providers.openai_compatible_llm_provider import (
OpenAICompatibleLLMProvider,
_parse_temperature,
)
from notebook_intelligence.api import LLMProviderProperty


def test_parse_temperature_omits_missing_values():
assert _parse_temperature(None) is omit
assert _parse_temperature(LLMProviderProperty("temperature", "Temperature", "", " ", True)) is omit


def test_parse_temperature_returns_float_value():
prop = LLMProviderProperty("temperature", "Temperature", "", "0.25", True)
assert _parse_temperature(prop) == 0.25


@patch("notebook_intelligence.llm_providers.openai_compatible_llm_provider.OpenAI")
def test_chat_completions_passes_temperature(mock_openai_cls):
provider = OpenAICompatibleLLMProvider()
model = provider.chat_models[0]
model.set_property_value("api_key", "test-key")
model.set_property_value("model_id", "gpt-4.1")
model.set_property_value("temperature", "0.2")

mock_client = Mock()
mock_response = Mock()
mock_response.model_dump_json.return_value = '{"choices": []}'
mock_response.choices = []
mock_client.chat.completions.create.return_value = mock_response
mock_openai_cls.return_value = mock_client

model.completions(messages=[{"role": "user", "content": "hi"}])

mock_client.chat.completions.create.assert_called_once()
assert mock_client.chat.completions.create.call_args.kwargs["temperature"] == 0.2


@patch("notebook_intelligence.llm_providers.openai_compatible_llm_provider.OpenAI")
def test_inline_completions_omits_blank_temperature(mock_openai_cls):
provider = OpenAICompatibleLLMProvider()
model = provider.inline_completion_models[0]
model.set_property_value("api_key", "test-key")
model.set_property_value("model_id", "gpt-4.1")
model.set_property_value("temperature", "")

mock_client = Mock()
mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="```python\npass\n```"))]
mock_client.chat.completions.create.return_value = mock_response
mock_openai_cls.return_value = mock_client

cancel_token = Mock(is_cancel_requested=False)
result = model.inline_completions("", "", "python", "test.py", Mock(), cancel_token)

assert result.strip() == "pass"
mock_client.chat.completions.create.assert_called_once()
assert mock_client.chat.completions.create.call_args.kwargs["temperature"] is omit