From ba6d1db01b8f0318eb6f271056006f3b5518f5a8 Mon Sep 17 00:00:00 2001 From: AI Edge Eval Team Date: Tue, 26 May 2026 21:24:10 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 921864329 --- model_eval/runners/base.py | 21 ++++++++++--- model_eval/runners/litert_lm/litert_lm.py | 6 ++-- model_eval/tests/unit/runners/base_test.py | 30 ++++++++++++++++-- .../unit/runners/litert_lm/litert_lm_test.py | 31 ++++++++++++++++++- 4 files changed, 79 insertions(+), 9 deletions(-) diff --git a/model_eval/runners/base.py b/model_eval/runners/base.py index 497e3e4..79e3bfb 100644 --- a/model_eval/runners/base.py +++ b/model_eval/runners/base.py @@ -24,6 +24,8 @@ import pydantic import requests +_DEFAULT_TIMEOUT_SECONDS = 300 + class RunnerType(enum.StrEnum): """Supported runner types.""" @@ -120,11 +122,18 @@ def describe_runner_args(cls) -> list[dict[str, Any]]: return introspection.get_fields(cls.config) def __enter__(self) -> "AbstractRunner": - self.start() + # Use reference counting to handle nested context manager entries. + if getattr(self, "_ref_count", 0) == 0: + self.start() + self._ref_count = 0 + self._ref_count += 1 return self def __exit__(self, *_) -> None: - self.stop() + if hasattr(self, "_ref_count"): + self._ref_count -= 1 + if self._ref_count == 0: + self.stop() def _validate_completions(self) -> None: """Verifies the /v1/chat/completions endpoint.""" @@ -135,7 +144,9 @@ def _validate_completions(self) -> None: "max_tokens": 1, } try: - response = requests.post(url, json=payload, timeout=20) + response = requests.post( + url, json=payload, timeout=_DEFAULT_TIMEOUT_SECONDS + ) response.raise_for_status() if "choices" not in response.json(): raise ValueError("Response missing 'choices' field.") @@ -155,7 +166,9 @@ def _validate_scoring(self) -> None: ], } try: - response = requests.post(url, json=payload, timeout=20) + response = requests.post( + url, json=payload, timeout=_DEFAULT_TIMEOUT_SECONDS + ) response.raise_for_status() data = response.json() choice = data["choices"][0] diff --git a/model_eval/runners/litert_lm/litert_lm.py b/model_eval/runners/litert_lm/litert_lm.py index 1626fd8..a0a6036 100644 --- a/model_eval/runners/litert_lm/litert_lm.py +++ b/model_eval/runners/litert_lm/litert_lm.py @@ -45,7 +45,7 @@ def _resolve_model_path(path: str) -> str: # pylint: disable=g-doc-args return path parts = path.split("/") - if len(parts) >= 3: + if len(parts) >= 3 and not path.startswith("/"): # Repo ID is the first two parts, i.e. "org/repo". repo_id = "/".join(parts[:2]) # The model path within the repo can be anything. @@ -209,7 +209,9 @@ def start(self) -> None: self._server_thread = threading.Thread(target=self._server.run, daemon=True) self._server_thread.start() - _litert_lm_server.wait_for_server(self.server_url) + _litert_lm_server.wait_for_server( + self.server_url, timeout=base._DEFAULT_TIMEOUT_SECONDS + ) if ( self.capabilities.text_generation diff --git a/model_eval/tests/unit/runners/base_test.py b/model_eval/tests/unit/runners/base_test.py index 604fbfe..05f81cf 100644 --- a/model_eval/tests/unit/runners/base_test.py +++ b/model_eval/tests/unit/runners/base_test.py @@ -62,7 +62,7 @@ def test_validate_completions_success(self, mock_post): "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1, }, - timeout=20, + timeout=base._DEFAULT_TIMEOUT_SECONDS, ) @mock.patch("model_eval.runners.base.requests.post") @@ -89,7 +89,7 @@ def test_validate_scoring_success(self, mock_post): {"role": "assistant", "content": "hi"}, ], }, - timeout=20, + timeout=base._DEFAULT_TIMEOUT_SECONDS, ) @mock.patch("model_eval.runners.base.requests.post") @@ -100,6 +100,32 @@ def test_validate_scoring_failure(self, mock_post): RuntimeError, "Runner failed scoring validation"): runner._validate_scoring() + def test_reentrancy_guard(self): + runner = DummyRunner() + with mock.patch.object( + runner, "start", wraps=runner.start + ) as mock_start, mock.patch.object( + runner, "stop", wraps=runner.stop + ) as mock_stop: + + # Enter recursively + with runner: + self.assertEqual(runner._ref_count, 1) + with runner: + self.assertEqual(runner._ref_count, 2) + # Server should only start once + mock_start.assert_called_once() + mock_stop.assert_not_called() + + # After inner exit, server should still be running (ref_count = 1) + self.assertEqual(runner._ref_count, 1) + mock_stop.assert_not_called() + + # After outer exit, server should stop (ref_count = 0) + self.assertEqual(runner._ref_count, 0) + mock_start.assert_called_once() + mock_stop.assert_called_once() + if __name__ == "__main__": unittest.main() diff --git a/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py b/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py index 3df3030..ec2bf1b 100644 --- a/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py +++ b/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py @@ -17,6 +17,7 @@ import unittest from unittest import mock +from model_eval.runners import base from model_eval.runners.litert_lm import litert_lm @@ -93,7 +94,9 @@ def test_initialization( ) # Verify wait_for_server. - mock_wait_for_server.assert_called_once_with("http://0.0.0.0:9090") + mock_wait_for_server.assert_called_once_with( + "http://0.0.0.0:9090", timeout=base._DEFAULT_TIMEOUT_SECONDS + ) runner.stop() @@ -234,6 +237,32 @@ def test_hf_path_resolution(self, mock_engine, mock_download, mock_exists): max_num_tokens=mock.ANY, ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.os.path.exists" + ) + @mock.patch("huggingface_hub.hf_hub_download") + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.litert_lm.Engine" + ) + def test_absolute_path_resolution( + self, mock_engine, mock_download, mock_exists + ): + mock_exists.return_value = False + config = litert_lm.LiteRtLmRunner.Config( + runner_type="litert-lm", model_path="/Users/foo/bar/model.litertlm" + ) + runner = litert_lm.LiteRtLmRunner(config) + try: + runner.start() + except Exception: # pylint: disable=broad-except + pass + mock_download.assert_not_called() + mock_engine.assert_called_once_with( + "/Users/foo/bar/model.litertlm", + backend=mock.ANY, + max_num_tokens=mock.ANY, + ) + if __name__ == "__main__": unittest.main()