diff --git a/README.md b/README.md index 7d35e0b..655bc72 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ challenge_final_results: bool = False judge_intervention: Optional[str] = None judge_metric: Optional[str] = None judge_endpoint_url: Optional[str] = None +judge_model_name: Optional[str] = None judge_api_key: str = "-" judge_always_intervene: bool = False ``` diff --git a/mallm/__init__.py b/mallm/__init__.py index c521d64..35d8e9b 100644 --- a/mallm/__init__.py +++ b/mallm/__init__.py @@ -1 +1 @@ -__version__='v1.0.4' +__version__ = 'v1.0.4' diff --git a/mallm/models/Chat.py b/mallm/models/Chat.py index 9f03fb8..20cbe7f 100644 --- a/mallm/models/Chat.py +++ b/mallm/models/Chat.py @@ -126,7 +126,8 @@ def _call( # type: ignore log_prob_sum = 0.0 for message in chat_completion: message_str = message.choices[0].delta.content - log_prob_sum += message.choices[0].logprobs.content[0].logprob + if message.choices[0].logprobs: + log_prob_sum += message.choices[0].logprobs.content[0].logprob if message_str and message_str not in self.stop_tokens: collected_messages.append(message_str) log_prob_sum = log_prob_sum / len(collected_messages) diff --git a/mallm/scheduler.py b/mallm/scheduler.py index 5dad52a..67f5303 100644 --- a/mallm/scheduler.py +++ b/mallm/scheduler.py @@ -147,15 +147,18 @@ def __init__(self, config: Config) -> None: self.llm = Chat( client=OpenAI( base_url=self.config.endpoint_url, api_key=self.config.api_key - ) + ), + model=self.config.model_name ) self.judge_llm = None if self.config.judge_endpoint_url: self.judge_llm = Chat( - client=OpenAI( - base_url=self.config.judge_endpoint_url, api_key=self.config.judge_api_key - ) + client=OpenAI( + base_url=self.config.judge_endpoint_url, + api_key=self.config.judge_api_key, + ), + model=self.config.judge_model_name, ) if config.response_generator not in RESPONSE_GENERATORS: diff --git a/mallm/utils/config.py b/mallm/utils/config.py index d01c539..ccdde01 100644 --- a/mallm/utils/config.py +++ b/mallm/utils/config.py @@ -53,6 +53,7 @@ class Config: judge_intervention: Optional[str] = None judge_metric: Optional[str] = None judge_endpoint_url: Optional[str] = None + judge_model_name: Optional[str] = None judge_api_key: str = "-" judge_always_intervene: bool = False @@ -117,15 +118,6 @@ def check_config(self) -> None: if self.endpoint_url.endswith("/"): logger.warning("Removing trailing / from the endpoint url.") self.endpoint_url = self.endpoint_url[:-1] - try: - logger.info("Testing availability of the endpoint...") - page = requests.head(self.endpoint_url.replace("/v1", "")) - logger.info("Status: " + str(page.status_code)) - assert page.status_code == 200 - except Exception as e: - logger.error("HTTP Error: Could not connect to the provided endpoint url.") - logger.error(e) - sys.exit(1) if self.concurrent_api_requests > 250: logger.warning( "concurrent_api_requests is very large. Please make sure the API endpoint you are using can handle that many simultaneous requests." diff --git a/mallm/utils/dicts.py b/mallm/utils/dicts.py index 2d0e07a..17dfb1b 100644 --- a/mallm/utils/dicts.py +++ b/mallm/utils/dicts.py @@ -19,7 +19,9 @@ from mallm.discourse_policy.report import DiscourseReport from mallm.models.discussion.CriticalResponseGenerator import CriticalResponseGenerator from mallm.models.discussion.FreeTextResponseGenerator import FreeTextResponseGenerator -from mallm.models.discussion.ReasoningResponseGenerator import ReasoningResponseGenerator +from mallm.models.discussion.ReasoningResponseGenerator import ( + ReasoningResponseGenerator, +) from mallm.models.discussion.ResponseGenerator import ResponseGenerator from mallm.models.discussion.SimpleResponseGenerator import SimpleResponseGenerator from mallm.models.discussion.SplitFreeTextResponseGenerator import (