diff --git a/.gitignore b/.gitignore index 06911bb..105d2f8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ venv/ llm_together.egg-info/ __pycache__ +.aider* diff --git a/llm_together.py b/llm_together.py index 6cf5c86..17b71e3 100644 --- a/llm_together.py +++ b/llm_together.py @@ -5,34 +5,26 @@ @llm.hookimpl def register_models(register): + models = Together().client.models.list() + for model in sorted(models, key=lambda m: m.id): + register(Together(model)) - together_instance = Together() - model_list = together_instance.get_models() - - for model in model_list: - if 'isFeaturedModel' in model and model['isFeaturedModel']: - register(Together(model)) def not_nulls(data) -> dict: return {key: value for key, value in data if value is not None} -class Together(llm.Model): +class Together(llm.KeyModel): model_id = "llm-together" needs_key = "together" key_env_var = "TOGETHER_API_KEY" - default_stop = "" + default_stop = [""] can_stream = True - def get_models(self): - together.api_key = self.get_key() - return together.Models.list() - def __init__(self, model=None): - together.api_key = self.get_key() - - if (model is not None): - self.model_id = model["name"] + self.client = together.Client(api_key=self.get_key()) + if model: self.model = model + self.model_id = model.id class Options(llm.Options): temperature: Optional[float] = Field( @@ -46,7 +38,7 @@ class Options(llm.Options): default=None, ) max_tokens: Optional[int] = Field( - description="Maximum number of tokens to generate.", default=256 + description="Maximum number of tokens to generate.", default=8192 ) top_p: Optional[float] = Field( description=( @@ -71,7 +63,7 @@ class Options(llm.Options): default=None, ) - def execute(self, prompt, stream, response, conversation): + def execute(self, prompt, stream, response, conversation, key=None): kwargs = dict(not_nulls(prompt.options)) user_prompt = "{}\n\n{}".format(prompt.system or "", prompt.prompt) @@ -82,7 +74,9 @@ def execute(self, prompt, stream, response, conversation): if conversation is not None: for message in conversation.responses: if 'prompt_format' in self.model["config"] and self.model["config"]['prompt_format']: - history += self.model["config"]["prompt_format"].format(prompt = message.prompt) + " " + message.text() + "\n" + formatted_prompt = self.model["config"]["prompt_format"].format(prompt=message.prompt) + message_text = message.text() + history += formatted_prompt + " " + message_text + "\n" else: history += "{}\n\n{}".format(message.prompt, message.text())+ "\n" @@ -91,23 +85,28 @@ def execute(self, prompt, stream, response, conversation): if 'stop' in self.model["config"]: - stop = self.model["config"]["stop"] + config_stop = self.model["config"]["stop"] + if isinstance(config_stop, list): + stop = config_stop + else: + stop = [config_stop] if stream: - for token in together.Complete.create_streaming( + for chunk in self.client.completions.create( prompt = history + "\n" + user_prompt, model = self.model_id, + stream = True, stop = stop, **kwargs, ): - yield token - + if chunk.choices and len(chunk.choices) > 0: + yield chunk.choices[0].text else: - output = together.Complete.create( + output = self.client.completions.create( prompt = history + "\n" + user_prompt, model = self.model_id, stop = stop, **kwargs, ) - - return [output['output']['choices'][0]['text']] + if output.choices and len(output.choices) > 0: + yield output.choices[0].text diff --git a/pyproject.toml b/pyproject.toml index 015c7b2..f4551ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,18 @@ [project] dynamic = ["optional-dependencies"] name = "llm-together" -version = "0.4" +version = "0.5" license = { text="MIT" } authors = [ { name="Kévin Quesada" }, + { name="Christian Braun" }, ] description = "llm together plugin" readme = "README.md" requires-python = ">=3.7" dependencies = [ - "llm>=0.5", "together==0.2.9" + "llm>=0.22", + "together>=1.4.1" ] classifiers = [ "License :: OSI Approved :: Apache Software License"