diff --git a/functionary/openai_types.py b/functionary/openai_types.py index e7dc5a4a..f4298433 100644 --- a/functionary/openai_types.py +++ b/functionary/openai_types.py @@ -121,7 +121,7 @@ class ChatCompletionRequest(BaseModel): function_call: Optional[Union[str, Function]] = None tool_choice: Optional[Union[str, Tool]] = None temperature: Optional[float] = 0.6 - top_p: Optional[float] = 1.0 + top_p: Optional[float] = 0.95 n: Optional[int] = 1 max_tokens: Optional[int] = 4096 stop: Optional[Union[str, List[str]]] = Field(default_factory=list) diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index 03870348..cc0ca702 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -8,8 +8,15 @@ from functionary.prompt_template.llava_prompt_template import LlavaLlama from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1 from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2 -from functionary.prompt_template.llama31_reasoning_prompt_template import Llama31ReasoningTemplate - +from functionary.prompt_template.llama31_reasoning_prompt_template import ( + Llama31ReasoningTemplate, +) +from functionary.prompt_template.r1_original_template import R1Template +from functionary.prompt_template.r1_distilled_qwen import R1DistilledQwen +from functionary.prompt_template.r1_distilled_llama import R1DistilledLlama +from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate +from functionary.prompt_template.gemma3_prompt_template import Gemma3Template +from functionary.prompt_template.cogito_prompt_template import CogitoPromptTemplate def get_available_prompt_template_versions() -> List[PromptTemplate]: """This function will get all the available prompt templates in the module. @@ -30,6 +37,10 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]: # we don't use get_prompt_template or this will return the parent class all_templates_obj.append(LlavaLlama.get_prompt_template()) all_templates_obj.append(Llama31ReasoningTemplate.get_prompt_template()) + all_templates_obj.append(R1DistilledLlama.get_prompt_template()) + all_templates_obj.append(R1DistilledQwen.get_prompt_template()) + all_templates_obj.append(Gemma3Template.get_prompt_template()) + all_templates_obj.append(CogitoPromptTemplate.get_prompt_template()) return all_templates_obj diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py index 89f69f78..fad33581 100644 --- a/functionary/prompt_template/base_template.py +++ b/functionary/prompt_template/base_template.py @@ -69,13 +69,12 @@ def initialize_grammar_sampling_gen_state(self, tool_choice: Optional[Any]) -> D """ raise NotImplementedError - @abstractmethod def get_additional_tokens(self) -> List[str]: """return list of added tokens if using this template Returns: List[str]: list of tokens, each token is a string """ - raise NotImplementedError + return [] @abstractmethod def get_stop_tokens_for_generation(self) -> List[str]: diff --git a/functionary/prompt_template/cogito_prompt_template.py b/functionary/prompt_template/cogito_prompt_template.py new file mode 100644 index 00000000..dc50bc37 --- /dev/null +++ b/functionary/prompt_template/cogito_prompt_template.py @@ -0,0 +1,17 @@ +from functionary.prompt_template.base_template import PromptTemplate +from functionary.prompt_template.qwen25_text_only_template import ( + Qwen25TextOnlyPromptTemplate, +) +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from functionary.prompt_template import prompt_utils +from functionary.openai_types import Function, Tool + + +class CogitoPromptTemplate(Qwen25TextOnlyPromptTemplate): + version = "cogito" + + def get_chat_template_jinja(self) -> str: + path_prefix = "./functionary/prompt_template/jinja_templates/" + with open(f"{path_prefix}cogito.txt", "r") as f: + template = f.read() + return template diff --git a/functionary/prompt_template/gemma3_prompt_template.py b/functionary/prompt_template/gemma3_prompt_template.py new file mode 100644 index 00000000..accdd6ec --- /dev/null +++ b/functionary/prompt_template/gemma3_prompt_template.py @@ -0,0 +1,38 @@ +import json +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.qwen25_text_only_template import ( + Qwen25TextOnlyPromptTemplate, +) + + +class Gemma3Template(Qwen25TextOnlyPromptTemplate): + version = "gemma3" + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return ["model\n"] + + def get_stop_tokens_for_generation(self) -> List[str]: + return [""] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_chat_template_jinja(self) -> str: + path_prefix = "./functionary/prompt_template/jinja_templates/" + with open(f"{path_prefix}{self.version}.txt", "r") as f: + template = f.read() + return template diff --git a/functionary/prompt_template/jinja_templates/cogito.txt b/functionary/prompt_template/jinja_templates/cogito.txt new file mode 100644 index 00000000..b96b43fc --- /dev/null +++ b/functionary/prompt_template/jinja_templates/cogito.txt @@ -0,0 +1,80 @@ +{# version=cogito #}{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages and messages[0]['role'] == 'system' %} + {{- 'Enable deep thinking subroutine.\n\n' + messages[0]['content'] }} + {%- else %} + {{- 'Enable deep thinking subroutine.' }} + {%- endif %} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages and messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\nEnable deep thinking subroutine\n\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nEnable deep thinking subroutine.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if message.role == "user" %} + {%- if message['content'] -%} + {%- if message['content'] is string -%} + {{ '<|im_start|>user\n' + message['content'] }} + {%- else -%} + {{ '<|im_start|>user\n' }} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {{ content['text'] }} + {%- else -%} + {{ '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {{ '<|im_end|>\n' }} + {%- endif -%} + {%- elif (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/gemma3.txt b/functionary/prompt_template/jinja_templates/gemma3.txt new file mode 100644 index 00000000..c7db6b81 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/gemma3.txt @@ -0,0 +1,64 @@ +{# version=gemma3 #}{{- bos_token -}}{%- if tools %} + {{- 'system\nYou are capable of executing available function(s) if required.' }} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n" }} +{%- else %} + {%- if messages and messages[0]['role'] == 'system' %} + {{- 'system\n' + messages[0]['content'] + '\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system") -%} + {{- '' + message.role + '\n' + message.content + '' + '\n' }} + {%- elif message.role == "assistant" %} + {{- 'model' }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + + {%- if "tool_calls" in message and message["tool_calls"] is not none -%} + {%- set tool_calls = message["tool_calls"] -%} + {%-else-%} + {%- set tool_calls = [] -%} + {%- endif -%} + + {%- for tool_call in tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- 'tool' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- 'model\n' }} +{%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt new file mode 100644 index 00000000..4b453552 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt @@ -0,0 +1,80 @@ +{# version=qwen2.5-text-only #}{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages and messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are a helpful assistant.' }} + {%- endif %} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages and messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if message.role == "user" %} + {%- if message['content'] -%} + {%- if message['content'] is string -%} + {{ '<|im_start|>user\n' + message['content'] }} + {%- else -%} + {{ '<|im_start|>user\n' }} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {{ content['text'] }} + {%- else -%} + {{ '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {{ '<|im_end|>\n' }} + {%- endif -%} + {%- elif (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/r1.txt b/functionary/prompt_template/jinja_templates/r1.txt new file mode 100644 index 00000000..eba6a224 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/r1.txt @@ -0,0 +1,102 @@ +{# version=r1 #}{%- if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{{- bos_token }} +{%- if not tools is defined -%} + {%- set tools = none -%} +{%- endif -%} + +{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%} +{%- if has_code_interpreter -%} + {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%} +{%- endif -%} +{%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} +{%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} +{%- endif -%} +{#- System message + builtin tools -#} +{#- If at lease: has_reasoning or has_code_interpreter or tools -#} +{%- if has_reasoning or has_code_interpreter or tools -%} + {{- "<|System|>\n" }} +{%- endif -%} +{%- if has_reasoning -%} + {{- "Reasoning Mode: On\n\n" }} +{%- endif -%} +{%- if has_code_interpreter -%} + {{- "Environment: ipython\n\n" }} +{%- endif %} +{%- if tools %} + {{- "\nYou have access to the following functions:\n\n" }} + {%- for t in tools %} + {%- if "type" in t -%} + {{ "Use the function '" + t["function"]["name"] + "' to '" + t["function"]["description"] + "'\n" + t["function"] | tojson() }} + {%- else -%} + {{ "Use the function '" + t["name"] + "' to '" + t["description"] + "'\n" + t | tojson }} + {%- endif -%} + {{- "\n\n" }} + {%- endfor %} + {{- '\nThink very carefully before calling functions.\n\n' -}} +{%- endif %} +{%- set ns = namespace( + is_first=false, + is_tool=false, + is_output_first=true, + system_prompt='', + is_first_sp=true, + last_message_is_tool=false +) -%} +{#- Process messages -#} +{%- for message in messages -%} + {%- if message['role'] == 'user'-%} + {# Reset tool and output first flags #} + {%- set ns.last_message_is_tool = false -%} + {%- set ns.is_tool = false -%} + {{-'<|User|>' + message['content'] -}} + {%- elif message['role'] == 'system' -%} + {%- set ns.last_message_is_tool = false -%} + {{- '<|System|>' + message['content'] + '\n' -}} + {%- elif message['role'] == 'assistant' -%} + {%- set ns.last_message_is_tool = false -%} + {# Need to check need to add <|tool▁outputs▁end|> if the previous message is tool message #} + {%- if ns.is_tool -%} + {{- '<|tool▁outputs▁end|>' -}} + {%- set ns.is_tool = false -%} + {%- endif -%} + {{-'<|Assistant|>'}} + {%- if message['content'] and message['content']|length > 0 -%} + {{ message['content'] }} + {%- endif -%} + {#Check if there is tool call#} + {%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0 -%} + {%- set ns.is_tool = true -%} + {%- set ns.is_output_first = true -%} + {{- '<|tool▁calls▁begin|>'}} + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call["function"]["name"] == "python" -%} + {%- set prefix = 'python' -%} + {%- else -%} + {%- set prefix = 'json' -%} + {%- endif -%} + {{- '<|tool▁call▁begin|>function<|tool▁sep|>' + tool_call['function']['name'] + '\n```' + prefix + '\n' + tool_call['function']['arguments'] + '\n```<|tool▁call▁end|>'}} + {%- endfor -%} + {{ '<|tool▁calls▁end|><|end▁of▁sentence|>' }} + {%- elif message['content'] and message['content']|length > 0 -%} + {{ '<|end▁of▁sentence|>' }} + {%- endif -%} + {%- elif message['role'] == 'tool' -%} + {%- set ns.last_message_is_tool = true -%} + {# Check if this is the result of the first tool call #} + {%- if ns.is_output_first -%} + {{- '<|tool▁outputs▁begin|>'}} + {%- set ns.is_output_first = false -%} + {%- endif -%} + {{- '<|tool▁output▁begin|>' + message["content"] + '<|tool▁output▁end|>'}} + {%- endif -%} +{%- endfor -%} +{# Need to add closing tag for list of tool messages #} +{%- if ns.last_message_is_tool -%} + {{- '<|tool▁outputs▁end|>' -}} +{%- endif -%} +{%- if add_generation_prompt and not ns.is_tool -%} + {{- '<|Assistant|>'}} +{% endif %} diff --git a/functionary/prompt_template/jinja_templates/r1_distilled_llama.txt b/functionary/prompt_template/jinja_templates/r1_distilled_llama.txt new file mode 100644 index 00000000..5b614d64 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/r1_distilled_llama.txt @@ -0,0 +1,67 @@ +{# version=r1_distilled_llama #}{%- if not tools is defined -%} + {%- set tools = none -%} +{%- endif -%} + +{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%} +{%- if has_code_interpreter -%} + {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%} +{%- endif -%} + +{%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} +{%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} +{%- endif -%} + +{#- System message + builtin tools #} +{{- bos_token }} +{%- if has_reasoning or has_code_interpreter or tools -%} + {{- "<|System|>\n" }} +{%- endif -%} +{%- if has_reasoning %} + {{- "Reasoning Mode: On\n\n" }} +{%- endif %} +{%- if has_code_interpreter %} + {{- "Environment: ipython\n\n" }} +{%- endif %} +{%- if tools %} + {{- "\nYou have access to the following functions:\n\n" }} + {%- for t in tools %} + {%- if "type" in t -%} + {{ "Use the function '" + t["function"]["name"] + "' to '" + t["function"]["description"] + "'\n" + t["function"] | tojson() }} + {%- else -%} + {{ "Use the function '" + t["name"] + "' to '" + t["description"] + "'\n" + t | tojson }} + {%- endif -%} + {{- "\n\n" }} + {%- endfor %} + {{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- If looking for real time information use relevant functions before falling back to brave_search\n- Function calls MUST follow the specified format, start with \n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line\n\n' -}} +{%- endif %} + +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{-'<|User|>' + message['content'] -}} + {%- elif message['role'] == 'system' -%} + {{- '<|System|>' + message['content'] + '\n' -}} + {%- elif message['role'] == 'tool' -%} + {{ 'ipython\n\n' + message['content'] }} + {%- else -%} + {{-'<|Assistant|>'}} + {%- if message['content'] and message['content']|length > 0 -%} + {{ message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0 -%} + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call["function"]["name"] == "python" -%} + {{ '<|python_tag|>' + tool_call['function']['arguments'] }} + {%- else -%} + {{ '' + tool_call['function']['arguments'] + '' }} + {%- endif -%} + {%- endfor -%} + {{ '<|eom_id|>' }} + {%- elif message['content'] and message['content']|length > 0 -%} + {{ '<|end▁of▁sentence|>' }} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{ '<|Assistant|>' }} +{%- endif -%} \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt b/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt new file mode 100644 index 00000000..bea2dc4f --- /dev/null +++ b/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt @@ -0,0 +1,63 @@ +{# version=r1_distilled_qwen #} +{%- if tools %} + {{- '<|System|>' }} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n" }} +{%- else %} +{%- endif %} +{%- for message in messages -%} + {%- if message.role == "user" -%} + {{ '<|User|>' + message['content'] }} + {%- elif message.role == "system" %} + {{ '<|System|>' + message['content'] + '\n'}} + {%- elif message.role == "assistant" %} + {{- '<|Assistant|>'}} + {%- if message.content %} + {{- message.content }} + {%- endif %} + {%- if "tool_calls" in message and message["tool_calls"] is not none -%} + {%- set tool_calls = message["tool_calls"] -%} + {%-else-%} + {%- set tool_calls = [] -%} + {%- endif -%} + {%- for tool_call in tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {{- '\n{"name": "' }} + {%- else %} + {{- '\n\n{"name": "' }} + {%- endif %} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '<|end▁of▁sentence|>' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|User|>' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|Assistant|>' }} +{%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/prompt_utils.py b/functionary/prompt_template/prompt_utils.py index 92522655..c6674be2 100644 --- a/functionary/prompt_template/prompt_utils.py +++ b/functionary/prompt_template/prompt_utils.py @@ -3,6 +3,7 @@ import random import string from copy import deepcopy +import json from io import BytesIO from typing import Dict, List, Optional, Union @@ -284,3 +285,53 @@ def resolve_json_refs(tools_or_functions): ) return tools + + +def convert_code_interpreter_to_function(tools: List[Dict]) -> List[Dict]: + _tools = [] + if tools: + for tool in tools: + if tool["type"] == "code_interpreter": + _tools.append( + { + "type": "function", + "function": { + "name": "python", + "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The python code to run", + } + }, + }, + }, + } + ) + else: + _tools.append(tool) + return _tools + + +def convert_code_interpreter_tool_calls(messages: List[Dict]) -> List[Dict]: + _messages = [] + for message in messages: + n_message = deepcopy(message) + tool_calls = n_message.get("tool_calls", []) or [] + if len(tool_calls) > 0: + for tool_call in tool_calls: + if tool_call["function"]["name"] == "python": + arguments = tool_call["function"][ + "arguments" + ] # currently the code is in string format + # check if argument is a valid JSON string or python code + try: # if this is a valid JSON string --> no need to change anything + json.loads(arguments) + except: + tool_call["function"]["arguments"] = json.dumps( + {"code": arguments}, ensure_ascii=False + ) + _messages.append(n_message) + return _messages diff --git a/functionary/prompt_template/qwen25_text_only_template.py b/functionary/prompt_template/qwen25_text_only_template.py new file mode 100644 index 00000000..fef076d3 --- /dev/null +++ b/functionary/prompt_template/qwen25_text_only_template.py @@ -0,0 +1,393 @@ +from functionary.prompt_template.base_template import PromptTemplate +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from functionary.prompt_template import prompt_utils +from functionary.openai_types import Function, Tool +import json +import copy +import math +import re + + +class Qwen25TextOnlyPromptTemplate(PromptTemplate): + version = "qwen2.5-text-only" + start_of_turn = "<|im_start|>" + end_of_turn = "<|im_end|>" + + def get_chat_template_jinja(self) -> str: + path_prefix = "./functionary/prompt_template/jinja_templates/" + with open(f"{path_prefix}{self.version}.txt", "r") as f: + template = f.read() + + return template + + def get_tool_choice_required_prefix(self) -> str: + return "\n" + + def get_prompt_from_messages( + self, + messages: List[Dict], + tools_or_functions: Optional[List[Dict]] = None, + bos_token: Optional[str] = "", + add_generation_prompt: bool = False, + ) -> str: + """This function is used to get the complete prompt for list of messages + + Args: + messages (List[Dict]): List of messages + tools_or_functions (Optional[List[Dict]], optional): List of tools or functions. Defaults to None. + + Returns: + str: the prompt for inference/training + """ + # handle code_interpreter + _tools = [] + if tools_or_functions: + _tools = prompt_utils.convert_code_interpreter_to_function( + tools_or_functions + ) + + # find the assistant message that tool_call is python + _messages = prompt_utils.convert_code_interpreter_tool_calls(messages) + prompt = super().get_prompt_from_messages( + messages=_messages, + tools_or_functions=_tools, + bos_token=bos_token, + add_generation_prompt=add_generation_prompt, + ) + return prompt + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return [f"{self.start_of_turn}assistant\n"] + + def get_stop_tokens_for_generation(self) -> List[str]: + return [self.end_of_turn] + + def get_force_function_call_prefix(self, function_name: str): + return """ +{"name": "{function_name}", "arguments""".replace( + "{function_name}", function_name + ) + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def parse_assistant_response( + self, llm_output: str, tool_choice: Any = None + ) -> Dict: + # first remove stop tokens if there exists + for stop in self.get_stop_tokens_for_generation(): + if llm_output.endswith(stop): + llm_output = llm_output[: -len(stop)] + + # add forced-function from tool_choice if exists + llm_output = ( + self.get_generation_prefix_for_tool_choice(tool_choice) + llm_output + ) + + # print(f"+++LLM_OUTPUT: {llm_output}") + # llm_output = post_process_llm_output(llm_output) + # print(f"+++LLM_OUTPUT after post-processing: {llm_output}") + text_content = "" + tool_call_strs = [] + + # Split on tool call tags + parts = llm_output.split("") + + if len(parts) > 0: + # First part is the text content + text_content = parts[0].strip() + + # Process remaining parts as tool calls + for part in parts[1:]: + if "" in part: + tool_call = part.split("")[0].strip() + if tool_call: + tool_call_strs.append(tool_call) + tool_calls = [] + for tool_call_str in tool_call_strs: + tool_call_dic = json.loads(tool_call_str) + if tool_call_dic["name"] == "python": + arguments = tool_call_dic["arguments"]["code"] + else: + arguments = json.dumps(tool_call_dic["arguments"], ensure_ascii=False) + tool_calls.append( + { + "type": "function", + "id": prompt_utils.get_random_tool_call_id(), + "function": { + "name": tool_call_dic["name"], + "arguments": arguments, + }, + } + ) + + return { + "role": "assistant", + "content": text_content if len(text_content) > 0 else None, + "tool_calls": None if len(tool_calls) == 0 else tool_calls, + } + + def initialize_fsm_gen_state( + self, + tool_choice: Union[str, Tool], + curr_text: str, + curr_tokens: Optional[List[int]], + add_code_interpreter: Optional[bool], + ) -> Dict: + """Initializes FSM state for both streaming and grammar sampling + + Args: + tool_choice (str): tool_choice provided by user + curr_text (str): Text to initialize in gen_state + curr_tokens (List[int]): Corresponding tokens of curr_text + add_code_interpreter (bool): Flag indicating whether to add "python" tool in options in "function" stage. + Returns: + Dict: generation state + """ + result = { + "stage": "start", + "func_index": -1, + "curr_text": curr_text, + "curr_tokens": curr_tokens, + "add_code_interpreter": add_code_interpreter, + } + if tool_choice == "required": # a tool must be used + result["stage"] = "function_name" + result["curr_text"] = "" + result["func_index"] += 1 + result["call_id"] = prompt_utils.get_random_tool_call_id() + + elif not isinstance(tool_choice, str): # a predefined tool is used + func_name = ( + tool_choice.name + if hasattr(tool_choice, "name") + else tool_choice.function.name + ) + result["stage"] = "function_name" + result["curr_text"] = '{"name": "{function_name}", "arguments'.replace( + "{function_name}", func_name + ) + result["func_index"] += 1 + result["call_id"] = prompt_utils.get_random_tool_call_id() + + return result + + def stream_delta_text( + self, + gen_state: Dict, + delta_text: str, + finish_reason: Optional[str], + tools_or_functions: List[Dict], + tool_choice: Any, + ) -> Tuple[Dict, Optional[Union[Dict, List[Dict]]]]: + if finish_reason is not None: # handle if finish + if gen_state["stage"] not in ["text_gen"]: + finish_reason = "tool_calls" + + end_response = prompt_utils.get_text_delta_response( + None, False, finish_reason + ) + last_response = None + # still need to check if there is st in buffer + if "buffer" in gen_state and len(gen_state["buffer"]) > 0: + if gen_state["stage"] == "text_gen": + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + last_response = prompt_utils.get_text_delta_response( + buffer_str, False, None + ) + + elif gen_state["stage"] == "function_arguments": + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + if buffer_str.endswith("}}"): + buffer_str = buffer_str[:-1] # remove the last "}" + + if len(buffer_str) > 0: + last_response = prompt_utils.get_function_delta_response( + gen_state, buffer_str, False, False, None + ) + elif gen_state["stage"] == "python": + last_response = return_all_code_from_buffer(gen_state) + + if last_response is not None: + return gen_state, [last_response, end_response] + else: + return gen_state, [end_response] + + current_text = gen_state["curr_text"] + delta_text + gen_state["curr_text"] = current_text + if gen_state["stage"] == "start": + if ( + gen_state.get("end_of_prev_function_call", False) and delta_text == "\n" + ): # ignore \n + gen_state["end_of_prev_function_call"] = False + gen_state["curr_text"] = "" + + elif delta_text == "": + # print(f"delta text: {delta_text}; go to function_name") + gen_state["stage"] = "function_name" + gen_state["curr_text"] = "" + gen_state["func_index"] += 1 + gen_state["call_id"] = prompt_utils.get_random_tool_call_id() + else: + # print(f"delta text: {delta_text}; go to text_gen") + gen_state["stage"] = "text_gen" + gen_state["curr_text"] = current_text + gen_state["buffer"] = ( + [] + ) # put to buffer before we return because we need to check the last item + responses = [ + prompt_utils.get_text_delta_response("", True, finish_reason) + ] + # gen_state["buffer"].extend(["", "\n"]) + if len(delta_text) > 0: + gen_state["buffer"].append(delta_text) + return gen_state, responses + + elif gen_state["stage"] == "function_name": + # wait until we get '{"name": "func_name", "arguments": {' + # print(f"current_text: {current_text}") + pattern = ( + r'\s*{"name"\s*:\s*"(?P.*)"\s*,\s*"arguments"\s*:\s*{' + ) + match = re.search(pattern, current_text) + if match: + _, end_ind = match.start(), match.end() + new_delta = current_text[end_ind - 1 :] + gen_state["curr_text"] = new_delta # -1 to retain "{" + gen_state["func_name"] = match.group("function_name") + gen_state["stage"] = ( + "function_arguments" + if gen_state["func_name"] != "python" + else "python" + ) + responses = [ + prompt_utils.get_function_delta_response( + gen_state, "", True, True, finish_reason + ) + ] # the chunk containing function_name only + gen_state["buffer"] = [] + if gen_state["func_name"] != "python": + gen_state["buffer"].append(new_delta) + + return gen_state, responses + else: + return gen_state, None + + elif gen_state["stage"] == "text_gen": + if delta_text == "": # start a tool call + # print("start a tool call after reasoning") + gen_state["stage"] = "function_name" + gen_state["curr_text"] = "" + gen_state["func_index"] += 1 + gen_state["call_id"] = prompt_utils.get_random_tool_call_id() + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + if len(buffer_str) > 0: + return gen_state, prompt_utils.get_text_delta_response( + buffer_str, False, finish_reason + ) + else: + gen_state["buffer"].append(delta_text) + if len(gen_state["buffer"]) >= 2: + delta_text_item = gen_state["buffer"].pop(0) + return gen_state, prompt_utils.get_text_delta_response( + delta_text_item, False, finish_reason + ) + + elif gen_state["stage"] == "function_arguments": + # check if current function is python, we need to stream the code string inside, not a json + if delta_text == "": + gen_state["stage"] = "start" + gen_state["curr_text"] = "" + gen_state["end_of_prev_function_call"] = True + # return all in the buffer but need to strip and remove the last "}" + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + if buffer_str.endswith("}}\n"): + buffer_str = buffer_str[:-2] # remove the last "}\n" + elif buffer_str.endswith("}}"): + buffer_str = buffer_str[:-1] # remove the last "}" + + return gen_state, prompt_utils.get_function_delta_response( + gen_state, buffer_str, False, False, finish_reason + ) + else: + gen_state["buffer"].append(delta_text) + if len(gen_state["buffer"]) >= 4: + delta_text_item = gen_state["buffer"].pop(0) + return gen_state, prompt_utils.get_function_delta_response( + gen_state, delta_text_item, False, False, finish_reason + ) + + elif gen_state["stage"] == "python": + return streamining_python_code(gen_state, delta_text) + return gen_state, None + + +def return_all_code_from_buffer(gen_state: Dict) -> Optional[Union[Dict, List[Dict]]]: + buffer_str = "".join(gen_state["buffer"]).rstrip() # remove \n at the end + if len(buffer_str) > 0: + return prompt_utils.get_function_delta_response( + gen_state, buffer_str, False, False, None + ) + return None + + +def streamining_python_code( + gen_state: Dict, delta_text: str +) -> Tuple[Dict, Optional[Union[Dict, List[Dict]]]]: + if "current_code" not in gen_state: + gen_state["current_code"] = "" + return gen_state, None + + current_text = gen_state["curr_text"] + current_code = gen_state["current_code"] + # try extracting the latest code from current_text + try: + if delta_text == "": # end of code + full_code_arg_str = current_text.rstrip("").strip() + if full_code_arg_str.endswith("}}"): + full_code_arg_str = full_code_arg_str[:-1] + new_code = json.loads(full_code_arg_str)["code"] + else: + new_code = json.loads(current_text + '"}')["code"] + delta_code = new_code[len(current_code) :] + + gen_state["buffer"].append(delta_code) + gen_state["current_code"] = new_code + except: # nothing changed + return gen_state, None + + if delta_text == "": + return gen_state, return_all_code_from_buffer(gen_state) + else: + if len(gen_state["buffer"]) >= 4: + delta_text_item = gen_state["buffer"].pop(0) + return gen_state, prompt_utils.get_function_delta_response( + gen_state, delta_text_item, False, False, None + ) + return gen_state, None + + +def match_pattern(pattern: str, text: str) -> Tuple[int, int]: + match = re.search(pattern, text) + if match: + return match.start(), match.end() + return -1, -1 diff --git a/functionary/prompt_template/r1_distilled_llama.py b/functionary/prompt_template/r1_distilled_llama.py new file mode 100644 index 00000000..d12cf2a8 --- /dev/null +++ b/functionary/prompt_template/r1_distilled_llama.py @@ -0,0 +1,45 @@ +import datetime +import json +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.llama31_reasoning_prompt_template import ( + Llama31Template, +) +import copy + + +class R1DistilledLlama(Llama31Template): + version = "r1_distilled_llama" + chat_template = None + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return [f"<|Assistant|>"] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_stop_tokens_for_generation(self) -> List[str]: + return ["<|end▁of▁sentence|>", "<|eom_id|>"] + + def get_chat_template_jinja(self) -> str: + if self.chat_template is None: + jinja_template_file = ( + "./functionary/prompt_template/jinja_templates/r1_distilled_llama.txt" + ) + with open(jinja_template_file, "r") as f: + self.chat_template = f.read() + return self.chat_template diff --git a/functionary/prompt_template/r1_distilled_qwen.py b/functionary/prompt_template/r1_distilled_qwen.py new file mode 100644 index 00000000..34541a4d --- /dev/null +++ b/functionary/prompt_template/r1_distilled_qwen.py @@ -0,0 +1,45 @@ +import datetime +import json +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.qwen25_text_only_template import ( + Qwen25TextOnlyPromptTemplate, +) +import copy + + +class R1DistilledQwen(Qwen25TextOnlyPromptTemplate): + version = "r1_distilled_qwen" + chat_template = None + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return [f"<|Assistant|>"] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_stop_tokens_for_generation(self) -> List[str]: + return ["<|end▁of▁sentence|>"] + + def get_chat_template_jinja(self) -> str: + if self.chat_template is None: + jinja_template_file = ( + "./functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt" + ) + with open(jinja_template_file, "r") as f: + self.chat_template = f.read() + return self.chat_template diff --git a/functionary/prompt_template/r1_original_template.py b/functionary/prompt_template/r1_original_template.py new file mode 100644 index 00000000..a58fce4c --- /dev/null +++ b/functionary/prompt_template/r1_original_template.py @@ -0,0 +1,130 @@ +import datetime +import json +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.base_template import PromptTemplate + + +def parse_tool_call(tool_call_str: str) -> Dict: + """parse the tool name and arguments from the tool call string. the format of the tool call string is: + function<|tool▁sep|> + ```type + + ``` + + Args: + tool_call_str (str): _description_ + + Returns: + Dict: {"name": , "args": } + """ + tool_sep_index = tool_call_str.find("<|tool▁sep|>") + brkline_index = tool_call_str.find("\n", tool_sep_index) + function_name = tool_call_str[ + tool_sep_index + len("<|tool▁sep|>") : brkline_index + ].strip() + # parse arguments + arguments_content = tool_call_str[brkline_index:].strip() + # strip ``` at the begining and the end + arguments_content = arguments_content.replace("```", "") + index = arguments_content.find("\n") # ignore: json\n or python\n + arguments_content = arguments_content[index:].strip() + + return { + "id": prompt_utils.get_random_tool_call_id(), + "type": "function", + "function": {"name": function_name, "arguments": arguments_content}, + } + + +def extract_text_inside(start_prefix: str, end_prefix: str, text: str) -> List[str]: + """extract all text inside the start_prefix and end_prefix, return a list of texts inside these two prefixes + + Args: + start_prefix (str): the prefix before the text to extract + end_prefix (str): the prefix after the text to extract + text (str): the text to extract + + Returns: + List[str]: a list of texts inside these two prefixes + """ + result = [] + current_pos = 0 + + while True: + # Find next start position + start_pos = text.find(start_prefix, current_pos) + if start_pos == -1: + break + + # Find matching end position + end_pos = text.find(end_prefix, start_pos + len(start_prefix)) + if end_pos == -1: + break + + # Extract text between prefixes + extracted = text[start_pos + len(start_prefix) : end_pos] + result.append(extracted) + + # Move current position past this match + current_pos = end_pos + len(end_prefix) + + return result + + +class R1Template(PromptTemplate): + version = "r1" + + def get_assistant_prefixes(self) -> List[str]: + return [f"<|Assistant|>"] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_stop_tokens_for_generation(self) -> List[str]: + return ["<|end▁of▁sentence|>"] + + def parse_assistant_response( + self, llm_output: str, tool_choice: Optional[Any] = None + ) -> Dict: + for stop in self.get_stop_tokens_for_generation(): + if llm_output.endswith(stop): + llm_output = llm_output[: -len(stop)] + + # add forced-function from tool_choice if exists + llm_output = ( + self.get_generation_prefix_for_tool_choice(tool_choice) + llm_output + ) + text_content = llm_output + tool_calls = [] + # parse the tool calls + # first extract the string about tool calls, which is inside: <|tool▁calls▁begin|><|tool▁calls▁end|> + all_tool_calls_list = extract_text_inside( + "<|tool▁calls▁begin|>", "<|tool▁calls▁end|>", llm_output + ) + if len(all_tool_calls_list) > 0: + all_tool_calls_str = all_tool_calls_list[0] + index = llm_output.find("<|tool▁calls▁begin|") + text_content = text_content[:index] + # extract tool calls inside: <|tool▁call▁begin|> & <|tool▁call▁end|> + tool_calls_strs = extract_text_inside( + "<|tool▁call▁begin|>", "<|tool▁call▁end|>", all_tool_calls_str + ) + for tool_call_str in tool_calls_strs: + tool_calls.append(parse_tool_call(tool_call_str)) + return { + "role": "assistant", + "content": text_content if text_content else None, + "tool_calls": tool_calls if tool_calls else None, + } diff --git a/functionary/train/packing/monkey_patch_packing.py b/functionary/train/packing/monkey_patch_packing.py index 2cbfff8c..d9063779 100644 --- a/functionary/train/packing/monkey_patch_packing.py +++ b/functionary/train/packing/monkey_patch_packing.py @@ -7,7 +7,9 @@ def get_max_seqlen_in_batch(attention_mask): max_num = torch.max(attention_mask) - # attention_mask: B x N + max_num = int( + max_num + ) # there are cases that type of attention_mask is bfloat16 such as Gemma3ForConditionalGeneration counts = [] for i in range(1, max_num + 1): counts.append( diff --git a/functionary/train/pyproject.toml b/functionary/train/pyproject.toml index cfb115a2..9733576a 100644 --- a/functionary/train/pyproject.toml +++ b/functionary/train/pyproject.toml @@ -4,11 +4,13 @@ version = "0.0.1" description = "Chat language model that can use tools and interpret the results" requires-python = ">=3.9" dependencies = [ + "wheel==0.45.1", + "setuptools==78.0.2", "accelerate==0.34.0", - "bitsandbytes==0.44.1", + "bitsandbytes==0.45.4", + "packaging", "scipy==1.11.3", "sentencepiece==0.1.99", - "packaging==23.1", "ninja==1.11.1", "einops==0.7.0", "wandb==0.15.11", @@ -18,8 +20,7 @@ dependencies = [ "tensorboard==2.15.1", "aenum==3.1.15", "transformers @ git+https://github.com/huggingface/transformers.git", - "flash-attn==2.6.3", - "json_source_map==1.0.5", + "json_source_map==1.0.5" ] [build-system] @@ -32,9 +33,11 @@ packages = ["train"] [project.optional-dependencies] liger = [ - "liger-kernel==0.3.1", + "liger-kernel==0.5.5", ] lora = [ "peft==0.13.2", "datasets==2.8.0", ] +[tool.pip] +install-options = { "flash-attn" = ["--no-build-isolation"] } \ No newline at end of file diff --git a/functionary/train/train.py b/functionary/train/train.py index ee95e200..2afb2de6 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -57,7 +57,7 @@ def lr_lambda(current_step): from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version from functionary.train.custom_datasets import read_dataset from functionary.train import training_utils -from training_utils import print_rank0 +from functionary.train.training_utils import print_rank0 LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -184,6 +184,8 @@ def train(): model_class = AutoLigerKernelForCausalLM else: model_class = transformers.AutoModelForCausalLM + # from transformers import Gemma3ForConditionalGeneration + # model_class = Gemma3ForConditionalGeneration model = model_class.from_pretrained( model_args.model_name_or_path, @@ -250,7 +252,7 @@ def train(): if training_args.do_eval: print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***") - training_utils.print_some_examples(eval_dataset, tokenizer) + # training_utils.print_some_examples(eval_dataset, tokenizer) def preprocess_logits_for_metrics(logits, labels): return training_utils.preprocess_logits_for_metrics( @@ -283,7 +285,7 @@ def compute_metrics(eval_preds): else: trainer.train() - trainer.save_state() + # trainer.save_state() # FSDP requires state_dict_type=FULL_STATE_DICT in order to save the model weights in .bin format if trainer.is_fsdp_enabled: diff --git a/functionary/train/training_utils.py b/functionary/train/training_utils.py index 6671be67..be0e190e 100644 --- a/functionary/train/training_utils.py +++ b/functionary/train/training_utils.py @@ -50,7 +50,7 @@ def initialize_tokenizer( # add chat_template for tokenizer tokenizer.chat_template = prompt_template.get_chat_template_jinja() - print("tokenizer: ", tokenizer) + # print("tokenizer: ", tokenizer) # Resize embedding model.resize_token_embeddings(len(tokenizer)) @@ -218,6 +218,7 @@ def print_some_examples(ds, tokenizer): print_rank0("shape of input_ids: ", batch["input_ids"].shape) # B x L print_rank0("shape of labels: ", batch["labels"].shape) print_rank0("shape of attention_mask: ", batch["attention_mask"].shape) + print_rank0("attetnion_mask", batch["attention_mask"]) # print_rank0('input_ids: ', batch["input_ids"].tolist()) # print_rank0('labels: ', batch["labels"].tolist()) print_rank0("attention mask: ", batch["attention_mask"]) diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index 166d2d66..bbc2bfd8 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -5,10 +5,11 @@ from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse -from vllm.entrypoints.openai.protocol import ( - LoadLoraAdapterRequest, - UnloadLoraAdapterRequest, -) + +# from vllm.entrypoints.openai.protocol import ( +# LoadLoraAdapterRequest, +# UnloadLoraAdapterRequest, +# ) from vllm.inputs import TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -44,7 +45,9 @@ async def check_length(request, input_ids, model_config): - if hasattr(model_config.hf_config, "max_sequence_length"): + if hasattr(model_config, "max_model_len"): + context_len = model_config.max_model_len + elif hasattr(model_config.hf_config, "max_sequence_length"): context_len = model_config.hf_config.max_sequence_length elif hasattr(model_config.hf_config, "seq_length"): context_len = model_config.hf_config.seq_length @@ -90,7 +93,7 @@ async def check_length(request, input_ids, model_config): async def process_load_lora_adapter( - request: LoadLoraAdapterRequest, + request: Any, served_loras: List[LoRARequest], lora_id_counter: AtomicCounter, ) -> Tuple[Union[str, JSONResponse], List[LoRARequest]]: @@ -128,8 +131,8 @@ async def process_load_lora_adapter( async def process_unload_lora_adapter( - request: UnloadLoraAdapterRequest, served_loras: List[LoRARequest] -) -> Tuple[Union[str, JSONResponse], List[LoRARequest]]: + request: Any, served_loras: List[Any] +) -> Tuple[Union[str, JSONResponse], List[Any]]: # Check if either 'lora_name' or 'lora_int_id' is provided if not request.lora_name and not request.lora_int_id: return ( diff --git a/pyproject.toml b/pyproject.toml index 83067ff5..806fcf35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ packages = ["functionary"] [project.optional-dependencies] vllm = [ - "vllm==0.6.3.post1; sys_platform != 'darwin'", + "vllm==0.8.2; sys_platform != 'darwin'", ] sglang = [ "python-multipart==0.0.12", diff --git a/server_sglang.py b/server_sglang.py index 291ea8fa..0e1a881d 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -172,10 +172,6 @@ async def stream_results() -> AsyncIterator[bytes]: @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - global tokenizer_manager, backend - - # if not args.grammar_sampling: - # backend = None return await v1_chat_completions(tokenizer_manager, None, raw_request, served_model) diff --git a/server_vllm.py b/server_vllm.py index 0e5121f2..50ab59dd 100644 --- a/server_vllm.py +++ b/server_vllm.py @@ -32,11 +32,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.api_server import mount_metrics from vllm.entrypoints.openai.protocol import ( - LoadLoraAdapterRequest, ModelCard, ModelList, ModelPermission, - UnloadLoraAdapterRequest, ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -131,7 +129,7 @@ async def create_chat_completion(raw_request: Request): @app.post("/v1/load_lora_adapter") -async def load_lora_adapter(request: LoadLoraAdapterRequest): +async def load_lora_adapter(request: Any): global served_loras error, served_loras = await process_load_lora_adapter( @@ -145,7 +143,7 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest): @app.post("/v1/unload_lora_adapter") -async def unload_lora_adapter(request: UnloadLoraAdapterRequest): +async def unload_lora_adapter(request: Any): global served_loras error, served_loras = await process_unload_lora_adapter(request, served_loras) diff --git a/tests/prompt_test_cogito.txt b/tests/prompt_test_cogito.txt new file mode 100644 index 00000000..a7b912ae --- /dev/null +++ b/tests/prompt_test_cogito.txt @@ -0,0 +1,87 @@ +<|im_start|>system +Enable deep thinking subroutine. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +who is the CEO of Meetkai<|im_end|> +<|im_start|>assistant +James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|im_end|> +<|im_start|>user +is the car Song more expensive than car Tang?<|im_end|> +<|im_start|>assistant +I will get the price of 2 cars and compare + +{"name": "get_car_price", "arguments": {"car_name": "Song"}} + + +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} +<|im_end|> +<|im_start|>user + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} +<|im_end|> +<|im_start|>assistant +No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|im_end|> +<|im_start|>user +what's the weather like in Hanoi?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "Hanoi"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 10}} +<|im_end|> +<|im_start|>assistant +The temperature in Hanoi is: 10 degree Celcious<|im_end|> +<|im_start|>user +Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|im_end|> +<|im_start|>assistant +I'll use code interpreter to handle this + +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} +<|im_end|> +<|im_start|>user + +ValueError: list.remove(x): x not in list +<|im_end|> +<|im_start|>assistant +I will fix the code + +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} +<|im_end|> +<|im_start|>user + +[0,1,2,3,5,] +<|im_end|> +<|im_start|>assistant +The final list is: 0,1,2,3,5<|im_end|> +<|im_start|>user +Thanks! What's the weather in San Francisco?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 20}} +<|im_end|> +<|im_start|>assistant +The temperature in San Francisco is: 20 degree Celcious<|im_end|> \ No newline at end of file diff --git a/tests/prompt_test_gemma3.txt b/tests/prompt_test_gemma3.txt new file mode 100644 index 00000000..fb0bca16 --- /dev/null +++ b/tests/prompt_test_gemma3.txt @@ -0,0 +1,87 @@ +system +You are capable of executing available function(s) if required. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +user +who is the CEO of Meetkai +model +James Kaplan is the Co-Founder and CEO of MeetKai Inc. +user +is the car Song more expensive than car Tang? +model +I will get the price of 2 cars and compare + +{"name": "get_car_price", "arguments": {"car_name": "Song"}} + + +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} + +tool + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} + +model +No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000. +user +what's the weather like in Hanoi? +model + +{"name": "get_weather", "arguments": {"location": "Hanoi"}} + +tool + +{"result": {"temperature": 10}} + +model +The temperature in Hanoi is: 10 degree Celcious +user +Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most +model +I'll use code interpreter to handle this + +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} + +tool + +ValueError: list.remove(x): x not in list + +model +I will fix the code + +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} + +tool + +[0,1,2,3,5,] + +model +The final list is: 0,1,2,3,5 +user +Thanks! What's the weather in San Francisco? +model + +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} + +tool + +{"result": {"temperature": 20}} + +model +The temperature in San Francisco is: 20 degree Celcious \ No newline at end of file diff --git a/tests/prompt_test_qwen2.5-text-only.txt b/tests/prompt_test_qwen2.5-text-only.txt new file mode 100644 index 00000000..6b386b90 --- /dev/null +++ b/tests/prompt_test_qwen2.5-text-only.txt @@ -0,0 +1,87 @@ +<|im_start|>system +You are a helpful assistant. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +who is the CEO of Meetkai<|im_end|> +<|im_start|>assistant +James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|im_end|> +<|im_start|>user +is the car Song more expensive than car Tang?<|im_end|> +<|im_start|>assistant +I will get the price of 2 cars and compare + +{"name": "get_car_price", "arguments": {"car_name": "Song"}} + + +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} +<|im_end|> +<|im_start|>user + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} +<|im_end|> +<|im_start|>assistant +No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|im_end|> +<|im_start|>user +what's the weather like in Hanoi?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "Hanoi"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 10}} +<|im_end|> +<|im_start|>assistant +The temperature in Hanoi is: 10 degree Celcious<|im_end|> +<|im_start|>user +Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|im_end|> +<|im_start|>assistant +I'll use code interpreter to handle this + +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} +<|im_end|> +<|im_start|>user + +ValueError: list.remove(x): x not in list +<|im_end|> +<|im_start|>assistant +I will fix the code + +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} +<|im_end|> +<|im_start|>user + +[0,1,2,3,5,] +<|im_end|> +<|im_start|>assistant +The final list is: 0,1,2,3,5<|im_end|> +<|im_start|>user +Thanks! What's the weather in San Francisco?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 20}} +<|im_end|> +<|im_start|>assistant +The temperature in San Francisco is: 20 degree Celcious<|im_end|> \ No newline at end of file diff --git a/tests/prompt_test_r1.txt b/tests/prompt_test_r1.txt new file mode 100644 index 00000000..904dcc09 --- /dev/null +++ b/tests/prompt_test_r1.txt @@ -0,0 +1,36 @@ +<|System|> +Environment: ipython + + +You have access to the following functions: + +Use the function 'get_car_price' to 'Get the price of a particular car model' +{"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}} + +Use the function 'get_weather' to 'This function's purpose is to get the weather of a location' +{"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}} + + +Think very carefully before calling functions. + +<|User|>who is the CEO of Meetkai<|Assistant|>James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|end▁of▁sentence|><|User|>is the car Song more expensive than car Tang?<|Assistant|>I will get the price of 2 cars and compare<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_car_price +```json +{"car_name": "Song"} +```<|tool▁call▁end|><|tool▁call▁begin|>function<|tool▁sep|>get_car_price +```json +{"car_name": "Tang"} +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>{"price": {"price": "$25000"}}<|tool▁output▁end|><|tool▁output▁begin|>{"price": {"price": "$20000"}}<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|end▁of▁sentence|><|User|>what's the weather like in Hanoi?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Hanoi"} +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>{"result": {"temperature": 10}}<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>The temperature in Hanoi is: 10 degree Celcious<|end▁of▁sentence|><|User|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|Assistant|>I'll use code interpreter to handle this<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>python +```python +l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')] +l.remove(3.6) +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>ValueError: list.remove(x): x not in list<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>I will fix the code<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>python +```python +l.remove(round(3.6)) +l +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>[0,1,2,3,5,]<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>The final list is: 0,1,2,3,5<|end▁of▁sentence|><|User|>Thanks! What's the weather in San Francisco?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "San Francisco, CA"} +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>{"result": {"temperature": 20}}<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>The temperature in San Francisco is: 20 degree Celcious<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/prompt_test_r1_distilled_llama.txt b/tests/prompt_test_r1_distilled_llama.txt new file mode 100644 index 00000000..dd2dfd70 --- /dev/null +++ b/tests/prompt_test_r1_distilled_llama.txt @@ -0,0 +1,47 @@ +<|System|> +Environment: ipython + + +You have access to the following functions: + +Use the function 'get_car_price' to 'Get the price of a particular car model' +{"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}} + +Use the function 'get_weather' to 'This function's purpose is to get the weather of a location' +{"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}} + + +Think very carefully before calling functions. +If a you choose to call a function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- If looking for real time information use relevant functions before falling back to brave_search +- Function calls MUST follow the specified format, start with +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line + +<|User|>who is the CEO of Meetkai<|Assistant|>James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|end▁of▁sentence|><|User|>is the car Song more expensive than car Tang?<|Assistant|>I will get the price of 2 cars and compare{"car_name": "Song"}{"car_name": "Tang"}<|eom_id|>ipython + +{"price": {"price": "$25000"}}ipython + +{"price": {"price": "$20000"}}<|Assistant|>No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|end▁of▁sentence|><|User|>what's the weather like in Hanoi?<|Assistant|>{"location": "Hanoi"}<|eom_id|>ipython + +{"result": {"temperature": 10}}<|Assistant|>The temperature in Hanoi is: 10 degree Celcious<|end▁of▁sentence|><|User|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|Assistant|>I'll use code interpreter to handle this<|python_tag|>l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')] +l.remove(3.6)<|eom_id|>ipython + +ValueError: list.remove(x): x not in list<|Assistant|>I will fix the code<|python_tag|>l.remove(round(3.6)) +l<|eom_id|>ipython + +[0,1,2,3,5,]<|Assistant|>The final list is: 0,1,2,3,5<|end▁of▁sentence|><|User|>Thanks! What's the weather in San Francisco?<|Assistant|>{"location": "San Francisco, CA"}<|eom_id|>ipython + +{"result": {"temperature": 20}}<|Assistant|>The temperature in San Francisco is: 20 degree Celcious<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/prompt_test_r1_distilled_qwen.txt b/tests/prompt_test_r1_distilled_qwen.txt new file mode 100644 index 00000000..45e1d252 --- /dev/null +++ b/tests/prompt_test_r1_distilled_qwen.txt @@ -0,0 +1,49 @@ +<|System|> + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +<|User|>who is the CEO of Meetkai<|Assistant|>James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|end▁of▁sentence|><|User|>is the car Song more expensive than car Tang?<|Assistant|>I will get the price of 2 cars and compare +{"name": "get_car_price", "arguments": {"car_name": "Song"}} + + +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} +<|end▁of▁sentence|><|User|> + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} +<|Assistant|>No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|end▁of▁sentence|><|User|>what's the weather like in Hanoi?<|Assistant|> +{"name": "get_weather", "arguments": {"location": "Hanoi"}} +<|end▁of▁sentence|><|User|> + +{"result": {"temperature": 10}} +<|Assistant|>The temperature in Hanoi is: 10 degree Celcious<|end▁of▁sentence|><|User|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|Assistant|>I'll use code interpreter to handle this +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} +<|end▁of▁sentence|><|User|> + +ValueError: list.remove(x): x not in list +<|Assistant|>I will fix the code +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} +<|end▁of▁sentence|><|User|> + +[0,1,2,3,5,] +<|Assistant|>The final list is: 0,1,2,3,5<|end▁of▁sentence|><|User|>Thanks! What's the weather in San Francisco?<|Assistant|> +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} +<|end▁of▁sentence|><|User|> + +{"result": {"temperature": 20}} +<|Assistant|>The temperature in San Francisco is: 20 degree Celcious<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index 31cfb7d4..18eb7003 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -38,10 +38,16 @@ def __init__(self, *args, **kwargs): super(TestPromptTemplate, self).__init__(*args, **kwargs) self.template_version_to_model_name = { - "v2": "meetkai/functionary-small-v2.4", - "v2.llama3": "meetkai/functionary-small-v2.5", - "v3.llama3": "meetkai/functionary-medium-v3.0", - "v3-llama3.1": "meetkai/functionary-small-v3.1", + # "v2": "meetkai/functionary-small-v2.4", + # "v2.llama3": "meetkai/functionary-small-v2.5", + # "v3.llama3": "meetkai/functionary-medium-v3.0", + # "v3-llama3.1": "meetkai/functionary-small-v3.1", + # "r1": "deepseek-ai/DeepSeek-R1", + # "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "qwen2.5-text-only": "Qwen/Qwen2.5-32B", + "gemma3": "unsloth/gemma-3-12b-it", + "cogito": "deepcogito/cogito-v1-preview-qwen-32B", } self.image_template_version_to_model_name = { "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" @@ -130,7 +136,7 @@ def run_prepare_training_inputs( ): """this function is used to test function: prepare_training_inputs""" # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176 - tokenizer = AutoTokenizer.from_pretrained(pretrained, legacy=True) + tokenizer = AutoTokenizer.from_pretrained(pretrained) tokenizer.pad_token = tokenizer.eos_token # first we add stop_tokens to the tokenizer prompt_template = get_prompt_template_by_version(template_version) @@ -138,6 +144,7 @@ def run_prepare_training_inputs( added_tokens = prompt_template.get_additional_tokens() special_tokens = {"additional_special_tokens": added_tokens} tokenizer.add_special_tokens(special_tokens) + tokenizer.chat_template = prompt_template.get_chat_template_jinja() test_case, _ = self.read_example_data(template_version) @@ -145,7 +152,7 @@ def run_prepare_training_inputs( messages=test_case, tokenizer=tokenizer, padding="longest", - max_length=1024, + max_length=2048, return_tensor=False, verbose=True, keep_assistant_prefix=keep_assistant_prefix, diff --git a/tests/test_request_handling.py b/tests/test_request_handling.py index 28a5c642..82e7cb51 100644 --- a/tests/test_request_handling.py +++ b/tests/test_request_handling.py @@ -22,8 +22,8 @@ LlavaLlama, PromptTemplate, PromptTemplateV2, - Llama31ReasoningTemplate, - get_available_prompt_template_versions, + Llama31ReasoningTemplate, + Qwen25TextOnlyPromptTemplate, ) from functionary.prompt_template.prompt_utils import ( enforce_tool_choice, @@ -145,7 +145,7 @@ def __init__(self, *args, **kwargs): {"type": "function", "function": self.default_functions[0]} ] self.default_python_args = "from datetime import date\n# Find today's date\ntoday = date.today()\ntoday" - self.test_prompt_templates = get_available_prompt_template_versions() + self.prompt_template_to_tokenizer_name_mapping = { PromptTemplateV2: "meetkai/functionary-small-v2.4", Llama3Template: "meetkai/functionary-small-v2.5", @@ -153,7 +153,12 @@ def __init__(self, *args, **kwargs): Llama31Template: "meetkai/functionary-small-v3.1", Llama31ReasoningTemplate: "meetkai/functionary-small-v3.1", LlavaLlama: "lmms-lab/llama3-llava-next-8b", + Qwen25TextOnlyPromptTemplate: "Qwen/Qwen2.5-32B-Instruct", } + # only test the prompt templates that we will use, no need to test all the prompt templates + self.test_prompt_templates = list( + self.prompt_template_to_tokenizer_name_mapping.keys() + ) # get_available_prompt_template_versions() self.default_text_str = "Normal text generation" self.default_tool_call_name = "get_weather" self.default_tool_call_args = [ @@ -500,13 +505,14 @@ async def test_prompt_template_to_tokenizer(self): # Test whether all prompt templates are included in template_to_tokenizer_mapping yet for prompt_template in self.test_prompt_templates: self.assertIn( - type(prompt_template), + prompt_template, self.prompt_template_to_tokenizer_name_mapping.keys(), f"Prompt template `{type(prompt_template)}` is not included in template_to_tokenizer_mapping yet.", ) async def test_request_handling(self): - for prompt_template in self.test_prompt_templates: + for prompt_template_class in self.test_prompt_templates: + prompt_template = prompt_template_class.get_prompt_template() for test_case in self.request_handling_test_cases: raw_response = generate_raw_response( gen_text=test_case["gen_text"], @@ -590,10 +596,10 @@ async def wrap_generator(tokenizer, raw_response, test_case): yield "", "stop" # yield "", test_case["expected_finish_reason"] - for prompt_template in self.test_prompt_templates: - + for prompt_template_class in self.test_prompt_templates: + prompt_template = prompt_template_class.get_prompt_template() tokenizer = AutoTokenizer.from_pretrained( - self.prompt_template_to_tokenizer_name_mapping[type(prompt_template)] + self.prompt_template_to_tokenizer_name_mapping[prompt_template_class] ) special_tokens = { "additional_special_tokens": prompt_template.get_additional_tokens()