Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions model_eval/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import pydantic
import requests

_DEFAULT_TIMEOUT_SECONDS = 300


class RunnerType(enum.StrEnum):
"""Supported runner types."""
Expand Down Expand Up @@ -120,11 +122,18 @@ def describe_runner_args(cls) -> list[dict[str, Any]]:
return introspection.get_fields(cls.config)

def __enter__(self) -> "AbstractRunner":
self.start()
# Use reference counting to handle nested context manager entries.
if getattr(self, "_ref_count", 0) == 0:
self.start()
self._ref_count = 0
self._ref_count += 1
return self

def __exit__(self, *_) -> None:
self.stop()
if hasattr(self, "_ref_count"):
self._ref_count -= 1
if self._ref_count == 0:
self.stop()

def _validate_completions(self) -> None:
"""Verifies the /v1/chat/completions endpoint."""
Expand All @@ -135,7 +144,9 @@ def _validate_completions(self) -> None:
"max_tokens": 1,
}
try:
response = requests.post(url, json=payload, timeout=20)
response = requests.post(
url, json=payload, timeout=_DEFAULT_TIMEOUT_SECONDS
)
response.raise_for_status()
if "choices" not in response.json():
raise ValueError("Response missing 'choices' field.")
Expand All @@ -155,7 +166,9 @@ def _validate_scoring(self) -> None:
],
}
try:
response = requests.post(url, json=payload, timeout=20)
response = requests.post(
url, json=payload, timeout=_DEFAULT_TIMEOUT_SECONDS
)
response.raise_for_status()
data = response.json()
choice = data["choices"][0]
Expand Down
6 changes: 4 additions & 2 deletions model_eval/runners/litert_lm/litert_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _resolve_model_path(path: str) -> str: # pylint: disable=g-doc-args
return path

parts = path.split("/")
if len(parts) >= 3:
if len(parts) >= 3 and not path.startswith("/"):
# Repo ID is the first two parts, i.e. "org/repo".
repo_id = "/".join(parts[:2])
# The model path within the repo can be anything.
Expand Down Expand Up @@ -209,7 +209,9 @@ def start(self) -> None:

self._server_thread = threading.Thread(target=self._server.run, daemon=True)
self._server_thread.start()
_litert_lm_server.wait_for_server(self.server_url)
_litert_lm_server.wait_for_server(
self.server_url, timeout=base._DEFAULT_TIMEOUT_SECONDS
)

if (
self.capabilities.text_generation
Expand Down
30 changes: 28 additions & 2 deletions model_eval/tests/unit/runners/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_validate_completions_success(self, mock_post):
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
},
timeout=20,
timeout=base._DEFAULT_TIMEOUT_SECONDS,
)

@mock.patch("model_eval.runners.base.requests.post")
Expand All @@ -89,7 +89,7 @@ def test_validate_scoring_success(self, mock_post):
{"role": "assistant", "content": "hi"},
],
},
timeout=20,
timeout=base._DEFAULT_TIMEOUT_SECONDS,
)

@mock.patch("model_eval.runners.base.requests.post")
Expand All @@ -100,6 +100,32 @@ def test_validate_scoring_failure(self, mock_post):
RuntimeError, "Runner failed scoring validation"):
runner._validate_scoring()

def test_reentrancy_guard(self):
runner = DummyRunner()
with mock.patch.object(
runner, "start", wraps=runner.start
) as mock_start, mock.patch.object(
runner, "stop", wraps=runner.stop
) as mock_stop:

# Enter recursively
with runner:
self.assertEqual(runner._ref_count, 1)
with runner:
self.assertEqual(runner._ref_count, 2)
# Server should only start once
mock_start.assert_called_once()
mock_stop.assert_not_called()

# After inner exit, server should still be running (ref_count = 1)
self.assertEqual(runner._ref_count, 1)
mock_stop.assert_not_called()

# After outer exit, server should stop (ref_count = 0)
self.assertEqual(runner._ref_count, 0)
mock_start.assert_called_once()
mock_stop.assert_called_once()


if __name__ == "__main__":
unittest.main()
31 changes: 30 additions & 1 deletion model_eval/tests/unit/runners/litert_lm/litert_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest
from unittest import mock

from model_eval.runners import base
from model_eval.runners.litert_lm import litert_lm


Expand Down Expand Up @@ -93,7 +94,9 @@ def test_initialization(
)

# Verify wait_for_server.
mock_wait_for_server.assert_called_once_with("http://0.0.0.0:9090")
mock_wait_for_server.assert_called_once_with(
"http://0.0.0.0:9090", timeout=base._DEFAULT_TIMEOUT_SECONDS
)

runner.stop()

Expand Down Expand Up @@ -234,6 +237,32 @@ def test_hf_path_resolution(self, mock_engine, mock_download, mock_exists):
max_num_tokens=mock.ANY,
)

@mock.patch(
"model_eval.runners.litert_lm.litert_lm.os.path.exists"
)
@mock.patch("huggingface_hub.hf_hub_download")
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.litert_lm.Engine"
)
def test_absolute_path_resolution(
self, mock_engine, mock_download, mock_exists
):
mock_exists.return_value = False
config = litert_lm.LiteRtLmRunner.Config(
runner_type="litert-lm", model_path="/Users/foo/bar/model.litertlm"
)
runner = litert_lm.LiteRtLmRunner(config)
try:
runner.start()
except Exception: # pylint: disable=broad-except
pass
mock_download.assert_not_called()
mock_engine.assert_called_once_with(
"/Users/foo/bar/model.litertlm",
backend=mock.ANY,
max_num_tokens=mock.ANY,
)


if __name__ == "__main__":
unittest.main()