Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ challenge_final_results: bool = False
judge_intervention: Optional[str] = None
judge_metric: Optional[str] = None
judge_endpoint_url: Optional[str] = None
judge_model_name: Optional[str] = None
judge_api_key: str = "-"
judge_always_intervene: bool = False
```
Expand Down
2 changes: 1 addition & 1 deletion mallm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__='v1.0.4'
__version__ = 'v1.0.4'
3 changes: 2 additions & 1 deletion mallm/models/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _call( # type: ignore
log_prob_sum = 0.0
for message in chat_completion:
message_str = message.choices[0].delta.content
log_prob_sum += message.choices[0].logprobs.content[0].logprob
if message.choices[0].logprobs:
log_prob_sum += message.choices[0].logprobs.content[0].logprob
if message_str and message_str not in self.stop_tokens:
collected_messages.append(message_str)
log_prob_sum = log_prob_sum / len(collected_messages)
Expand Down
11 changes: 7 additions & 4 deletions mallm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,18 @@ def __init__(self, config: Config) -> None:
self.llm = Chat(
client=OpenAI(
base_url=self.config.endpoint_url, api_key=self.config.api_key
)
),
model=self.config.model_name
)

self.judge_llm = None
if self.config.judge_endpoint_url:
self.judge_llm = Chat(
client=OpenAI(
base_url=self.config.judge_endpoint_url, api_key=self.config.judge_api_key
)
client=OpenAI(
base_url=self.config.judge_endpoint_url,
api_key=self.config.judge_api_key,
),
model=self.config.judge_model_name,
)

if config.response_generator not in RESPONSE_GENERATORS:
Expand Down
10 changes: 1 addition & 9 deletions mallm/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Config:
judge_intervention: Optional[str] = None
judge_metric: Optional[str] = None
judge_endpoint_url: Optional[str] = None
judge_model_name: Optional[str] = None
judge_api_key: str = "-"
judge_always_intervene: bool = False

Expand Down Expand Up @@ -117,15 +118,6 @@ def check_config(self) -> None:
if self.endpoint_url.endswith("/"):
logger.warning("Removing trailing / from the endpoint url.")
self.endpoint_url = self.endpoint_url[:-1]
try:
logger.info("Testing availability of the endpoint...")
page = requests.head(self.endpoint_url.replace("/v1", ""))
logger.info("Status: " + str(page.status_code))
assert page.status_code == 200
except Exception as e:
logger.error("HTTP Error: Could not connect to the provided endpoint url.")
logger.error(e)
sys.exit(1)
if self.concurrent_api_requests > 250:
logger.warning(
"concurrent_api_requests is very large. Please make sure the API endpoint you are using can handle that many simultaneous requests."
Expand Down
4 changes: 3 additions & 1 deletion mallm/utils/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from mallm.discourse_policy.report import DiscourseReport
from mallm.models.discussion.CriticalResponseGenerator import CriticalResponseGenerator
from mallm.models.discussion.FreeTextResponseGenerator import FreeTextResponseGenerator
from mallm.models.discussion.ReasoningResponseGenerator import ReasoningResponseGenerator
from mallm.models.discussion.ReasoningResponseGenerator import (
ReasoningResponseGenerator,
)
from mallm.models.discussion.ResponseGenerator import ResponseGenerator
from mallm.models.discussion.SimpleResponseGenerator import SimpleResponseGenerator
from mallm.models.discussion.SplitFreeTextResponseGenerator import (
Expand Down