From 4b0716f4a6544ec9da9cc6b14ffdc37add60286f Mon Sep 17 00:00:00 2001 From: jonas-becker Date: Mon, 2 Jun 2025 16:56:12 +0100 Subject: [PATCH 1/2] vllm fixes --- mallm/__init__.py | 2 +- mallm/models/Chat.py | 3 ++- mallm/scheduler.py | 11 +++++++---- mallm/utils/config.py | 10 +--------- mallm/utils/dicts.py | 4 +++- 5 files changed, 14 insertions(+), 16 deletions(-) diff --git a/mallm/__init__.py b/mallm/__init__.py index c521d64c..35d8e9b2 100644 --- a/mallm/__init__.py +++ b/mallm/__init__.py @@ -1 +1 @@ -__version__='v1.0.4' +__version__ = 'v1.0.4' diff --git a/mallm/models/Chat.py b/mallm/models/Chat.py index 9f03fb86..20cbe7fc 100644 --- a/mallm/models/Chat.py +++ b/mallm/models/Chat.py @@ -126,7 +126,8 @@ def _call( # type: ignore log_prob_sum = 0.0 for message in chat_completion: message_str = message.choices[0].delta.content - log_prob_sum += message.choices[0].logprobs.content[0].logprob + if message.choices[0].logprobs: + log_prob_sum += message.choices[0].logprobs.content[0].logprob if message_str and message_str not in self.stop_tokens: collected_messages.append(message_str) log_prob_sum = log_prob_sum / len(collected_messages) diff --git a/mallm/scheduler.py b/mallm/scheduler.py index 5dad52a3..67f53030 100644 --- a/mallm/scheduler.py +++ b/mallm/scheduler.py @@ -147,15 +147,18 @@ def __init__(self, config: Config) -> None: self.llm = Chat( client=OpenAI( base_url=self.config.endpoint_url, api_key=self.config.api_key - ) + ), + model=self.config.model_name ) self.judge_llm = None if self.config.judge_endpoint_url: self.judge_llm = Chat( - client=OpenAI( - base_url=self.config.judge_endpoint_url, api_key=self.config.judge_api_key - ) + client=OpenAI( + base_url=self.config.judge_endpoint_url, + api_key=self.config.judge_api_key, + ), + model=self.config.judge_model_name, ) if config.response_generator not in RESPONSE_GENERATORS: diff --git a/mallm/utils/config.py b/mallm/utils/config.py index d01c539f..ccdde01a 100644 --- a/mallm/utils/config.py +++ b/mallm/utils/config.py @@ -53,6 +53,7 @@ class Config: judge_intervention: Optional[str] = None judge_metric: Optional[str] = None judge_endpoint_url: Optional[str] = None + judge_model_name: Optional[str] = None judge_api_key: str = "-" judge_always_intervene: bool = False @@ -117,15 +118,6 @@ def check_config(self) -> None: if self.endpoint_url.endswith("/"): logger.warning("Removing trailing / from the endpoint url.") self.endpoint_url = self.endpoint_url[:-1] - try: - logger.info("Testing availability of the endpoint...") - page = requests.head(self.endpoint_url.replace("/v1", "")) - logger.info("Status: " + str(page.status_code)) - assert page.status_code == 200 - except Exception as e: - logger.error("HTTP Error: Could not connect to the provided endpoint url.") - logger.error(e) - sys.exit(1) if self.concurrent_api_requests > 250: logger.warning( "concurrent_api_requests is very large. Please make sure the API endpoint you are using can handle that many simultaneous requests." diff --git a/mallm/utils/dicts.py b/mallm/utils/dicts.py index 2d0e07af..17dfb1b9 100644 --- a/mallm/utils/dicts.py +++ b/mallm/utils/dicts.py @@ -19,7 +19,9 @@ from mallm.discourse_policy.report import DiscourseReport from mallm.models.discussion.CriticalResponseGenerator import CriticalResponseGenerator from mallm.models.discussion.FreeTextResponseGenerator import FreeTextResponseGenerator -from mallm.models.discussion.ReasoningResponseGenerator import ReasoningResponseGenerator +from mallm.models.discussion.ReasoningResponseGenerator import ( + ReasoningResponseGenerator, +) from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.models.discussion.SimpleResponseGenerator import SimpleResponseGenerator from mallm.models.discussion.SplitFreeTextResponseGenerator import ( From 78d278cfca0b795d9388c692219b4264d85657a2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 2 Jun 2025 15:01:31 +0000 Subject: [PATCH 2/2] Updated README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 7d35e0bc..655bc72a 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ challenge_final_results: bool = False judge_intervention: Optional[str] = None judge_metric: Optional[str] = None judge_endpoint_url: Optional[str] = None +judge_model_name: Optional[str] = None judge_api_key: str = "-" judge_always_intervene: bool = False ```