Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
venv/
llm_together.egg-info/
__pycache__
.aider*
51 changes: 25 additions & 26 deletions llm_together.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,26 @@

@llm.hookimpl
def register_models(register):
models = Together().client.models.list()
for model in sorted(models, key=lambda m: m.id):
register(Together(model))

together_instance = Together()
model_list = together_instance.get_models()

for model in model_list:
if 'isFeaturedModel' in model and model['isFeaturedModel']:
register(Together(model))

def not_nulls(data) -> dict:
return {key: value for key, value in data if value is not None}

class Together(llm.Model):
class Together(llm.KeyModel):
model_id = "llm-together"
needs_key = "together"
key_env_var = "TOGETHER_API_KEY"
default_stop = "<human>"
default_stop = ["<human>"]
can_stream = True

def get_models(self):
together.api_key = self.get_key()
return together.Models.list()

def __init__(self, model=None):
together.api_key = self.get_key()

if (model is not None):
self.model_id = model["name"]
self.client = together.Client(api_key=self.get_key())
if model:
self.model = model
self.model_id = model.id

class Options(llm.Options):
temperature: Optional[float] = Field(
Expand All @@ -46,7 +38,7 @@ class Options(llm.Options):
default=None,
)
max_tokens: Optional[int] = Field(
description="Maximum number of tokens to generate.", default=256
description="Maximum number of tokens to generate.", default=8192
)
top_p: Optional[float] = Field(
description=(
Expand All @@ -71,7 +63,7 @@ class Options(llm.Options):
default=None,
)

def execute(self, prompt, stream, response, conversation):
def execute(self, prompt, stream, response, conversation, key=None):
kwargs = dict(not_nulls(prompt.options))

user_prompt = "{}\n\n{}".format(prompt.system or "", prompt.prompt)
Expand All @@ -82,7 +74,9 @@ def execute(self, prompt, stream, response, conversation):
if conversation is not None:
for message in conversation.responses:
if 'prompt_format' in self.model["config"] and self.model["config"]['prompt_format']:
history += self.model["config"]["prompt_format"].format(prompt = message.prompt) + " " + message.text() + "\n"
formatted_prompt = self.model["config"]["prompt_format"].format(prompt=message.prompt)
message_text = message.text()
history += formatted_prompt + " " + message_text + "\n"
else:
history += "{}\n\n{}".format(message.prompt, message.text())+ "\n"

Expand All @@ -91,23 +85,28 @@ def execute(self, prompt, stream, response, conversation):


if 'stop' in self.model["config"]:
stop = self.model["config"]["stop"]
config_stop = self.model["config"]["stop"]
if isinstance(config_stop, list):
stop = config_stop
else:
stop = [config_stop]

if stream:
for token in together.Complete.create_streaming(
for chunk in self.client.completions.create(
prompt = history + "\n" + user_prompt,
model = self.model_id,
stream = True,
stop = stop,
**kwargs,
):
yield token

if chunk.choices and len(chunk.choices) > 0:
yield chunk.choices[0].text
else:
output = together.Complete.create(
output = self.client.completions.create(
prompt = history + "\n" + user_prompt,
model = self.model_id,
stop = stop,
**kwargs,
)

return [output['output']['choices'][0]['text']]
if output.choices and len(output.choices) > 0:
yield output.choices[0].text
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
[project]
dynamic = ["optional-dependencies"]
name = "llm-together"
version = "0.4"
version = "0.5"
license = { text="MIT" }
authors = [
{ name="Kévin Quesada" },
{ name="Christian Braun" },
]
description = "llm together plugin"
readme = "README.md"
requires-python = ">=3.7"
dependencies = [
"llm>=0.5", "together==0.2.9"
"llm>=0.22",
"together>=1.4.1"
]
classifiers = [
"License :: OSI Approved :: Apache Software License"
Expand Down