From e44913556b31e34375172234f60b51debf4f4139 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 20 Feb 2025 10:59:27 +0700 Subject: [PATCH 01/31] add template for chat only --- functionary/prompt_template/__init__.py | 1 + functionary/prompt_template/base_template.py | 3 +- .../jinja_templates/r1-qwen-distill.txt | 89 +++++++++++++++++++ .../r1_qwen_distill_template.py | 29 ++++++ 4 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 functionary/prompt_template/jinja_templates/r1-qwen-distill.txt create mode 100644 functionary/prompt_template/r1_qwen_distill_template.py diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index 03870348..d7e6bd30 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -9,6 +9,7 @@ from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1 from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2 from functionary.prompt_template.llama31_reasoning_prompt_template import Llama31ReasoningTemplate +from functionary.prompt_template.r1_qwen_distill_template import R1QwenDistillTemplate def get_available_prompt_template_versions() -> List[PromptTemplate]: diff --git a/functionary/prompt_template/base_template.py b/functionary/prompt_template/base_template.py index 89f69f78..fad33581 100644 --- a/functionary/prompt_template/base_template.py +++ b/functionary/prompt_template/base_template.py @@ -69,13 +69,12 @@ def initialize_grammar_sampling_gen_state(self, tool_choice: Optional[Any]) -> D """ raise NotImplementedError - @abstractmethod def get_additional_tokens(self) -> List[str]: """return list of added tokens if using this template Returns: List[str]: list of tokens, each token is a string """ - raise NotImplementedError + return [] @abstractmethod def get_stop_tokens_for_generation(self) -> List[str]: diff --git a/functionary/prompt_template/jinja_templates/r1-qwen-distill.txt b/functionary/prompt_template/jinja_templates/r1-qwen-distill.txt new file mode 100644 index 00000000..1e211d0d --- /dev/null +++ b/functionary/prompt_template/jinja_templates/r1-qwen-distill.txt @@ -0,0 +1,89 @@ +{# version=r1-qwen-distill #}{# Initialize variables #} +{%- if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{%- set ns = namespace( + is_first=false, + is_tool=false, + is_output_first=true, + system_prompt='', + is_first_sp=true +) -%} +{#- Build system prompt #} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {%- set ns.system_prompt = ns.system_prompt + message['content'] -%} + {%- set ns.is_first_sp = false -%} + {%- else %} + {%- set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] -%} + {%- endif %} + {%- endif %} +{%- endfor %} +{#- Output system prompt: Check if ns.system_prompt is not empty, add \n to it} +{%- if ns.system_prompt %} + {%- set ns.system_prompt = ns.system_prompt + '\n' %} +{%- endif %} + +{{- bos_token }}{{- ns.system_prompt -}} +{# Process messages #} +{%- for message in messages %} + {#- Handle user messages #} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {{- '<|User|>' + message['content'] -}} + {%- endif -%} + {# Handle assistant messages with tool calls #} + {%- if message['role'] == 'assistant' and 'tool_calls' in message %} + {%- set ns.is_tool = false %} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{- '<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + + '<|tool▁call▁end|>'}} + {%- else %} + {{- '<|Assistant|>' + message['content'] + + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true %} + {%- else %} + {{-'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + + tool['function']['name'] + '\n' + '```json' + '\n' + + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{- '<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif -%} + {# Handle assistant messages without tool calls #} + {%- if message['role'] == 'assistant' and 'tool_calls' not in message %} + {%- if ns.is_tool -%} + {{- '<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>' -}} + {%- set ns.is_tool = false %} + {%- else -%} + {{- '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' -}} + {%- endif -%} + {%- endif -%} + {# Handle tool messages #} + {%- if message['role'] == 'tool' -%} + {%- set ns.is_tool = true %} + {%- if ns.is_output_first %} + {{- '<|tool▁outputs▁begin|><|tool▁output▁begin|>' + + message['content'] + '<|tool▁output▁end|>'}} + {%- set ns.is_output_first = false %} + {%- else %} + {{- '<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} + {%- endif -%} +{%- endfor -%} +{# Final outputs #} +{%- if ns.is_tool -%} + {{- '<|tool▁outputs▁end|>' -}} +{%- endif -%} +{%- if add_generation_prompt and not ns.is_tool -%} + {{- '<|Assistant|>'}} +{% endif %} diff --git a/functionary/prompt_template/r1_qwen_distill_template.py b/functionary/prompt_template/r1_qwen_distill_template.py new file mode 100644 index 00000000..94fe59f0 --- /dev/null +++ b/functionary/prompt_template/r1_qwen_distill_template.py @@ -0,0 +1,29 @@ +import datetime +import json +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.base_template import PromptTemplate + + +class R1QwenDistillTemplate(PromptTemplate): + version = "r1-qwen-distill" + + def get_assistant_prefixes(self) -> List[str]: + return [f"<|Assistant|>"] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_stop_tokens_for_generation(self) -> List[str]: + return ["<|end▁of▁sentence|>"] \ No newline at end of file From 5f454f356064d7f2ac8e4b2850ee409064d8b2ff Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Mon, 24 Feb 2025 00:16:53 +0700 Subject: [PATCH 02/31] add template for r1 --- functionary/prompt_template/__init__.py | 2 +- .../jinja_templates/r1-qwen-distill.txt | 89 ------------ .../prompt_template/jinja_templates/r1.txt | 97 +++++++++++++ .../r1_qwen_distill_template.py | 29 ---- functionary/prompt_template/r1_template.py | 130 ++++++++++++++++++ tests/prompt_test_r1.txt | 36 +++++ tests/test_prompt_creation.py | 4 +- 7 files changed, 267 insertions(+), 120 deletions(-) delete mode 100644 functionary/prompt_template/jinja_templates/r1-qwen-distill.txt create mode 100644 functionary/prompt_template/jinja_templates/r1.txt delete mode 100644 functionary/prompt_template/r1_qwen_distill_template.py create mode 100644 functionary/prompt_template/r1_template.py create mode 100644 tests/prompt_test_r1.txt diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index d7e6bd30..14c0b0c6 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -9,7 +9,7 @@ from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1 from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2 from functionary.prompt_template.llama31_reasoning_prompt_template import Llama31ReasoningTemplate -from functionary.prompt_template.r1_qwen_distill_template import R1QwenDistillTemplate +from functionary.prompt_template.r1_template import R1Template def get_available_prompt_template_versions() -> List[PromptTemplate]: diff --git a/functionary/prompt_template/jinja_templates/r1-qwen-distill.txt b/functionary/prompt_template/jinja_templates/r1-qwen-distill.txt deleted file mode 100644 index 1e211d0d..00000000 --- a/functionary/prompt_template/jinja_templates/r1-qwen-distill.txt +++ /dev/null @@ -1,89 +0,0 @@ -{# version=r1-qwen-distill #}{# Initialize variables #} -{%- if not add_generation_prompt is defined -%} - {%- set add_generation_prompt = false -%} -{%- endif -%} -{%- set ns = namespace( - is_first=false, - is_tool=false, - is_output_first=true, - system_prompt='', - is_first_sp=true -) -%} -{#- Build system prompt #} -{%- for message in messages %} - {%- if message['role'] == 'system' %} - {%- if ns.is_first_sp %} - {%- set ns.system_prompt = ns.system_prompt + message['content'] -%} - {%- set ns.is_first_sp = false -%} - {%- else %} - {%- set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] -%} - {%- endif %} - {%- endif %} -{%- endfor %} -{#- Output system prompt: Check if ns.system_prompt is not empty, add \n to it} -{%- if ns.system_prompt %} - {%- set ns.system_prompt = ns.system_prompt + '\n' %} -{%- endif %} - -{{- bos_token }}{{- ns.system_prompt -}} -{# Process messages #} -{%- for message in messages %} - {#- Handle user messages #} - {%- if message['role'] == 'user' %} - {%- set ns.is_tool = false -%} - {{- '<|User|>' + message['content'] -}} - {%- endif -%} - {# Handle assistant messages with tool calls #} - {%- if message['role'] == 'assistant' and 'tool_calls' in message %} - {%- set ns.is_tool = false %} - {%- for tool in message['tool_calls'] %} - {%- if not ns.is_first %} - {%- if message['content'] is none %} - {{- '<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + - tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + - '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + - '<|tool▁call▁end|>'}} - {%- else %} - {{- '<|Assistant|>' + message['content'] + - '<|tool▁calls▁begin|><|tool▁call▁begin|>' + - tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + - '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + - '<|tool▁call▁end|>'}} - {%- endif %} - {%- set ns.is_first = true %} - {%- else %} - {{-'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + - tool['function']['name'] + '\n' + '```json' + '\n' + - tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} - {%- endif %} - {%- endfor %} - {{- '<|tool▁calls▁end|><|end▁of▁sentence|>'}} - {%- endif -%} - {# Handle assistant messages without tool calls #} - {%- if message['role'] == 'assistant' and 'tool_calls' not in message %} - {%- if ns.is_tool -%} - {{- '<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>' -}} - {%- set ns.is_tool = false %} - {%- else -%} - {{- '<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>' -}} - {%- endif -%} - {%- endif -%} - {# Handle tool messages #} - {%- if message['role'] == 'tool' -%} - {%- set ns.is_tool = true %} - {%- if ns.is_output_first %} - {{- '<|tool▁outputs▁begin|><|tool▁output▁begin|>' + - message['content'] + '<|tool▁output▁end|>'}} - {%- set ns.is_output_first = false %} - {%- else %} - {{- '<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} - {%- endif %} - {%- endif -%} -{%- endfor -%} -{# Final outputs #} -{%- if ns.is_tool -%} - {{- '<|tool▁outputs▁end|>' -}} -{%- endif -%} -{%- if add_generation_prompt and not ns.is_tool -%} - {{- '<|Assistant|>'}} -{% endif %} diff --git a/functionary/prompt_template/jinja_templates/r1.txt b/functionary/prompt_template/jinja_templates/r1.txt new file mode 100644 index 00000000..7c87e09f --- /dev/null +++ b/functionary/prompt_template/jinja_templates/r1.txt @@ -0,0 +1,97 @@ +{# version=r1 #}{%- if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{{- bos_token }} +{%- if not tools is defined -%} + {%- set tools = none -%} +{%- endif -%} + +{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%} +{%- if has_code_interpreter -%} + {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%} +{%- endif -%} +{%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} +{%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} +{%- endif -%} +{#- System message + builtin tools -#} +{#- If at lease: has_reasoning or has_code_interpreter or tools -#} +{%- if has_reasoning or has_code_interpreter or tools -%} + {{- "<|System|>\n" }} +{%- endif -%} +{%- if has_reasoning -%} + {{- "Reasoning Mode: On\n\n" }} +{%- endif -%} +{%- if has_code_interpreter -%} + {{- "Environment: ipython\n\n" }} +{%- endif %} +{%- if tools %} + {{- "\nYou have access to the following functions:\n\n" }} + {%- for t in tools %} + {%- if "type" in t -%} + {{ "Use the function '" + t["function"]["name"] + "' to '" + t["function"]["description"] + "'\n" + t["function"] | tojson() }} + {%- else -%} + {{ "Use the function '" + t["name"] + "' to '" + t["description"] + "'\n" + t | tojson }} + {%- endif -%} + {{- "\n\n" }} + {%- endfor %} + {{- '\nThink very carefully before calling functions.\n\n' -}} +{%- endif %} +{%- set ns = namespace( + is_first=false, + is_tool=false, + is_output_first=true, + system_prompt='', + is_first_sp=true +) -%} +{#- Process messages -#} +{%- for message in messages -%} + {%- if message['role'] == 'user'-%} + {# Reset tool and output first flags #} + {%- set ns.is_tool = false -%} + {{-'<|User|>' + message['content'] -}} + {%- elif message['role'] == 'system' -%} + {{- '<|System|>' + message['content'] + '\n' -}} + {%- elif message['role'] == 'assistant' -%} + {# Need to check need to add <|tool▁outputs▁end|> if the previous message is tool message #} + {%- if ns.is_tool -%} + {{- '<|tool▁outputs▁end|>' -}} + {%- set ns.is_tool = false -%} + {%- endif -%} + {{-'<|Assistant|>'}} + {%- if message['content'] and message['content']|length > 0 -%} + {{ message['content'] }} + {%- endif -%} + {#Check if there is tool call#} + {%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0 -%} + {%- set ns.is_tool = true -%} + {%- set ns.is_output_first = true -%} + {{- '<|tool▁calls▁begin|>'}} + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call["function"]["name"] == "python" -%} + {%- set prefix = 'python' -%} + {%- else -%} + {%- set prefix = 'json' -%} + {%- endif -%} + {{- '<|tool▁call▁begin|>function<|tool▁sep|>' + tool_call['function']['name'] + '\n```' + prefix + '\n' + tool_call['function']['arguments'] + '\n```<|tool▁call▁end|>'}} + {%- endfor -%} + {{ '<|tool▁calls▁end|><|end▁of▁sentence|>' }} + {%- elif message['content'] and message['content']|length > 0 -%} + {{ '<|end▁of▁sentence|>' }} + {%- endif -%} + {%- elif message['role'] == 'tool' -%} + {# Check if this is the result of the first tool call #} + {%- if ns.is_output_first -%} + {{- '<|tool▁outputs▁begin|>'}} + {%- set ns.is_output_first = false -%} + {%- endif -%} + {{- '<|tool▁output▁begin|>' + message["content"] + '<|tool▁output▁end|>'}} + {%- endif -%} +{%- endfor -%} +{# Need to add closing tag for list of tool messages #} +{%- if ns.is_tool -%} + {{- '<|tool▁outputs▁end|>' -}} +{%- endif -%} +{%- if add_generation_prompt and not ns.is_tool -%} + {{- '<|Assistant|>'}} +{% endif %} diff --git a/functionary/prompt_template/r1_qwen_distill_template.py b/functionary/prompt_template/r1_qwen_distill_template.py deleted file mode 100644 index 94fe59f0..00000000 --- a/functionary/prompt_template/r1_qwen_distill_template.py +++ /dev/null @@ -1,29 +0,0 @@ -import datetime -import json -import re -from typing import Any, Dict, List, Literal, Optional, Tuple, Union - -from functionary.openai_types import Function, Tool -from functionary.prompt_template import prompt_utils -from functionary.prompt_template.base_template import PromptTemplate - - -class R1QwenDistillTemplate(PromptTemplate): - version = "r1-qwen-distill" - - def get_assistant_prefixes(self) -> List[str]: - return [f"<|Assistant|>"] - - def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: - """Order the tool results by the order of tool call ids - - Args: - messages (List[Dict]): List of messages - - Returns: - List[Dict]: List of messages - """ - return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) - - def get_stop_tokens_for_generation(self) -> List[str]: - return ["<|end▁of▁sentence|>"] \ No newline at end of file diff --git a/functionary/prompt_template/r1_template.py b/functionary/prompt_template/r1_template.py new file mode 100644 index 00000000..c475833f --- /dev/null +++ b/functionary/prompt_template/r1_template.py @@ -0,0 +1,130 @@ +import datetime +import json +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.base_template import PromptTemplate + + +def parse_tool_call(tool_call_str: str) -> Dict: + """parse the tool name and arguments from the tool call string. the format of the tool call string is: + function<|tool▁sep|> + ```type + + ``` + + Args: + tool_call_str (str): _description_ + + Returns: + Dict: {"name": , "args": } + """ + tool_sep_index = tool_call_str.find("<|tool▁sep|>") + brkline_index = tool_call_str.find("\n", tool_sep_index) + function_name = tool_call_str[ + tool_sep_index + len("<|tool▁sep|>") : brkline_index + ].strip() + # parse arguments + arguments_content = tool_call_str[brkline_index:].strip() + # strip ``` at the begining and the end + arguments_content = arguments_content.replace("```", "") + index = arguments_content.find("\n") # ignore: json\n or python\n + arguments_content = arguments_content[index:].strip() + + return { + "id": prompt_utils.get_random_tool_call_id(), + "type": "function", + "function": {"name": function_name, "arguments": arguments_content}, + } + + +def extract_text_inside(start_prefix: str, end_prefix: str, text: str) -> List[str]: + """extract all text inside the start_prefix and end_prefix, return a list of texts inside these two prefixes + + Args: + start_prefix (str): the prefix before the text to extract + end_prefix (str): the prefix after the text to extract + text (str): the text to extract + + Returns: + List[str]: a list of texts inside these two prefixes + """ + result = [] + current_pos = 0 + + while True: + # Find next start position + start_pos = text.find(start_prefix, current_pos) + if start_pos == -1: + break + + # Find matching end position + end_pos = text.find(end_prefix, start_pos + len(start_prefix)) + if end_pos == -1: + break + + # Extract text between prefixes + extracted = text[start_pos + len(start_prefix) : end_pos] + result.append(extracted) + + # Move current position past this match + current_pos = end_pos + len(end_prefix) + + return result + + +class R1Template(PromptTemplate): + version = "r1" + + def get_assistant_prefixes(self) -> List[str]: + return [f"<|Assistant|>"] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_stop_tokens_for_generation(self) -> List[str]: + return ["<|end▁of▁sentence|>"] + + def parse_assistant_response( + self, llm_output: str, tool_choice: Any | None + ) -> Dict: + for stop in self.get_stop_tokens_for_generation(): + if llm_output.endswith(stop): + llm_output = llm_output[: -len(stop)] + + # add forced-function from tool_choice if exists + llm_output = ( + self.get_generation_prefix_for_tool_choice(tool_choice) + llm_output + ) + text_content = llm_output + tool_calls = [] + # parse the tool calls + # first extract the string about tool calls, which is inside: <|tool▁calls▁begin|><|tool▁calls▁end|> + all_tool_calls_list = extract_text_inside( + "<|tool▁calls▁begin|>", "<|tool▁calls▁end|>", llm_output + ) + if len(all_tool_calls_list) > 0: + all_tool_calls_str = all_tool_calls_list[0] + index = llm_output.find("<|tool▁calls▁begin|") + text_content = text_content[:index] + # extract tool calls inside: <|tool▁call▁begin|> & <|tool▁call▁end|> + tool_calls_strs = extract_text_inside( + "<|tool▁call▁begin|>", "<|tool▁call▁end|>", all_tool_calls_str + ) + for tool_call_str in tool_calls_strs: + tool_calls.append(parse_tool_call(tool_call_str)) + return { + "role": "assistant", + "content": text_content if text_content else None, + "tool_calls": tool_calls if tool_calls else None + } diff --git a/tests/prompt_test_r1.txt b/tests/prompt_test_r1.txt new file mode 100644 index 00000000..904dcc09 --- /dev/null +++ b/tests/prompt_test_r1.txt @@ -0,0 +1,36 @@ +<|System|> +Environment: ipython + + +You have access to the following functions: + +Use the function 'get_car_price' to 'Get the price of a particular car model' +{"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}} + +Use the function 'get_weather' to 'This function's purpose is to get the weather of a location' +{"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}} + + +Think very carefully before calling functions. + +<|User|>who is the CEO of Meetkai<|Assistant|>James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|end▁of▁sentence|><|User|>is the car Song more expensive than car Tang?<|Assistant|>I will get the price of 2 cars and compare<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_car_price +```json +{"car_name": "Song"} +```<|tool▁call▁end|><|tool▁call▁begin|>function<|tool▁sep|>get_car_price +```json +{"car_name": "Tang"} +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>{"price": {"price": "$25000"}}<|tool▁output▁end|><|tool▁output▁begin|>{"price": {"price": "$20000"}}<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|end▁of▁sentence|><|User|>what's the weather like in Hanoi?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "Hanoi"} +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>{"result": {"temperature": 10}}<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>The temperature in Hanoi is: 10 degree Celcious<|end▁of▁sentence|><|User|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|Assistant|>I'll use code interpreter to handle this<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>python +```python +l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')] +l.remove(3.6) +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>ValueError: list.remove(x): x not in list<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>I will fix the code<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>python +```python +l.remove(round(3.6)) +l +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>[0,1,2,3,5,]<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>The final list is: 0,1,2,3,5<|end▁of▁sentence|><|User|>Thanks! What's the weather in San Francisco?<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +```json +{"location": "San Francisco, CA"} +```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|><|tool▁outputs▁begin|><|tool▁output▁begin|>{"result": {"temperature": 20}}<|tool▁output▁end|><|tool▁outputs▁end|><|Assistant|>The temperature in San Francisco is: 20 degree Celcious<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index 31cfb7d4..a69dbc73 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -42,6 +42,7 @@ def __init__(self, *args, **kwargs): "v2.llama3": "meetkai/functionary-small-v2.5", "v3.llama3": "meetkai/functionary-medium-v3.0", "v3-llama3.1": "meetkai/functionary-small-v3.1", + "r1": "deepseek-ai/DeepSeek-R1" } self.image_template_version_to_model_name = { "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" @@ -130,7 +131,7 @@ def run_prepare_training_inputs( ): """this function is used to test function: prepare_training_inputs""" # note that must set legacy=True, read more: https://github.com/huggingface/transformers/issues/25176 - tokenizer = AutoTokenizer.from_pretrained(pretrained, legacy=True) + tokenizer = AutoTokenizer.from_pretrained(pretrained) tokenizer.pad_token = tokenizer.eos_token # first we add stop_tokens to the tokenizer prompt_template = get_prompt_template_by_version(template_version) @@ -138,6 +139,7 @@ def run_prepare_training_inputs( added_tokens = prompt_template.get_additional_tokens() special_tokens = {"additional_special_tokens": added_tokens} tokenizer.add_special_tokens(special_tokens) + tokenizer.chat_template = prompt_template.get_chat_template_jinja() test_case, _ = self.read_example_data(template_version) From 12f4e8863e684307fb2cce3e7675c04d58264b09 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Tue, 25 Feb 2025 11:30:45 +0700 Subject: [PATCH 03/31] add template for r1-qwen-distilled --- functionary/prompt_template/__init__.py | 5 +- .../prompt_template/jinja_templates/r1.txt | 9 +- .../jinja_templates/r1_distilled_qwen.txt | 58 +++++++ .../prompt_template/r1_distilled_qwen.py | 164 ++++++++++++++++++ tests/prompt_test_r1_distilled_qwen.txt | 49 ++++++ tests/test_prompt_creation.py | 3 +- 6 files changed, 284 insertions(+), 4 deletions(-) create mode 100644 functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt create mode 100644 functionary/prompt_template/r1_distilled_qwen.py create mode 100644 tests/prompt_test_r1_distilled_qwen.txt diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index 14c0b0c6..308bdf88 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -8,8 +8,11 @@ from functionary.prompt_template.llava_prompt_template import LlavaLlama from functionary.prompt_template.prompt_template_v1 import PromptTemplateV1 from functionary.prompt_template.prompt_template_v2 import PromptTemplateV2 -from functionary.prompt_template.llama31_reasoning_prompt_template import Llama31ReasoningTemplate +from functionary.prompt_template.llama31_reasoning_prompt_template import ( + Llama31ReasoningTemplate, +) from functionary.prompt_template.r1_template import R1Template +from functionary.prompt_template.r1_distilled_qwen import R1DistilledQwen def get_available_prompt_template_versions() -> List[PromptTemplate]: diff --git a/functionary/prompt_template/jinja_templates/r1.txt b/functionary/prompt_template/jinja_templates/r1.txt index 7c87e09f..eba6a224 100644 --- a/functionary/prompt_template/jinja_templates/r1.txt +++ b/functionary/prompt_template/jinja_templates/r1.txt @@ -42,17 +42,21 @@ is_tool=false, is_output_first=true, system_prompt='', - is_first_sp=true + is_first_sp=true, + last_message_is_tool=false ) -%} {#- Process messages -#} {%- for message in messages -%} {%- if message['role'] == 'user'-%} {# Reset tool and output first flags #} + {%- set ns.last_message_is_tool = false -%} {%- set ns.is_tool = false -%} {{-'<|User|>' + message['content'] -}} {%- elif message['role'] == 'system' -%} + {%- set ns.last_message_is_tool = false -%} {{- '<|System|>' + message['content'] + '\n' -}} {%- elif message['role'] == 'assistant' -%} + {%- set ns.last_message_is_tool = false -%} {# Need to check need to add <|tool▁outputs▁end|> if the previous message is tool message #} {%- if ns.is_tool -%} {{- '<|tool▁outputs▁end|>' -}} @@ -80,6 +84,7 @@ {{ '<|end▁of▁sentence|>' }} {%- endif -%} {%- elif message['role'] == 'tool' -%} + {%- set ns.last_message_is_tool = true -%} {# Check if this is the result of the first tool call #} {%- if ns.is_output_first -%} {{- '<|tool▁outputs▁begin|>'}} @@ -89,7 +94,7 @@ {%- endif -%} {%- endfor -%} {# Need to add closing tag for list of tool messages #} -{%- if ns.is_tool -%} +{%- if ns.last_message_is_tool -%} {{- '<|tool▁outputs▁end|>' -}} {%- endif -%} {%- if add_generation_prompt and not ns.is_tool -%} diff --git a/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt b/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt new file mode 100644 index 00000000..598ba017 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt @@ -0,0 +1,58 @@ +{# version=r1_distilled_qwen #} +{%- if tools %} + {{- '<|System|>' }} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n" }} +{%- else %} +{%- endif %} +{%- for message in messages -%} + {%- if message.role == "user" -%} + {{ '<|User|>' + message['content'] }} + {%- elif message.role == "system" %} + {{ '<|System|>' + message['content'] + '\n'}} + {%- elif message.role == "assistant" %} + {{- '<|Assistant|>'}} + {%- if message.content %} + {{- message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {{- '\n{"name": "' }} + {%- else %} + {{- '\n\n{"name": "' }} + {%- endif %} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '<|end▁of▁sentence|>' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|User|>' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|Assistant|>' }} +{%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/r1_distilled_qwen.py b/functionary/prompt_template/r1_distilled_qwen.py new file mode 100644 index 00000000..67235cd2 --- /dev/null +++ b/functionary/prompt_template/r1_distilled_qwen.py @@ -0,0 +1,164 @@ +import datetime +import json +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.base_template import PromptTemplate +import copy + + +class R1DistilledQwen(PromptTemplate): + version = "r1_distilled_qwen" + chat_template = None + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return [f"<|Assistant|>"] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_stop_tokens_for_generation(self) -> List[str]: + return ["<|end▁of▁sentence|>"] + + def parse_assistant_response( + self, llm_output: str, tool_choice: Any = None + ) -> Dict: + # first remove stop tokens if there exists + for stop in self.get_stop_tokens_for_generation(): + if llm_output.endswith(stop): + llm_output = llm_output[: -len(stop)] + + # add forced-function from tool_choice if exists + llm_output = ( + self.get_generation_prefix_for_tool_choice(tool_choice) + llm_output + ) + + tool_calls = [] + text_response = "" + + while len(llm_output) > 0: + start_tool_call_index = llm_output.find("") + if start_tool_call_index >= 0: + end_index = llm_output.find("", start_tool_call_index) + if end_index >= 0: + json_between = llm_output[ + start_tool_call_index + len("") : end_index + ] + func_call = json.loads(json_between) + tool_calls.append( + { + "type": "function", + "id": prompt_utils.get_random_tool_call_id(), + "function": { + "name": func_call["name"], + "arguments": json.dumps( + func_call["arguments"], ensure_ascii=False + ), + }, + } + ) + index = end_index + len("") + + text_response += llm_output[:start_tool_call_index].strip() + llm_output = llm_output[index:] + else: # cannot find at the end + text_response += llm_output + llm_output = "" + else: # cannot find + text_response += llm_output + llm_output = "" + + if not text_response: + text_response = None + elif len(text_response.strip()) == 0: + text_response = None + + if not tool_calls: + tool_calls = None + + return {"role": "assistant", "content": text_response, "tool_calls": tool_calls} + + def get_prompt_from_messages( + self, + messages: List[Dict], + tools_or_functions: Optional[List[Dict]] = None, + bos_token: Optional[str] = "", + add_generation_prompt: bool = False, + ) -> str: + """This function is used to get the complete prompt for list of messages + + Args: + messages (List[Dict]): List of messages + tools_or_functions (Optional[List[Dict]], optional): List of tools or functions. Defaults to None. + + Returns: + str: the prompt for inference/training + """ + # qwen 2.5 use transformers chat template, need to convert argument string --> dictionary, this is noted in: https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example + # If you’re familiar with the OpenAI API, you should pay attention to an important difference here - the tool_call is a dict, but in the OpenAI API it’s a JSON string. Passing a string may cause errors or strange model behaviour! + new_messages = copy.deepcopy(messages) + for message in new_messages: + tool_calls = message.get("tool_calls", []) + if tool_calls: + for tool_call in tool_calls: + if type(tool_call["function"]["arguments"]) is str: + if tool_call["function"]["name"] != "python": + tool_call["function"]["arguments"] = json.loads( + tool_call["function"]["arguments"] + ) + else: + tool_call["function"] = { + "name": "python", + "arguments": { + "code": tool_call["function"]["arguments"] + }, + } + # check if contain code_interpreter, replace with python + new_tools = copy.deepcopy(tools_or_functions) + if tools_or_functions is not None and len(tools_or_functions) == 0: + new_tools = None + + if new_tools: + for tool in new_tools: + if tool["type"] == "code_interpreter": + tool["type"] = "function" + tool["function"] = { + "name": "python", + "description": 'When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at "/mnt/data" can be used to save and persist user files.', + "parameters": { + "type": "object", + "properties": { + "code": {"type": "string", "description": "Python code"} + }, + "required": ["code"], + }, + } + + prompt = self._jinja_template.render( + messages=new_messages, + tools=new_tools, + bos_token=bos_token, + add_generation_prompt=add_generation_prompt, + ) + + return prompt + + def get_chat_template_jinja(self) -> str: + if self.chat_template is None: + jinja_template_file = "./functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt" + with open(jinja_template_file, "r") as f: + self.chat_template = f.read() + return self.chat_template \ No newline at end of file diff --git a/tests/prompt_test_r1_distilled_qwen.txt b/tests/prompt_test_r1_distilled_qwen.txt new file mode 100644 index 00000000..281fca93 --- /dev/null +++ b/tests/prompt_test_r1_distilled_qwen.txt @@ -0,0 +1,49 @@ +<|System|> + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \"/mnt/data\" can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "Python code"}}, "required": ["code"]}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +<|User|>who is the CEO of Meetkai<|Assistant|>James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|end▁of▁sentence|><|User|>is the car Song more expensive than car Tang?<|Assistant|>I will get the price of 2 cars and compare +{"name": "get_car_price", "arguments": {'car_name': 'Song'}} + + +{"name": "get_car_price", "arguments": {'car_name': 'Tang'}} +<|end▁of▁sentence|><|User|> + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} +<|Assistant|>No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|end▁of▁sentence|><|User|>what's the weather like in Hanoi?<|Assistant|> +{"name": "get_weather", "arguments": {'location': 'Hanoi'}} +<|end▁of▁sentence|><|User|> + +{"result": {"temperature": 10}} +<|Assistant|>The temperature in Hanoi is: 10 degree Celcious<|end▁of▁sentence|><|User|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|Assistant|>I'll use code interpreter to handle this +{"name": "python", "arguments": {'code': "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} +<|end▁of▁sentence|><|User|> + +ValueError: list.remove(x): x not in list +<|Assistant|>I will fix the code +{"name": "python", "arguments": {'code': 'l.remove(round(3.6))\nl'}} +<|end▁of▁sentence|><|User|> + +[0,1,2,3,5,] +<|Assistant|>The final list is: 0,1,2,3,5<|end▁of▁sentence|><|User|>Thanks! What's the weather in San Francisco?<|Assistant|> +{"name": "get_weather", "arguments": {'location': 'San Francisco, CA'}} +<|end▁of▁sentence|><|User|> + +{"result": {"temperature": 20}} +<|Assistant|>The temperature in San Francisco is: 20 degree Celcious<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index a69dbc73..a80d25d4 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -42,7 +42,8 @@ def __init__(self, *args, **kwargs): "v2.llama3": "meetkai/functionary-small-v2.5", "v3.llama3": "meetkai/functionary-medium-v3.0", "v3-llama3.1": "meetkai/functionary-small-v3.1", - "r1": "deepseek-ai/DeepSeek-R1" + "r1": "deepseek-ai/DeepSeek-R1", + "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", } self.image_template_version_to_model_name = { "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" From 09fcf046a4f7f209d6ee72b0b16fd344c60ade7b Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Tue, 25 Feb 2025 15:45:46 +0700 Subject: [PATCH 04/31] add template for r1_distilled_llama --- functionary/prompt_template/__init__.py | 2 + .../jinja_templates/r1_distilled_llama.txt | 67 +++++++++++++++++++ .../prompt_template/r1_distilled_llama.py | 45 +++++++++++++ .../prompt_template/r1_distilled_qwen.py | 10 +-- tests/prompt_test_r1_distilled_llama.txt | 47 +++++++++++++ tests/test_prompt_creation.py | 1 + 6 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 functionary/prompt_template/jinja_templates/r1_distilled_llama.txt create mode 100644 functionary/prompt_template/r1_distilled_llama.py create mode 100644 tests/prompt_test_r1_distilled_llama.txt diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index 308bdf88..71ef8a18 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -13,6 +13,7 @@ ) from functionary.prompt_template.r1_template import R1Template from functionary.prompt_template.r1_distilled_qwen import R1DistilledQwen +from functionary.prompt_template.r1_distilled_llama import R1DistilledLlama def get_available_prompt_template_versions() -> List[PromptTemplate]: @@ -34,6 +35,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]: # we don't use get_prompt_template or this will return the parent class all_templates_obj.append(LlavaLlama.get_prompt_template()) all_templates_obj.append(Llama31ReasoningTemplate.get_prompt_template()) + all_templates_obj.append(R1DistilledLlama.get_prompt_template()) return all_templates_obj diff --git a/functionary/prompt_template/jinja_templates/r1_distilled_llama.txt b/functionary/prompt_template/jinja_templates/r1_distilled_llama.txt new file mode 100644 index 00000000..5b614d64 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/r1_distilled_llama.txt @@ -0,0 +1,67 @@ +{# version=r1_distilled_llama #}{%- if not tools is defined -%} + {%- set tools = none -%} +{%- endif -%} + +{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%} +{%- if has_code_interpreter -%} + {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%} +{%- endif -%} + +{%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} +{%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} +{%- endif -%} + +{#- System message + builtin tools #} +{{- bos_token }} +{%- if has_reasoning or has_code_interpreter or tools -%} + {{- "<|System|>\n" }} +{%- endif -%} +{%- if has_reasoning %} + {{- "Reasoning Mode: On\n\n" }} +{%- endif %} +{%- if has_code_interpreter %} + {{- "Environment: ipython\n\n" }} +{%- endif %} +{%- if tools %} + {{- "\nYou have access to the following functions:\n\n" }} + {%- for t in tools %} + {%- if "type" in t -%} + {{ "Use the function '" + t["function"]["name"] + "' to '" + t["function"]["description"] + "'\n" + t["function"] | tojson() }} + {%- else -%} + {{ "Use the function '" + t["name"] + "' to '" + t["description"] + "'\n" + t | tojson }} + {%- endif -%} + {{- "\n\n" }} + {%- endfor %} + {{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- If looking for real time information use relevant functions before falling back to brave_search\n- Function calls MUST follow the specified format, start with \n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line\n\n' -}} +{%- endif %} + +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{-'<|User|>' + message['content'] -}} + {%- elif message['role'] == 'system' -%} + {{- '<|System|>' + message['content'] + '\n' -}} + {%- elif message['role'] == 'tool' -%} + {{ 'ipython\n\n' + message['content'] }} + {%- else -%} + {{-'<|Assistant|>'}} + {%- if message['content'] and message['content']|length > 0 -%} + {{ message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] and message['tool_calls']|length > 0 -%} + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call["function"]["name"] == "python" -%} + {{ '<|python_tag|>' + tool_call['function']['arguments'] }} + {%- else -%} + {{ '' + tool_call['function']['arguments'] + '' }} + {%- endif -%} + {%- endfor -%} + {{ '<|eom_id|>' }} + {%- elif message['content'] and message['content']|length > 0 -%} + {{ '<|end▁of▁sentence|>' }} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{ '<|Assistant|>' }} +{%- endif -%} \ No newline at end of file diff --git a/functionary/prompt_template/r1_distilled_llama.py b/functionary/prompt_template/r1_distilled_llama.py new file mode 100644 index 00000000..d12cf2a8 --- /dev/null +++ b/functionary/prompt_template/r1_distilled_llama.py @@ -0,0 +1,45 @@ +import datetime +import json +import re +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.llama31_reasoning_prompt_template import ( + Llama31Template, +) +import copy + + +class R1DistilledLlama(Llama31Template): + version = "r1_distilled_llama" + chat_template = None + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return [f"<|Assistant|>"] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_stop_tokens_for_generation(self) -> List[str]: + return ["<|end▁of▁sentence|>", "<|eom_id|>"] + + def get_chat_template_jinja(self) -> str: + if self.chat_template is None: + jinja_template_file = ( + "./functionary/prompt_template/jinja_templates/r1_distilled_llama.txt" + ) + with open(jinja_template_file, "r") as f: + self.chat_template = f.read() + return self.chat_template diff --git a/functionary/prompt_template/r1_distilled_qwen.py b/functionary/prompt_template/r1_distilled_qwen.py index 67235cd2..3c880a7c 100644 --- a/functionary/prompt_template/r1_distilled_qwen.py +++ b/functionary/prompt_template/r1_distilled_qwen.py @@ -12,7 +12,7 @@ class R1DistilledQwen(PromptTemplate): version = "r1_distilled_qwen" chat_template = None - + def get_additional_tokens(self) -> List[str]: return [] @@ -155,10 +155,12 @@ def get_prompt_from_messages( ) return prompt - + def get_chat_template_jinja(self) -> str: if self.chat_template is None: - jinja_template_file = "./functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt" + jinja_template_file = ( + "./functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt" + ) with open(jinja_template_file, "r") as f: self.chat_template = f.read() - return self.chat_template \ No newline at end of file + return self.chat_template diff --git a/tests/prompt_test_r1_distilled_llama.txt b/tests/prompt_test_r1_distilled_llama.txt new file mode 100644 index 00000000..dd2dfd70 --- /dev/null +++ b/tests/prompt_test_r1_distilled_llama.txt @@ -0,0 +1,47 @@ +<|System|> +Environment: ipython + + +You have access to the following functions: + +Use the function 'get_car_price' to 'Get the price of a particular car model' +{"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}} + +Use the function 'get_weather' to 'This function's purpose is to get the weather of a location' +{"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}} + + +Think very carefully before calling functions. +If a you choose to call a function ONLY reply in the following format: +<{start_tag}={function_name}>{parameters}{end_tag} +where + +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +end_tag => `` + +Here is an example, +{"example_name": "example_value"} + +Reminder: +- If looking for real time information use relevant functions before falling back to brave_search +- Function calls MUST follow the specified format, start with +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line + +<|User|>who is the CEO of Meetkai<|Assistant|>James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|end▁of▁sentence|><|User|>is the car Song more expensive than car Tang?<|Assistant|>I will get the price of 2 cars and compare{"car_name": "Song"}{"car_name": "Tang"}<|eom_id|>ipython + +{"price": {"price": "$25000"}}ipython + +{"price": {"price": "$20000"}}<|Assistant|>No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|end▁of▁sentence|><|User|>what's the weather like in Hanoi?<|Assistant|>{"location": "Hanoi"}<|eom_id|>ipython + +{"result": {"temperature": 10}}<|Assistant|>The temperature in Hanoi is: 10 degree Celcious<|end▁of▁sentence|><|User|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|Assistant|>I'll use code interpreter to handle this<|python_tag|>l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')] +l.remove(3.6)<|eom_id|>ipython + +ValueError: list.remove(x): x not in list<|Assistant|>I will fix the code<|python_tag|>l.remove(round(3.6)) +l<|eom_id|>ipython + +[0,1,2,3,5,]<|Assistant|>The final list is: 0,1,2,3,5<|end▁of▁sentence|><|User|>Thanks! What's the weather in San Francisco?<|Assistant|>{"location": "San Francisco, CA"}<|eom_id|>ipython + +{"result": {"temperature": 20}}<|Assistant|>The temperature in San Francisco is: 20 degree Celcious<|end▁of▁sentence|> \ No newline at end of file diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index a80d25d4..e0d39963 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs): "v3-llama3.1": "meetkai/functionary-small-v3.1", "r1": "deepseek-ai/DeepSeek-R1", "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", } self.image_template_version_to_model_name = { "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" From 0d42f10c9ffd5710a185eb91c0c4e746de8c1266 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 26 Feb 2025 09:35:56 +0000 Subject: [PATCH 05/31] fix template --- .../prompt_template/jinja_templates/r1_distilled_qwen.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt b/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt index 598ba017..bea2dc4f 100644 --- a/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt +++ b/functionary/prompt_template/jinja_templates/r1_distilled_qwen.txt @@ -29,7 +29,12 @@ {%- if message.content %} {{- message.content }} {%- endif %} - {%- for tool_call in message.tool_calls %} + {%- if "tool_calls" in message and message["tool_calls"] is not none -%} + {%- set tool_calls = message["tool_calls"] -%} + {%-else-%} + {%- set tool_calls = [] -%} + {%- endif -%} + {%- for tool_call in tool_calls %} {%- if tool_call.function is defined %} {%- set tool_call = tool_call.function %} {%- endif %} From a77da3bc4a909fc67d2b0163d3e489b2a2781e76 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Mon, 3 Mar 2025 08:55:08 +0700 Subject: [PATCH 06/31] add prompt template for qwen2.5 text only --- functionary/prompt_template/__init__.py | 1 + .../jinja_templates/qwen2.5-text-only.txt | 80 ++++ .../qwen25_text_only_template.py | 429 ++++++++++++++++++ tests/prompt_test_qwen2.5-text-only.txt | 87 ++++ tests/test_prompt_creation.py | 1 + 5 files changed, 598 insertions(+) create mode 100644 functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt create mode 100644 functionary/prompt_template/qwen25_text_only_template.py create mode 100644 tests/prompt_test_qwen2.5-text-only.txt diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index 71ef8a18..ea0ad8dd 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -14,6 +14,7 @@ from functionary.prompt_template.r1_template import R1Template from functionary.prompt_template.r1_distilled_qwen import R1DistilledQwen from functionary.prompt_template.r1_distilled_llama import R1DistilledLlama +from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate def get_available_prompt_template_versions() -> List[PromptTemplate]: diff --git a/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt new file mode 100644 index 00000000..4b453552 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt @@ -0,0 +1,80 @@ +{# version=qwen2.5-text-only #}{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages and messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are a helpful assistant.' }} + {%- endif %} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages and messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if message.role == "user" %} + {%- if message['content'] -%} + {%- if message['content'] is string -%} + {{ '<|im_start|>user\n' + message['content'] }} + {%- else -%} + {{ '<|im_start|>user\n' }} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {{ content['text'] }} + {%- else -%} + {{ '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {{ '<|im_end|>\n' }} + {%- endif -%} + {%- elif (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/qwen25_text_only_template.py b/functionary/prompt_template/qwen25_text_only_template.py new file mode 100644 index 00000000..c9992de9 --- /dev/null +++ b/functionary/prompt_template/qwen25_text_only_template.py @@ -0,0 +1,429 @@ +from functionary.prompt_template.base_template import PromptTemplate +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from functionary.prompt_template import prompt_utils +from functionary.openai_types import Function, Tool +import json +import copy +import math +import re + + +class Qwen25TextOnlyPromptTemplate(PromptTemplate): + version = "qwen2.5-text-only" + start_of_turn = "<|im_start|>" + end_of_turn = "<|im_end|>" + + def get_chat_template_jinja(self) -> str: + path_prefix = "./functionary/prompt_template/jinja_templates/" + with open(f"{path_prefix}{self.version}.txt", "r") as f: + template = f.read() + + return template + + def get_tool_choice_required_prefix(self) -> str: + return "\n" + + def get_prompt_from_messages( + self, + messages: List[Dict], + tools_or_functions: Optional[List[Dict]] = None, + bos_token: Optional[str] = "", + add_generation_prompt: bool = False, + ) -> str: + """This function is used to get the complete prompt for list of messages + + Args: + messages (List[Dict]): List of messages + tools_or_functions (Optional[List[Dict]], optional): List of tools or functions. Defaults to None. + + Returns: + str: the prompt for inference/training + """ + # handle code_interpreter + _tools = [] + if tools_or_functions: + for tool in tools_or_functions: + if tool["type"] == "code_interpreter": + _tools.append( + { + "type": "function", + "function": { + "name": "python", + "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The python code to run", + } + }, + }, + }, + } + ) + else: + _tools.append(tool) + + # find the assistant message that tool_call is python + _messages = [] + for message in messages: + n_message = copy.deepcopy(message) + tool_calls = n_message.get("tool_calls", []) or [] + if len(tool_calls) > 0: + for tool_call in tool_calls: + if tool_call["function"]["name"] == "python": + arguments = tool_call["function"][ + "arguments" + ] # currently the code is in string format + # check if argument is a valid JSON string or python code + try: # if this is a valid JSON string --> no need to change anything + json.loads(arguments) + except: + tool_call["function"]["arguments"] = json.dumps( + {"code": arguments}, ensure_ascii=False + ) + _messages.append(n_message) + prompt = super().get_prompt_from_messages( + messages=_messages, + tools_or_functions=_tools, + bos_token=bos_token, + add_generation_prompt=add_generation_prompt, + ) + return prompt + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return [f"{self.start_of_turn}assistant\n"] + + def get_stop_tokens_for_generation(self) -> List[str]: + return [self.end_of_turn] + + def get_force_function_call_prefix(self, function_name: str): + return """ +{"name": "{function_name}", "arguments""".replace( + "{function_name}", function_name + ) + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def parse_assistant_response( + self, llm_output: str, tool_choice: Any = None + ) -> Dict: + # first remove stop tokens if there exists + for stop in self.get_stop_tokens_for_generation(): + if llm_output.endswith(stop): + llm_output = llm_output[: -len(stop)] + + # add forced-function from tool_choice if exists + llm_output = ( + self.get_generation_prefix_for_tool_choice(tool_choice) + llm_output + ) + + # print(f"+++LLM_OUTPUT: {llm_output}") + llm_output = post_process_llm_output(llm_output) + # print(f"+++LLM_OUTPUT after post-processing: {llm_output}") + text_content = "" + tool_call_strs = [] + + # Split on tool call tags + parts = llm_output.split("") + + if len(parts) > 0: + # First part is the text content + text_content = parts[0].strip() + + # Process remaining parts as tool calls + for part in parts[1:]: + if "" in part: + tool_call = part.split("")[0].strip() + if tool_call: + tool_call_strs.append(tool_call) + tool_calls = [] + for tool_call_str in tool_call_strs: + tool_call_dic = json.loads(tool_call_str) + if tool_call_dic["name"] == "python": + arguments = tool_call_dic["arguments"]["code"] + else: + arguments = json.dumps(tool_call_dic["arguments"], ensure_ascii=False) + tool_calls.append( + { + "type": "function", + "id": prompt_utils.get_random_tool_call_id(), + "function": { + "name": tool_call_dic["name"], + "arguments": arguments, + }, + } + ) + + return { + "role": "assistant", + "content": text_content if len(text_content) > 0 else None, + "tool_calls": None if len(tool_calls) == 0 else tool_calls, + } + + + def initialize_fsm_gen_state( + self, + tool_choice: Union[str, Tool], + curr_text: str, + curr_tokens: Optional[List[int]], + add_code_interpreter: Optional[bool], + ) -> Dict: + """Initializes FSM state for both streaming and grammar sampling + + Args: + tool_choice (str): tool_choice provided by user + curr_text (str): Text to initialize in gen_state + curr_tokens (List[int]): Corresponding tokens of curr_text + add_code_interpreter (bool): Flag indicating whether to add "python" tool in options in "function" stage. + Returns: + Dict: generation state + """ + result = { + "stage": "start", + "func_index": -1, + "curr_text": curr_text, + "curr_tokens": curr_tokens, + "add_code_interpreter": add_code_interpreter, + } + if tool_choice == "required": # a tool must be used + result["stage"] = "function_name" + result["curr_text"] = "" + result["func_index"] += 1 + result["call_id"] = prompt_utils.get_random_tool_call_id() + + elif not isinstance(tool_choice, str): # a predefined tool is used + func_name = ( + tool_choice.name + if hasattr(tool_choice, "name") + else tool_choice.function.name + ) + result["stage"] = "function_name" + result["curr_text"] = '{"name": "{function_name}", "arguments'.replace( + "{function_name}", func_name + ) + result["func_index"] += 1 + result["call_id"] = prompt_utils.get_random_tool_call_id() + + return result + + def stream_delta_text( + self, + gen_state: Dict, + delta_text: str, + finish_reason: Optional[str], + tools_or_functions: List[Dict], + tool_choice: Any, + ) -> Tuple[Dict, Optional[Union[Dict, List[Dict]]]]: + if finish_reason is not None: # handle if finish + if gen_state["stage"] not in ["text_gen"]: + finish_reason = "tool_calls" + + end_response = prompt_utils.get_text_delta_response( + None, False, finish_reason + ) + last_response = None + # still need to check if there is st in buffer + if "buffer" in gen_state and len(gen_state["buffer"]) > 0: + if gen_state["stage"] == "text_gen": + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + last_response = prompt_utils.get_text_delta_response( + buffer_str, False, None + ) + + elif gen_state["stage"] == "function_arguments": + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + if buffer_str.endswith("}}"): + buffer_str = buffer_str[:-1] # remove the last "}" + + if len(buffer_str) > 0: + last_response = prompt_utils.get_function_delta_response( + gen_state, buffer_str, False, False, None + ) + elif gen_state["stage"] == "python": + last_response = return_all_code_from_buffer(gen_state) + + if last_response is not None: + return gen_state, [last_response, end_response] + else: + return gen_state, [end_response] + + current_text = gen_state["curr_text"] + delta_text + gen_state["curr_text"] = current_text + if gen_state["stage"] == "start": + if ( + gen_state.get("end_of_prev_function_call", False) and delta_text == "\n" + ): # ignore \n + gen_state["end_of_prev_function_call"] = False + gen_state["curr_text"] = "" + + elif delta_text == "": + # print(f"delta text: {delta_text}; go to function_name") + gen_state["stage"] = "function_name" + gen_state["curr_text"] = "" + gen_state["func_index"] += 1 + gen_state["call_id"] = prompt_utils.get_random_tool_call_id() + else: + # print(f"delta text: {delta_text}; go to text_gen") + gen_state["stage"] = "text_gen" + gen_state["curr_text"] = current_text + gen_state["buffer"] = ( + [] + ) # put to buffer before we return because we need to check the last item + responses = [ + prompt_utils.get_text_delta_response("", True, finish_reason) + ] + if len(delta_text) > 0: + gen_state["buffer"].append(delta_text) + return gen_state, responses + + elif gen_state["stage"] == "function_name": + # wait until we get '{"name": "func_name", "arguments": {' + # print(f"current_text: {current_text}") + pattern = ( + r'\s*{"name"\s*:\s*"(?P.*)"\s*,\s*"arguments"\s*:\s*{' + ) + match = re.search(pattern, current_text) + if match: + _, end_ind = match.start(), match.end() + new_delta = current_text[end_ind - 1 :] + gen_state["curr_text"] = new_delta # -1 to retain "{" + gen_state["func_name"] = match.group("function_name") + gen_state["stage"] = ( + "function_arguments" + if gen_state["func_name"] != "python" + else "python" + ) + responses = [ + prompt_utils.get_function_delta_response( + gen_state, "", True, True, finish_reason + ) + ] # the chunk containing function_name only + gen_state["buffer"] = [] + if gen_state["func_name"] != "python": + gen_state["buffer"].append(new_delta) + + return gen_state, responses + else: + return gen_state, None + + elif gen_state["stage"] == "text_gen": + if delta_text == "": # start a tool call + # print("start a tool call after reasoning") + gen_state["stage"] = "function_name" + gen_state["curr_text"] = "" + gen_state["func_index"] += 1 + gen_state["call_id"] = prompt_utils.get_random_tool_call_id() + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + if len(buffer_str) > 0: + return gen_state, prompt_utils.get_text_delta_response( + buffer_str, False, finish_reason + ) + else: + gen_state["buffer"].append(delta_text) + if len(gen_state["buffer"]) >= 2: + delta_text_item = gen_state["buffer"].pop(0) + return gen_state, prompt_utils.get_text_delta_response( + delta_text_item, False, finish_reason + ) + + elif gen_state["stage"] == "function_arguments": + # check if current function is python, we need to stream the code string inside, not a json + if delta_text == "": + gen_state["stage"] = "start" + gen_state["curr_text"] = "" + gen_state["end_of_prev_function_call"] = True + # return all in the buffer but need to strip and remove the last "}" + buffer_str = "".join( + gen_state["buffer"] + ).rstrip() # remove \n at the end + if buffer_str.endswith("}}\n"): + buffer_str = buffer_str[:-2] # remove the last "}\n" + elif buffer_str.endswith("}}"): + buffer_str = buffer_str[:-1] # remove the last "}" + + return gen_state, prompt_utils.get_function_delta_response( + gen_state, buffer_str, False, False, finish_reason + ) + else: + gen_state["buffer"].append(delta_text) + if len(gen_state["buffer"]) >= 4: + delta_text_item = gen_state["buffer"].pop(0) + return gen_state, prompt_utils.get_function_delta_response( + gen_state, delta_text_item, False, False, finish_reason + ) + + elif gen_state["stage"] == "python": + return streamining_python_code(gen_state, delta_text) + return gen_state, None + + +def return_all_code_from_buffer(gen_state: Dict) -> Optional[Union[Dict, List[Dict]]]: + buffer_str = "".join(gen_state["buffer"]).rstrip() # remove \n at the end + if len(buffer_str) > 0: + return prompt_utils.get_function_delta_response( + gen_state, buffer_str, False, False, None + ) + return None + + +def streamining_python_code( + gen_state: Dict, delta_text: str +) -> Tuple[Dict, Optional[Union[Dict, List[Dict]]]]: + if "current_code" not in gen_state: + gen_state["current_code"] = "" + return gen_state, None + + current_text = gen_state["curr_text"] + current_code = gen_state["current_code"] + # try extracting the latest code from current_text + try: + if delta_text == "": # end of code + full_code_arg_str = current_text.rstrip("").strip() + if full_code_arg_str.endswith("}}"): + full_code_arg_str = full_code_arg_str[:-1] + new_code = json.loads(full_code_arg_str)["code"] + else: + new_code = json.loads(current_text + '"}')["code"] + delta_code = new_code[len(current_code) :] + + gen_state["buffer"].append(delta_code) + gen_state["current_code"] = new_code + except: # nothing changed + return gen_state, None + + if delta_text == "": + return gen_state, return_all_code_from_buffer(gen_state) + else: + if len(gen_state["buffer"]) >= 4: + delta_text_item = gen_state["buffer"].pop(0) + return gen_state, prompt_utils.get_function_delta_response( + gen_state, delta_text_item, False, False, None + ) + return gen_state, None + + +def match_pattern(pattern: str, text: str) -> Tuple[int, int]: + match = re.search(pattern, text) + if match: + return match.start(), match.end() + return -1, -1 \ No newline at end of file diff --git a/tests/prompt_test_qwen2.5-text-only.txt b/tests/prompt_test_qwen2.5-text-only.txt new file mode 100644 index 00000000..6b386b90 --- /dev/null +++ b/tests/prompt_test_qwen2.5-text-only.txt @@ -0,0 +1,87 @@ +<|im_start|>system +You are a helpful assistant. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +who is the CEO of Meetkai<|im_end|> +<|im_start|>assistant +James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|im_end|> +<|im_start|>user +is the car Song more expensive than car Tang?<|im_end|> +<|im_start|>assistant +I will get the price of 2 cars and compare + +{"name": "get_car_price", "arguments": {"car_name": "Song"}} + + +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} +<|im_end|> +<|im_start|>user + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} +<|im_end|> +<|im_start|>assistant +No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|im_end|> +<|im_start|>user +what's the weather like in Hanoi?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "Hanoi"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 10}} +<|im_end|> +<|im_start|>assistant +The temperature in Hanoi is: 10 degree Celcious<|im_end|> +<|im_start|>user +Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|im_end|> +<|im_start|>assistant +I'll use code interpreter to handle this + +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} +<|im_end|> +<|im_start|>user + +ValueError: list.remove(x): x not in list +<|im_end|> +<|im_start|>assistant +I will fix the code + +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} +<|im_end|> +<|im_start|>user + +[0,1,2,3,5,] +<|im_end|> +<|im_start|>assistant +The final list is: 0,1,2,3,5<|im_end|> +<|im_start|>user +Thanks! What's the weather in San Francisco?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 20}} +<|im_end|> +<|im_start|>assistant +The temperature in San Francisco is: 20 degree Celcious<|im_end|> \ No newline at end of file diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index e0d39963..7271b72c 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -45,6 +45,7 @@ def __init__(self, *args, **kwargs): "r1": "deepseek-ai/DeepSeek-R1", "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "qwen2.5-text-only": "Qwen/Qwen2.5-32B", } self.image_template_version_to_model_name = { "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" From 76bf666cd6559cca200f73203cf8735bfa008236 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Mon, 3 Mar 2025 13:32:14 +0700 Subject: [PATCH 07/31] remove post_process --- functionary/prompt_template/qwen25_text_only_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionary/prompt_template/qwen25_text_only_template.py b/functionary/prompt_template/qwen25_text_only_template.py index c9992de9..850209a2 100644 --- a/functionary/prompt_template/qwen25_text_only_template.py +++ b/functionary/prompt_template/qwen25_text_only_template.py @@ -132,7 +132,7 @@ def parse_assistant_response( ) # print(f"+++LLM_OUTPUT: {llm_output}") - llm_output = post_process_llm_output(llm_output) + # llm_output = post_process_llm_output(llm_output) # print(f"+++LLM_OUTPUT after post-processing: {llm_output}") text_content = "" tool_call_strs = [] From 0de866aa92c93d614135d7f576fff570fb7be917 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 5 Mar 2025 00:31:24 +0700 Subject: [PATCH 08/31] fix r1_distilled_qwen --- functionary/prompt_template/__init__.py | 1 + .../prompt_template/r1_distilled_qwen.py | 71 ++----------------- tests/prompt_test_r1_distilled_qwen.txt | 14 ++-- 3 files changed, 12 insertions(+), 74 deletions(-) diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index ea0ad8dd..becf0c4d 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -37,6 +37,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]: all_templates_obj.append(LlavaLlama.get_prompt_template()) all_templates_obj.append(Llama31ReasoningTemplate.get_prompt_template()) all_templates_obj.append(R1DistilledLlama.get_prompt_template()) + all_templates_obj.append(R1DistilledQwen.get_prompt_template()) return all_templates_obj diff --git a/functionary/prompt_template/r1_distilled_qwen.py b/functionary/prompt_template/r1_distilled_qwen.py index 3c880a7c..40e3908a 100644 --- a/functionary/prompt_template/r1_distilled_qwen.py +++ b/functionary/prompt_template/r1_distilled_qwen.py @@ -5,11 +5,13 @@ from functionary.openai_types import Function, Tool from functionary.prompt_template import prompt_utils -from functionary.prompt_template.base_template import PromptTemplate +from functionary.prompt_template.qwen25_text_only_template import ( + Qwen25TextOnlyPromptTemplate, +) import copy -class R1DistilledQwen(PromptTemplate): +class R1DistilledQwen(Qwen25TextOnlyPromptTemplate): version = "r1_distilled_qwen" chat_template = None @@ -91,71 +93,6 @@ def parse_assistant_response( return {"role": "assistant", "content": text_response, "tool_calls": tool_calls} - def get_prompt_from_messages( - self, - messages: List[Dict], - tools_or_functions: Optional[List[Dict]] = None, - bos_token: Optional[str] = "", - add_generation_prompt: bool = False, - ) -> str: - """This function is used to get the complete prompt for list of messages - - Args: - messages (List[Dict]): List of messages - tools_or_functions (Optional[List[Dict]], optional): List of tools or functions. Defaults to None. - - Returns: - str: the prompt for inference/training - """ - # qwen 2.5 use transformers chat template, need to convert argument string --> dictionary, this is noted in: https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example - # If you’re familiar with the OpenAI API, you should pay attention to an important difference here - the tool_call is a dict, but in the OpenAI API it’s a JSON string. Passing a string may cause errors or strange model behaviour! - new_messages = copy.deepcopy(messages) - for message in new_messages: - tool_calls = message.get("tool_calls", []) - if tool_calls: - for tool_call in tool_calls: - if type(tool_call["function"]["arguments"]) is str: - if tool_call["function"]["name"] != "python": - tool_call["function"]["arguments"] = json.loads( - tool_call["function"]["arguments"] - ) - else: - tool_call["function"] = { - "name": "python", - "arguments": { - "code": tool_call["function"]["arguments"] - }, - } - # check if contain code_interpreter, replace with python - new_tools = copy.deepcopy(tools_or_functions) - if tools_or_functions is not None and len(tools_or_functions) == 0: - new_tools = None - - if new_tools: - for tool in new_tools: - if tool["type"] == "code_interpreter": - tool["type"] = "function" - tool["function"] = { - "name": "python", - "description": 'When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at "/mnt/data" can be used to save and persist user files.', - "parameters": { - "type": "object", - "properties": { - "code": {"type": "string", "description": "Python code"} - }, - "required": ["code"], - }, - } - - prompt = self._jinja_template.render( - messages=new_messages, - tools=new_tools, - bos_token=bos_token, - add_generation_prompt=add_generation_prompt, - ) - - return prompt - def get_chat_template_jinja(self) -> str: if self.chat_template is None: jinja_template_file = ( diff --git a/tests/prompt_test_r1_distilled_qwen.txt b/tests/prompt_test_r1_distilled_qwen.txt index 281fca93..45e1d252 100644 --- a/tests/prompt_test_r1_distilled_qwen.txt +++ b/tests/prompt_test_r1_distilled_qwen.txt @@ -8,7 +8,7 @@ You are provided with function signatures within XML tags: {"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} {"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} -{"type": "function", "function": {"name": "python", "description": "When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \"/mnt/data\" can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "Python code"}}, "required": ["code"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} For each function call, return a json object with function name and arguments within XML tags: @@ -16,10 +16,10 @@ For each function call, return a json object with function name and arguments wi {"name": , "arguments": } <|User|>who is the CEO of Meetkai<|Assistant|>James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|end▁of▁sentence|><|User|>is the car Song more expensive than car Tang?<|Assistant|>I will get the price of 2 cars and compare -{"name": "get_car_price", "arguments": {'car_name': 'Song'}} +{"name": "get_car_price", "arguments": {"car_name": "Song"}} -{"name": "get_car_price", "arguments": {'car_name': 'Tang'}} +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} <|end▁of▁sentence|><|User|> {"price": {"price": "$25000"}} @@ -27,22 +27,22 @@ For each function call, return a json object with function name and arguments wi {"price": {"price": "$20000"}} <|Assistant|>No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|end▁of▁sentence|><|User|>what's the weather like in Hanoi?<|Assistant|> -{"name": "get_weather", "arguments": {'location': 'Hanoi'}} +{"name": "get_weather", "arguments": {"location": "Hanoi"}} <|end▁of▁sentence|><|User|> {"result": {"temperature": 10}} <|Assistant|>The temperature in Hanoi is: 10 degree Celcious<|end▁of▁sentence|><|User|>Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|Assistant|>I'll use code interpreter to handle this -{"name": "python", "arguments": {'code': "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} <|end▁of▁sentence|><|User|> ValueError: list.remove(x): x not in list <|Assistant|>I will fix the code -{"name": "python", "arguments": {'code': 'l.remove(round(3.6))\nl'}} +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} <|end▁of▁sentence|><|User|> [0,1,2,3,5,] <|Assistant|>The final list is: 0,1,2,3,5<|end▁of▁sentence|><|User|>Thanks! What's the weather in San Francisco?<|Assistant|> -{"name": "get_weather", "arguments": {'location': 'San Francisco, CA'}} +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} <|end▁of▁sentence|><|User|> {"result": {"temperature": 20}} From 39df7c2b046682eae89c0a560fb8960039ba1638 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 5 Mar 2025 10:54:41 +0700 Subject: [PATCH 09/31] update streaming for qwen template --- functionary/prompt_template/__init__.py | 3 +- .../prompt_template/r1_distilled_qwen.py | 58 ------------------- ...r1_template.py => r1_original_template.py} | 0 3 files changed, 1 insertion(+), 60 deletions(-) rename functionary/prompt_template/{r1_template.py => r1_original_template.py} (100%) diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index becf0c4d..d5982dc3 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -11,12 +11,11 @@ from functionary.prompt_template.llama31_reasoning_prompt_template import ( Llama31ReasoningTemplate, ) -from functionary.prompt_template.r1_template import R1Template +from functionary.prompt_template.r1_original_template import R1Template from functionary.prompt_template.r1_distilled_qwen import R1DistilledQwen from functionary.prompt_template.r1_distilled_llama import R1DistilledLlama from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate - def get_available_prompt_template_versions() -> List[PromptTemplate]: """This function will get all the available prompt templates in the module. diff --git a/functionary/prompt_template/r1_distilled_qwen.py b/functionary/prompt_template/r1_distilled_qwen.py index 40e3908a..34541a4d 100644 --- a/functionary/prompt_template/r1_distilled_qwen.py +++ b/functionary/prompt_template/r1_distilled_qwen.py @@ -35,64 +35,6 @@ def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Di def get_stop_tokens_for_generation(self) -> List[str]: return ["<|end▁of▁sentence|>"] - def parse_assistant_response( - self, llm_output: str, tool_choice: Any = None - ) -> Dict: - # first remove stop tokens if there exists - for stop in self.get_stop_tokens_for_generation(): - if llm_output.endswith(stop): - llm_output = llm_output[: -len(stop)] - - # add forced-function from tool_choice if exists - llm_output = ( - self.get_generation_prefix_for_tool_choice(tool_choice) + llm_output - ) - - tool_calls = [] - text_response = "" - - while len(llm_output) > 0: - start_tool_call_index = llm_output.find("") - if start_tool_call_index >= 0: - end_index = llm_output.find("", start_tool_call_index) - if end_index >= 0: - json_between = llm_output[ - start_tool_call_index + len("") : end_index - ] - func_call = json.loads(json_between) - tool_calls.append( - { - "type": "function", - "id": prompt_utils.get_random_tool_call_id(), - "function": { - "name": func_call["name"], - "arguments": json.dumps( - func_call["arguments"], ensure_ascii=False - ), - }, - } - ) - index = end_index + len("") - - text_response += llm_output[:start_tool_call_index].strip() - llm_output = llm_output[index:] - else: # cannot find at the end - text_response += llm_output - llm_output = "" - else: # cannot find - text_response += llm_output - llm_output = "" - - if not text_response: - text_response = None - elif len(text_response.strip()) == 0: - text_response = None - - if not tool_calls: - tool_calls = None - - return {"role": "assistant", "content": text_response, "tool_calls": tool_calls} - def get_chat_template_jinja(self) -> str: if self.chat_template is None: jinja_template_file = ( diff --git a/functionary/prompt_template/r1_template.py b/functionary/prompt_template/r1_original_template.py similarity index 100% rename from functionary/prompt_template/r1_template.py rename to functionary/prompt_template/r1_original_template.py From 47454aefdc87225984f0caf29ebc316dbcf37fe6 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 6 Mar 2025 16:24:37 +0700 Subject: [PATCH 10/31] remove saving state --- functionary/train/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/functionary/train/train.py b/functionary/train/train.py index ee95e200..84317edc 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -57,7 +57,7 @@ def lr_lambda(current_step): from functionary.prompt_template import PromptTemplate, get_prompt_template_by_version from functionary.train.custom_datasets import read_dataset from functionary.train import training_utils -from training_utils import print_rank0 +from functionary.train.training_utils import print_rank0 LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -283,7 +283,7 @@ def compute_metrics(eval_preds): else: trainer.train() - trainer.save_state() + # trainer.save_state() # FSDP requires state_dict_type=FULL_STATE_DICT in order to save the model weights in .bin format if trainer.is_fsdp_enabled: From 99fed6d6b1cd5fc8d78bd8c85ce70096d3f318a9 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Mon, 10 Mar 2025 17:07:58 +0000 Subject: [PATCH 11/31] prefill --- .../prompt_template/jinja_templates/qwen2.5-text-only.txt | 2 +- functionary/prompt_template/qwen25_text_only_template.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt index 4b453552..79d91ea1 100644 --- a/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt +++ b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt @@ -76,5 +76,5 @@ {%- endif %} {%- endfor %} {%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} + {{- '<|im_start|>assistant\n\n' }} {%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/qwen25_text_only_template.py b/functionary/prompt_template/qwen25_text_only_template.py index 850209a2..6b794a2f 100644 --- a/functionary/prompt_template/qwen25_text_only_template.py +++ b/functionary/prompt_template/qwen25_text_only_template.py @@ -170,7 +170,7 @@ def parse_assistant_response( return { "role": "assistant", - "content": text_content if len(text_content) > 0 else None, + "content": "\n" + text_content if len(text_content) > 0 else None, "tool_calls": None if len(tool_calls) == 0 else tool_calls, } @@ -290,6 +290,7 @@ def stream_delta_text( responses = [ prompt_utils.get_text_delta_response("", True, finish_reason) ] + gen_state["buffer"].extend(["", "\n"]) if len(delta_text) > 0: gen_state["buffer"].append(delta_text) return gen_state, responses From e5a2c8ddf299e94d55fb4725719ff9b2463f81c7 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Tue, 11 Mar 2025 14:35:18 +0700 Subject: [PATCH 12/31] increase default max length --- functionary/openai_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionary/openai_types.py b/functionary/openai_types.py index b382c611..e7dc5a4a 100644 --- a/functionary/openai_types.py +++ b/functionary/openai_types.py @@ -123,7 +123,7 @@ class ChatCompletionRequest(BaseModel): temperature: Optional[float] = 0.6 top_p: Optional[float] = 1.0 n: Optional[int] = 1 - max_tokens: Optional[int] = 512 + max_tokens: Optional[int] = 4096 stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False presence_penalty: Optional[float] = 0.0 From 3f3e57c34ed3b8952e6d5f60719b59188cb83c48 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 13 Mar 2025 10:50:08 +0700 Subject: [PATCH 13/31] add gemma3 template --- functionary/prompt_template/__init__.py | 3 + .../prompt_template/gemma3_prompt_template.py | 38 +++++++++++++ .../jinja_templates/gemma3.txt | 57 +++++++++++++++++++ tests/test_prompt_creation.py | 19 ++++--- 4 files changed, 108 insertions(+), 9 deletions(-) create mode 100644 functionary/prompt_template/gemma3_prompt_template.py create mode 100644 functionary/prompt_template/jinja_templates/gemma3.txt diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index d5982dc3..29158db6 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -15,6 +15,8 @@ from functionary.prompt_template.r1_distilled_qwen import R1DistilledQwen from functionary.prompt_template.r1_distilled_llama import R1DistilledLlama from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate +from functionary.prompt_template.gemma3_prompt_template import Gemma3Template + def get_available_prompt_template_versions() -> List[PromptTemplate]: """This function will get all the available prompt templates in the module. @@ -37,6 +39,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]: all_templates_obj.append(Llama31ReasoningTemplate.get_prompt_template()) all_templates_obj.append(R1DistilledLlama.get_prompt_template()) all_templates_obj.append(R1DistilledQwen.get_prompt_template()) + all_templates_obj.append(Gemma3Template.get_prompt_template()) return all_templates_obj diff --git a/functionary/prompt_template/gemma3_prompt_template.py b/functionary/prompt_template/gemma3_prompt_template.py new file mode 100644 index 00000000..6ab69675 --- /dev/null +++ b/functionary/prompt_template/gemma3_prompt_template.py @@ -0,0 +1,38 @@ +import json +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from functionary.openai_types import Function, Tool +from functionary.prompt_template import prompt_utils +from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate + + +class Gemma3Template(Qwen25TextOnlyPromptTemplate): + version = "gemma3" + + def get_additional_tokens(self) -> List[str]: + return [] + + def get_assistant_prefixes(self) -> List[str]: + return [ + "model\n" + ] + + def get_stop_tokens_for_generation(self) -> List[str]: + return [""] + + def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]: + """Order the tool results by the order of tool call ids + + Args: + messages (List[Dict]): List of messages + + Returns: + List[Dict]: List of messages + """ + return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages) + + def get_chat_template_jinja(self) -> str: + path_prefix = "./functionary/prompt_template/jinja_templates/" + with open(f"{path_prefix}{self.version}.txt", "r") as f: + template = f.read() + return template \ No newline at end of file diff --git a/functionary/prompt_template/jinja_templates/gemma3.txt b/functionary/prompt_template/jinja_templates/gemma3.txt new file mode 100644 index 00000000..3e3dc484 --- /dev/null +++ b/functionary/prompt_template/jinja_templates/gemma3.txt @@ -0,0 +1,57 @@ +{# version=gemma3 #}{%- if tools %} + {{- 'system\nYou are capable of executing available function(s) if required.' }} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n" }} +{%- else %} + {%- if messages and messages[0]['role'] == 'system' %} + {{- 'system\n' + messages[0]['content'] + '\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system") -%} + {{- '' + message.role + '\n' + message.content + '' + '\n' }} + {%- elif message.role == "assistant" %} + {{- 'model' }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- 'tool' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- 'model\n' }} +{%- endif %} \ No newline at end of file diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index 7271b72c..1d2531d5 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -38,14 +38,15 @@ def __init__(self, *args, **kwargs): super(TestPromptTemplate, self).__init__(*args, **kwargs) self.template_version_to_model_name = { - "v2": "meetkai/functionary-small-v2.4", - "v2.llama3": "meetkai/functionary-small-v2.5", - "v3.llama3": "meetkai/functionary-medium-v3.0", - "v3-llama3.1": "meetkai/functionary-small-v3.1", - "r1": "deepseek-ai/DeepSeek-R1", - "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", - "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", - "qwen2.5-text-only": "Qwen/Qwen2.5-32B", + # "v2": "meetkai/functionary-small-v2.4", + # "v2.llama3": "meetkai/functionary-small-v2.5", + # "v3.llama3": "meetkai/functionary-medium-v3.0", + # "v3-llama3.1": "meetkai/functionary-small-v3.1", + # "r1": "deepseek-ai/DeepSeek-R1", + # "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + # "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + # "qwen2.5-text-only": "Qwen/Qwen2.5-32B", + "gemma3": "google/gemma-3-27b-it", } self.image_template_version_to_model_name = { "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" @@ -150,7 +151,7 @@ def run_prepare_training_inputs( messages=test_case, tokenizer=tokenizer, padding="longest", - max_length=1024, + max_length=2048, return_tensor=False, verbose=True, keep_assistant_prefix=keep_assistant_prefix, From 1dd0ad2d6dbc5c68790d1ca66c6eaef6b97fe5ca Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 13 Mar 2025 10:51:20 +0700 Subject: [PATCH 14/31] add gemma3 example --- tests/prompt_test_gemma3.txt | 87 ++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests/prompt_test_gemma3.txt diff --git a/tests/prompt_test_gemma3.txt b/tests/prompt_test_gemma3.txt new file mode 100644 index 00000000..fb0bca16 --- /dev/null +++ b/tests/prompt_test_gemma3.txt @@ -0,0 +1,87 @@ +system +You are capable of executing available function(s) if required. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +user +who is the CEO of Meetkai +model +James Kaplan is the Co-Founder and CEO of MeetKai Inc. +user +is the car Song more expensive than car Tang? +model +I will get the price of 2 cars and compare + +{"name": "get_car_price", "arguments": {"car_name": "Song"}} + + +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} + +tool + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} + +model +No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000. +user +what's the weather like in Hanoi? +model + +{"name": "get_weather", "arguments": {"location": "Hanoi"}} + +tool + +{"result": {"temperature": 10}} + +model +The temperature in Hanoi is: 10 degree Celcious +user +Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most +model +I'll use code interpreter to handle this + +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} + +tool + +ValueError: list.remove(x): x not in list + +model +I will fix the code + +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} + +tool + +[0,1,2,3,5,] + +model +The final list is: 0,1,2,3,5 +user +Thanks! What's the weather in San Francisco? +model + +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} + +tool + +{"result": {"temperature": 20}} + +model +The temperature in San Francisco is: 20 degree Celcious \ No newline at end of file From 2f761bedfcfccfa104c160f69e5cd35908df351b Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 13 Mar 2025 06:20:01 +0000 Subject: [PATCH 15/31] fix template --- functionary/prompt_template/jinja_templates/gemma3.txt | 9 ++++++++- functionary/train/packing/assert_packing_loss.py | 3 ++- functionary/train/packing/monkey_patch_packing.py | 2 +- functionary/train/train.py | 4 +++- functionary/train/training_utils.py | 3 ++- 5 files changed, 16 insertions(+), 5 deletions(-) diff --git a/functionary/prompt_template/jinja_templates/gemma3.txt b/functionary/prompt_template/jinja_templates/gemma3.txt index 3e3dc484..cc75ce71 100644 --- a/functionary/prompt_template/jinja_templates/gemma3.txt +++ b/functionary/prompt_template/jinja_templates/gemma3.txt @@ -29,7 +29,14 @@ {%- if message.content %} {{- '\n' + message.content }} {%- endif %} - {%- for tool_call in message.tool_calls %} + + {%- if "tool_calls" in message and message["tool_calls"] is not none -%} + {%- set tool_calls = message["tool_calls"] -%} + {%-else-%} + {%- set tool_calls = [] -%} + {%- endif -%} + + {%- for tool_call in tool_calls %} {%- if tool_call.function is defined %} {%- set tool_call = tool_call.function %} {%- endif %} diff --git a/functionary/train/packing/assert_packing_loss.py b/functionary/train/packing/assert_packing_loss.py index bdb4d813..e2fcd58b 100644 --- a/functionary/train/packing/assert_packing_loss.py +++ b/functionary/train/packing/assert_packing_loss.py @@ -12,6 +12,7 @@ from packed_dataset import PackedDataset from torch.utils.data import DataLoader, Dataset from transformers import AutoTokenizer +from transformers import Gemma3ForConditionalGeneration random.seed(1) torch.manual_seed(3) @@ -79,7 +80,7 @@ def compute_loss_for_model_class( Returns: _type_: _description_ """ - model = transformers.AutoModelForCausalLM.from_pretrained( + model = Gemma3ForConditionalGeneration.from_pretrained( pretrained_path, torch_dtype=torch.bfloat16, device_map="auto", diff --git a/functionary/train/packing/monkey_patch_packing.py b/functionary/train/packing/monkey_patch_packing.py index 2cbfff8c..05f34b48 100644 --- a/functionary/train/packing/monkey_patch_packing.py +++ b/functionary/train/packing/monkey_patch_packing.py @@ -7,7 +7,7 @@ def get_max_seqlen_in_batch(attention_mask): max_num = torch.max(attention_mask) - # attention_mask: B x N + max_num = int(max_num) # there are cases that type of attention_mask is bfloat16 such as Gemma3ForConditionalGeneration counts = [] for i in range(1, max_num + 1): counts.append( diff --git a/functionary/train/train.py b/functionary/train/train.py index 84317edc..62da9836 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -184,6 +184,8 @@ def train(): model_class = AutoLigerKernelForCausalLM else: model_class = transformers.AutoModelForCausalLM + from transformers import Gemma3ForConditionalGeneration + model_class = Gemma3ForConditionalGeneration model = model_class.from_pretrained( model_args.model_name_or_path, @@ -250,7 +252,7 @@ def train(): if training_args.do_eval: print_rank0("***** HERE ARE SOME EXAMPLES FROM EVALUATION ***") - training_utils.print_some_examples(eval_dataset, tokenizer) + # training_utils.print_some_examples(eval_dataset, tokenizer) def preprocess_logits_for_metrics(logits, labels): return training_utils.preprocess_logits_for_metrics( diff --git a/functionary/train/training_utils.py b/functionary/train/training_utils.py index 6671be67..be0e190e 100644 --- a/functionary/train/training_utils.py +++ b/functionary/train/training_utils.py @@ -50,7 +50,7 @@ def initialize_tokenizer( # add chat_template for tokenizer tokenizer.chat_template = prompt_template.get_chat_template_jinja() - print("tokenizer: ", tokenizer) + # print("tokenizer: ", tokenizer) # Resize embedding model.resize_token_embeddings(len(tokenizer)) @@ -218,6 +218,7 @@ def print_some_examples(ds, tokenizer): print_rank0("shape of input_ids: ", batch["input_ids"].shape) # B x L print_rank0("shape of labels: ", batch["labels"].shape) print_rank0("shape of attention_mask: ", batch["attention_mask"].shape) + print_rank0("attetnion_mask", batch["attention_mask"]) # print_rank0('input_ids: ', batch["input_ids"].tolist()) # print_rank0('labels: ', batch["labels"].tolist()) print_rank0("attention mask: ", batch["attention_mask"]) From 9a2737b75932b81fc3fc9c5fc8b7131c124ad6ba Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 26 Mar 2025 08:04:23 +0700 Subject: [PATCH 16/31] comment gemma3 --- functionary/train/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/functionary/train/train.py b/functionary/train/train.py index 62da9836..2afb2de6 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -184,8 +184,8 @@ def train(): model_class = AutoLigerKernelForCausalLM else: model_class = transformers.AutoModelForCausalLM - from transformers import Gemma3ForConditionalGeneration - model_class = Gemma3ForConditionalGeneration + # from transformers import Gemma3ForConditionalGeneration + # model_class = Gemma3ForConditionalGeneration model = model_class.from_pretrained( model_args.model_name_or_path, From b24c4a4dabd01fab722a8956614d9cea8cc71d86 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 26 Mar 2025 09:58:43 +0700 Subject: [PATCH 17/31] change version of flash-attn --- functionary/train/pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/functionary/train/pyproject.toml b/functionary/train/pyproject.toml index cfb115a2..e149b503 100644 --- a/functionary/train/pyproject.toml +++ b/functionary/train/pyproject.toml @@ -18,8 +18,8 @@ dependencies = [ "tensorboard==2.15.1", "aenum==3.1.15", "transformers @ git+https://github.com/huggingface/transformers.git", - "flash-attn==2.6.3", "json_source_map==1.0.5", + "flash-attn==2.7.4.post1" ] [build-system] @@ -38,3 +38,5 @@ lora = [ "peft==0.13.2", "datasets==2.8.0", ] +[tool.pip] +install-options = { "flash-attn" = ["--no-build-isolation"] } \ No newline at end of file From 8cd9a1b9b0185bfb041ea1eb884156b296f34aed Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 3 Apr 2025 10:57:24 +0700 Subject: [PATCH 18/31] update requirements for training --- functionary/train/pyproject.toml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/functionary/train/pyproject.toml b/functionary/train/pyproject.toml index e149b503..44502ac0 100644 --- a/functionary/train/pyproject.toml +++ b/functionary/train/pyproject.toml @@ -4,11 +4,13 @@ version = "0.0.1" description = "Chat language model that can use tools and interpret the results" requires-python = ">=3.9" dependencies = [ + "wheel==0.45.1", + "setuptools==78.0.2", "accelerate==0.34.0", - "bitsandbytes==0.44.1", + "bitsandbytes==0.45.4", + "packaging", "scipy==1.11.3", "sentencepiece==0.1.99", - "packaging==23.1", "ninja==1.11.1", "einops==0.7.0", "wandb==0.15.11", @@ -17,7 +19,7 @@ dependencies = [ "typer==0.9.0", "tensorboard==2.15.1", "aenum==3.1.15", - "transformers @ git+https://github.com/huggingface/transformers.git", + "transformers==4.50.3", "json_source_map==1.0.5", "flash-attn==2.7.4.post1" ] @@ -32,7 +34,7 @@ packages = ["train"] [project.optional-dependencies] liger = [ - "liger-kernel==0.3.1", + "liger-kernel==0.5.5", ] lora = [ "peft==0.13.2", From 1efe410816cde9c0cdb3a099baabafb56cc94f58 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Thu, 3 Apr 2025 11:06:29 +0700 Subject: [PATCH 19/31] remve flash-attn from pyproject.toml --- functionary/train/pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/functionary/train/pyproject.toml b/functionary/train/pyproject.toml index 44502ac0..3584080f 100644 --- a/functionary/train/pyproject.toml +++ b/functionary/train/pyproject.toml @@ -20,8 +20,7 @@ dependencies = [ "tensorboard==2.15.1", "aenum==3.1.15", "transformers==4.50.3", - "json_source_map==1.0.5", - "flash-attn==2.7.4.post1" + "json_source_map==1.0.5" ] [build-system] From c366fdcd0e70624b19bfc50793e527b054da1981 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Fri, 4 Apr 2025 04:01:56 +0000 Subject: [PATCH 20/31] remove prefill in prompt --- .../prompt_template/jinja_templates/qwen2.5-text-only.txt | 2 +- functionary/prompt_template/qwen25_text_only_template.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt index 79d91ea1..4b453552 100644 --- a/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt +++ b/functionary/prompt_template/jinja_templates/qwen2.5-text-only.txt @@ -76,5 +76,5 @@ {%- endif %} {%- endfor %} {%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n\n' }} + {{- '<|im_start|>assistant\n' }} {%- endif %} \ No newline at end of file diff --git a/functionary/prompt_template/qwen25_text_only_template.py b/functionary/prompt_template/qwen25_text_only_template.py index 6b794a2f..a7690dc6 100644 --- a/functionary/prompt_template/qwen25_text_only_template.py +++ b/functionary/prompt_template/qwen25_text_only_template.py @@ -170,7 +170,7 @@ def parse_assistant_response( return { "role": "assistant", - "content": "\n" + text_content if len(text_content) > 0 else None, + "content": text_content if len(text_content) > 0 else None, "tool_calls": None if len(tool_calls) == 0 else tool_calls, } @@ -290,7 +290,7 @@ def stream_delta_text( responses = [ prompt_utils.get_text_delta_response("", True, finish_reason) ] - gen_state["buffer"].extend(["", "\n"]) + # gen_state["buffer"].extend(["", "\n"]) if len(delta_text) > 0: gen_state["buffer"].append(delta_text) return gen_state, responses From 13e93357fe28fffbec74ba688c88e5a01799db14 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Fri, 4 Apr 2025 04:03:17 +0000 Subject: [PATCH 21/31] remove imports to avoid errors --- functionary/vllm_inference.py | 18 ++++++++++-------- server_vllm.py | 6 ++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index 166d2d66..fcbb9366 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -5,10 +5,10 @@ from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse -from vllm.entrypoints.openai.protocol import ( - LoadLoraAdapterRequest, - UnloadLoraAdapterRequest, -) +# from vllm.entrypoints.openai.protocol import ( +# LoadLoraAdapterRequest, +# UnloadLoraAdapterRequest, +# ) from vllm.inputs import TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -44,7 +44,9 @@ async def check_length(request, input_ids, model_config): - if hasattr(model_config.hf_config, "max_sequence_length"): + if hasattr(model_config, "max_model_len"): + context_len = model_config.max_model_len + elif hasattr(model_config.hf_config, "max_sequence_length"): context_len = model_config.hf_config.max_sequence_length elif hasattr(model_config.hf_config, "seq_length"): context_len = model_config.hf_config.seq_length @@ -90,7 +92,7 @@ async def check_length(request, input_ids, model_config): async def process_load_lora_adapter( - request: LoadLoraAdapterRequest, + request: Any, served_loras: List[LoRARequest], lora_id_counter: AtomicCounter, ) -> Tuple[Union[str, JSONResponse], List[LoRARequest]]: @@ -128,8 +130,8 @@ async def process_load_lora_adapter( async def process_unload_lora_adapter( - request: UnloadLoraAdapterRequest, served_loras: List[LoRARequest] -) -> Tuple[Union[str, JSONResponse], List[LoRARequest]]: + request: Any, served_loras: List[Any] +) -> Tuple[Union[str, JSONResponse], List[Any]]: # Check if either 'lora_name' or 'lora_int_id' is provided if not request.lora_name and not request.lora_int_id: return ( diff --git a/server_vllm.py b/server_vllm.py index 0e5121f2..50ab59dd 100644 --- a/server_vllm.py +++ b/server_vllm.py @@ -32,11 +32,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.api_server import mount_metrics from vllm.entrypoints.openai.protocol import ( - LoadLoraAdapterRequest, ModelCard, ModelList, ModelPermission, - UnloadLoraAdapterRequest, ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -131,7 +129,7 @@ async def create_chat_completion(raw_request: Request): @app.post("/v1/load_lora_adapter") -async def load_lora_adapter(request: LoadLoraAdapterRequest): +async def load_lora_adapter(request: Any): global served_loras error, served_loras = await process_load_lora_adapter( @@ -145,7 +143,7 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest): @app.post("/v1/unload_lora_adapter") -async def unload_lora_adapter(request: UnloadLoraAdapterRequest): +async def unload_lora_adapter(request: Any): global served_loras error, served_loras = await process_unload_lora_adapter(request, served_loras) From de90233bceb94740f8ca33047be890fb3e968fbb Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Fri, 4 Apr 2025 04:03:58 +0000 Subject: [PATCH 22/31] upgrade vllm version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 83067ff5..806fcf35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ packages = ["functionary"] [project.optional-dependencies] vllm = [ - "vllm==0.6.3.post1; sys_platform != 'darwin'", + "vllm==0.8.2; sys_platform != 'darwin'", ] sglang = [ "python-multipart==0.0.12", From b88d55ccee9a76b84d0a56c6cec2250496582407 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Sun, 6 Apr 2025 17:20:26 +0000 Subject: [PATCH 23/31] fix prompt template for gemma3 --- functionary/prompt_template/jinja_templates/gemma3.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionary/prompt_template/jinja_templates/gemma3.txt b/functionary/prompt_template/jinja_templates/gemma3.txt index cc75ce71..c7db6b81 100644 --- a/functionary/prompt_template/jinja_templates/gemma3.txt +++ b/functionary/prompt_template/jinja_templates/gemma3.txt @@ -1,4 +1,4 @@ -{# version=gemma3 #}{%- if tools %} +{# version=gemma3 #}{{- bos_token -}}{%- if tools %} {{- 'system\nYou are capable of executing available function(s) if required.' }} {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} {%- if has_reasoning -%} From 9129d9b92e8fbc16568e2d4ec2c4d9ec2b3c4768 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Tue, 8 Apr 2025 08:33:15 +0700 Subject: [PATCH 24/31] set default top_p as 0.95 --- functionary/openai_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionary/openai_types.py b/functionary/openai_types.py index e7dc5a4a..f4298433 100644 --- a/functionary/openai_types.py +++ b/functionary/openai_types.py @@ -121,7 +121,7 @@ class ChatCompletionRequest(BaseModel): function_call: Optional[Union[str, Function]] = None tool_choice: Optional[Union[str, Tool]] = None temperature: Optional[float] = 0.6 - top_p: Optional[float] = 1.0 + top_p: Optional[float] = 0.95 n: Optional[int] = 1 max_tokens: Optional[int] = 4096 stop: Optional[Union[str, List[str]]] = Field(default_factory=list) From 28b2ac9ee614785f117d6cf243f4e8327bfd3100 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Mon, 14 Apr 2025 00:02:11 +0700 Subject: [PATCH 25/31] add cogito --- functionary/prompt_template/__init__.py | 3 +- functionary/prompt_template/prompt_utils.py | 51 +++++++++++++++++++ .../qwen25_text_only_template.py | 47 ++--------------- tests/test_prompt_creation.py | 1 + 4 files changed, 59 insertions(+), 43 deletions(-) diff --git a/functionary/prompt_template/__init__.py b/functionary/prompt_template/__init__.py index 29158db6..cc0ca702 100644 --- a/functionary/prompt_template/__init__.py +++ b/functionary/prompt_template/__init__.py @@ -16,7 +16,7 @@ from functionary.prompt_template.r1_distilled_llama import R1DistilledLlama from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate from functionary.prompt_template.gemma3_prompt_template import Gemma3Template - +from functionary.prompt_template.cogito_prompt_template import CogitoPromptTemplate def get_available_prompt_template_versions() -> List[PromptTemplate]: """This function will get all the available prompt templates in the module. @@ -40,6 +40,7 @@ def get_available_prompt_template_versions() -> List[PromptTemplate]: all_templates_obj.append(R1DistilledLlama.get_prompt_template()) all_templates_obj.append(R1DistilledQwen.get_prompt_template()) all_templates_obj.append(Gemma3Template.get_prompt_template()) + all_templates_obj.append(CogitoPromptTemplate.get_prompt_template()) return all_templates_obj diff --git a/functionary/prompt_template/prompt_utils.py b/functionary/prompt_template/prompt_utils.py index 92522655..c6674be2 100644 --- a/functionary/prompt_template/prompt_utils.py +++ b/functionary/prompt_template/prompt_utils.py @@ -3,6 +3,7 @@ import random import string from copy import deepcopy +import json from io import BytesIO from typing import Dict, List, Optional, Union @@ -284,3 +285,53 @@ def resolve_json_refs(tools_or_functions): ) return tools + + +def convert_code_interpreter_to_function(tools: List[Dict]) -> List[Dict]: + _tools = [] + if tools: + for tool in tools: + if tool["type"] == "code_interpreter": + _tools.append( + { + "type": "function", + "function": { + "name": "python", + "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The python code to run", + } + }, + }, + }, + } + ) + else: + _tools.append(tool) + return _tools + + +def convert_code_interpreter_tool_calls(messages: List[Dict]) -> List[Dict]: + _messages = [] + for message in messages: + n_message = deepcopy(message) + tool_calls = n_message.get("tool_calls", []) or [] + if len(tool_calls) > 0: + for tool_call in tool_calls: + if tool_call["function"]["name"] == "python": + arguments = tool_call["function"][ + "arguments" + ] # currently the code is in string format + # check if argument is a valid JSON string or python code + try: # if this is a valid JSON string --> no need to change anything + json.loads(arguments) + except: + tool_call["function"]["arguments"] = json.dumps( + {"code": arguments}, ensure_ascii=False + ) + _messages.append(n_message) + return _messages diff --git a/functionary/prompt_template/qwen25_text_only_template.py b/functionary/prompt_template/qwen25_text_only_template.py index a7690dc6..fef076d3 100644 --- a/functionary/prompt_template/qwen25_text_only_template.py +++ b/functionary/prompt_template/qwen25_text_only_template.py @@ -42,48 +42,12 @@ def get_prompt_from_messages( # handle code_interpreter _tools = [] if tools_or_functions: - for tool in tools_or_functions: - if tool["type"] == "code_interpreter": - _tools.append( - { - "type": "function", - "function": { - "name": "python", - "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "The python code to run", - } - }, - }, - }, - } - ) - else: - _tools.append(tool) + _tools = prompt_utils.convert_code_interpreter_to_function( + tools_or_functions + ) # find the assistant message that tool_call is python - _messages = [] - for message in messages: - n_message = copy.deepcopy(message) - tool_calls = n_message.get("tool_calls", []) or [] - if len(tool_calls) > 0: - for tool_call in tool_calls: - if tool_call["function"]["name"] == "python": - arguments = tool_call["function"][ - "arguments" - ] # currently the code is in string format - # check if argument is a valid JSON string or python code - try: # if this is a valid JSON string --> no need to change anything - json.loads(arguments) - except: - tool_call["function"]["arguments"] = json.dumps( - {"code": arguments}, ensure_ascii=False - ) - _messages.append(n_message) + _messages = prompt_utils.convert_code_interpreter_tool_calls(messages) prompt = super().get_prompt_from_messages( messages=_messages, tools_or_functions=_tools, @@ -174,7 +138,6 @@ def parse_assistant_response( "tool_calls": None if len(tool_calls) == 0 else tool_calls, } - def initialize_fsm_gen_state( self, tool_choice: Union[str, Tool], @@ -427,4 +390,4 @@ def match_pattern(pattern: str, text: str) -> Tuple[int, int]: match = re.search(pattern, text) if match: return match.start(), match.end() - return -1, -1 \ No newline at end of file + return -1, -1 diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index 1d2531d5..0d71a7cb 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -47,6 +47,7 @@ def __init__(self, *args, **kwargs): # "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # "qwen2.5-text-only": "Qwen/Qwen2.5-32B", "gemma3": "google/gemma-3-27b-it", + "cogito": "deepcogito/cogito-v1-preview-qwen-32B", } self.image_template_version_to_model_name = { "v3.llava_llama": "meetkai/functionary-vision-small-v0.1" From bfd97e415dc59732cfdb4bcce9efa6b508f895b0 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Mon, 14 Apr 2025 00:35:45 +0700 Subject: [PATCH 26/31] add files for cogito --- .../prompt_template/cogito_prompt_template.py | 14 +++ .../jinja_templates/cogito.txt | 80 +++++++++++++++++ tests/prompt_test_cogito.txt | 87 +++++++++++++++++++ 3 files changed, 181 insertions(+) create mode 100644 functionary/prompt_template/cogito_prompt_template.py create mode 100644 functionary/prompt_template/jinja_templates/cogito.txt create mode 100644 tests/prompt_test_cogito.txt diff --git a/functionary/prompt_template/cogito_prompt_template.py b/functionary/prompt_template/cogito_prompt_template.py new file mode 100644 index 00000000..16d7eea5 --- /dev/null +++ b/functionary/prompt_template/cogito_prompt_template.py @@ -0,0 +1,14 @@ +from functionary.prompt_template.base_template import PromptTemplate +from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from functionary.prompt_template import prompt_utils +from functionary.openai_types import Function, Tool + +class CogitoPromptTemplate(Qwen25TextOnlyPromptTemplate): + version = "cogito" + + def get_chat_template_jinja(self) -> str: + path_prefix = "./functionary/prompt_template/jinja_templates/" + with open(f"{path_prefix}cogito.txt", "r") as f: + template = f.read() + return template diff --git a/functionary/prompt_template/jinja_templates/cogito.txt b/functionary/prompt_template/jinja_templates/cogito.txt new file mode 100644 index 00000000..b96b43fc --- /dev/null +++ b/functionary/prompt_template/jinja_templates/cogito.txt @@ -0,0 +1,80 @@ +{# version=cogito #}{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages and messages[0]['role'] == 'system' %} + {{- 'Enable deep thinking subroutine.\n\n' + messages[0]['content'] }} + {%- else %} + {{- 'Enable deep thinking subroutine.' }} + {%- endif %} + {%- set has_reasoning = tools | selectattr("type", "equalto", "reasoning") | list | length > 0 -%} + {%- if has_reasoning -%} + {%- set tools = tools | rejectattr("type", "equalto", "reasoning") | list -%} + {%- endif -%} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- if has_reasoning %} + {{- " Please generate reasoning before deciding which tools to use." }} + {%- else %} + {{- "" }} + {%- endif %} + {{- "\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages and messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\nEnable deep thinking subroutine\n\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nEnable deep thinking subroutine.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if message.role == "user" %} + {%- if message['content'] -%} + {%- if message['content'] is string -%} + {{ '<|im_start|>user\n' + message['content'] }} + {%- else -%} + {{ '<|im_start|>user\n' }} + {%- for content in message['content'] -%} + {%- if content['type'] == 'text' -%} + {{ content['text'] }} + {%- else -%} + {{ '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {{ '<|im_end|>\n' }} + {%- endif -%} + {%- elif (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/tests/prompt_test_cogito.txt b/tests/prompt_test_cogito.txt new file mode 100644 index 00000000..a7b912ae --- /dev/null +++ b/tests/prompt_test_cogito.txt @@ -0,0 +1,87 @@ +<|im_start|>system +Enable deep thinking subroutine. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "get_car_price", "description": "Get the price of a particular car model", "parameters": {"type": "object", "properties": {"car_name": {"type": "string", "description": "The name of the car model"}}, "required": ["car_name"]}}} +{"type": "function", "function": {"name": "get_weather", "description": "This function's purpose is to get the weather of a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "where to get weather"}}, "required": ["location"]}}} +{"type": "function", "function": {"name": "python", "description": "This tool is used to execute python code. Code will be executed in a stateful Jupyter notebook environment. Python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to run"}}}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +<|im_end|> +<|im_start|>user +who is the CEO of Meetkai<|im_end|> +<|im_start|>assistant +James Kaplan is the Co-Founder and CEO of MeetKai Inc.<|im_end|> +<|im_start|>user +is the car Song more expensive than car Tang?<|im_end|> +<|im_start|>assistant +I will get the price of 2 cars and compare + +{"name": "get_car_price", "arguments": {"car_name": "Song"}} + + +{"name": "get_car_price", "arguments": {"car_name": "Tang"}} +<|im_end|> +<|im_start|>user + +{"price": {"price": "$25000"}} + + +{"price": {"price": "$20000"}} +<|im_end|> +<|im_start|>assistant +No, the car Tang is less expensive than the car Song. The car Song is priced at $25,000, while the car Tang is priced at $20,000.<|im_end|> +<|im_start|>user +what's the weather like in Hanoi?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "Hanoi"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 10}} +<|im_end|> +<|im_start|>assistant +The temperature in Hanoi is: 10 degree Celcious<|im_end|> +<|im_start|>user +Given the list of strings: '0','1','2','3','4','5' remove the number in the list that is close to 3.6 the most<|im_end|> +<|im_start|>assistant +I'll use code interpreter to handle this + +{"name": "python", "arguments": {"code": "l=[int('0'),int('1'),int('2'),int('3'),int('4'),int('5')]\nl.remove(3.6)"}} +<|im_end|> +<|im_start|>user + +ValueError: list.remove(x): x not in list +<|im_end|> +<|im_start|>assistant +I will fix the code + +{"name": "python", "arguments": {"code": "l.remove(round(3.6))\nl"}} +<|im_end|> +<|im_start|>user + +[0,1,2,3,5,] +<|im_end|> +<|im_start|>assistant +The final list is: 0,1,2,3,5<|im_end|> +<|im_start|>user +Thanks! What's the weather in San Francisco?<|im_end|> +<|im_start|>assistant + +{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}} +<|im_end|> +<|im_start|>user + +{"result": {"temperature": 20}} +<|im_end|> +<|im_start|>assistant +The temperature in San Francisco is: 20 degree Celcious<|im_end|> \ No newline at end of file From 77f26bcc95b002b92bf894a408ea46e9e7a19671 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 16 Apr 2025 10:00:18 +0700 Subject: [PATCH 27/31] fix format --- .../prompt_template/cogito_prompt_template.py | 7 +++++-- .../prompt_template/gemma3_prompt_template.py | 12 ++++++------ functionary/prompt_template/r1_original_template.py | 6 +++--- functionary/train/packing/assert_packing_loss.py | 3 +-- functionary/train/packing/monkey_patch_packing.py | 4 +++- functionary/train/pyproject.toml | 2 +- functionary/vllm_inference.py | 1 + server_sglang.py | 2 +- tests/test_prompt_creation.py | 4 ++-- 9 files changed, 23 insertions(+), 18 deletions(-) diff --git a/functionary/prompt_template/cogito_prompt_template.py b/functionary/prompt_template/cogito_prompt_template.py index 16d7eea5..dc50bc37 100644 --- a/functionary/prompt_template/cogito_prompt_template.py +++ b/functionary/prompt_template/cogito_prompt_template.py @@ -1,9 +1,12 @@ from functionary.prompt_template.base_template import PromptTemplate -from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate +from functionary.prompt_template.qwen25_text_only_template import ( + Qwen25TextOnlyPromptTemplate, +) from typing import Any, Dict, List, Literal, Optional, Tuple, Union from functionary.prompt_template import prompt_utils from functionary.openai_types import Function, Tool + class CogitoPromptTemplate(Qwen25TextOnlyPromptTemplate): version = "cogito" @@ -11,4 +14,4 @@ def get_chat_template_jinja(self) -> str: path_prefix = "./functionary/prompt_template/jinja_templates/" with open(f"{path_prefix}cogito.txt", "r") as f: template = f.read() - return template + return template diff --git a/functionary/prompt_template/gemma3_prompt_template.py b/functionary/prompt_template/gemma3_prompt_template.py index 6ab69675..accdd6ec 100644 --- a/functionary/prompt_template/gemma3_prompt_template.py +++ b/functionary/prompt_template/gemma3_prompt_template.py @@ -3,19 +3,19 @@ from functionary.openai_types import Function, Tool from functionary.prompt_template import prompt_utils -from functionary.prompt_template.qwen25_text_only_template import Qwen25TextOnlyPromptTemplate +from functionary.prompt_template.qwen25_text_only_template import ( + Qwen25TextOnlyPromptTemplate, +) class Gemma3Template(Qwen25TextOnlyPromptTemplate): version = "gemma3" - + def get_additional_tokens(self) -> List[str]: return [] def get_assistant_prefixes(self) -> List[str]: - return [ - "model\n" - ] + return ["model\n"] def get_stop_tokens_for_generation(self) -> List[str]: return [""] @@ -35,4 +35,4 @@ def get_chat_template_jinja(self) -> str: path_prefix = "./functionary/prompt_template/jinja_templates/" with open(f"{path_prefix}{self.version}.txt", "r") as f: template = f.read() - return template \ No newline at end of file + return template diff --git a/functionary/prompt_template/r1_original_template.py b/functionary/prompt_template/r1_original_template.py index c475833f..e94e8f35 100644 --- a/functionary/prompt_template/r1_original_template.py +++ b/functionary/prompt_template/r1_original_template.py @@ -124,7 +124,7 @@ def parse_assistant_response( for tool_call_str in tool_calls_strs: tool_calls.append(parse_tool_call(tool_call_str)) return { - "role": "assistant", - "content": text_content if text_content else None, - "tool_calls": tool_calls if tool_calls else None + "role": "assistant", + "content": text_content if text_content else None, + "tool_calls": tool_calls if tool_calls else None, } diff --git a/functionary/train/packing/assert_packing_loss.py b/functionary/train/packing/assert_packing_loss.py index e2fcd58b..bdb4d813 100644 --- a/functionary/train/packing/assert_packing_loss.py +++ b/functionary/train/packing/assert_packing_loss.py @@ -12,7 +12,6 @@ from packed_dataset import PackedDataset from torch.utils.data import DataLoader, Dataset from transformers import AutoTokenizer -from transformers import Gemma3ForConditionalGeneration random.seed(1) torch.manual_seed(3) @@ -80,7 +79,7 @@ def compute_loss_for_model_class( Returns: _type_: _description_ """ - model = Gemma3ForConditionalGeneration.from_pretrained( + model = transformers.AutoModelForCausalLM.from_pretrained( pretrained_path, torch_dtype=torch.bfloat16, device_map="auto", diff --git a/functionary/train/packing/monkey_patch_packing.py b/functionary/train/packing/monkey_patch_packing.py index 05f34b48..d9063779 100644 --- a/functionary/train/packing/monkey_patch_packing.py +++ b/functionary/train/packing/monkey_patch_packing.py @@ -7,7 +7,9 @@ def get_max_seqlen_in_batch(attention_mask): max_num = torch.max(attention_mask) - max_num = int(max_num) # there are cases that type of attention_mask is bfloat16 such as Gemma3ForConditionalGeneration + max_num = int( + max_num + ) # there are cases that type of attention_mask is bfloat16 such as Gemma3ForConditionalGeneration counts = [] for i in range(1, max_num + 1): counts.append( diff --git a/functionary/train/pyproject.toml b/functionary/train/pyproject.toml index 3584080f..9733576a 100644 --- a/functionary/train/pyproject.toml +++ b/functionary/train/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "typer==0.9.0", "tensorboard==2.15.1", "aenum==3.1.15", - "transformers==4.50.3", + "transformers @ git+https://github.com/huggingface/transformers.git", "json_source_map==1.0.5" ] diff --git a/functionary/vllm_inference.py b/functionary/vllm_inference.py index fcbb9366..bbc2bfd8 100644 --- a/functionary/vllm_inference.py +++ b/functionary/vllm_inference.py @@ -5,6 +5,7 @@ from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse + # from vllm.entrypoints.openai.protocol import ( # LoadLoraAdapterRequest, # UnloadLoraAdapterRequest, diff --git a/server_sglang.py b/server_sglang.py index 291ea8fa..54a330b6 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -172,7 +172,7 @@ async def stream_results() -> AsyncIterator[bytes]: @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - global tokenizer_manager, backend + global tokenizer_manager # if not args.grammar_sampling: # backend = None diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index 0d71a7cb..42b939a9 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -44,8 +44,8 @@ def __init__(self, *args, **kwargs): # "v3-llama3.1": "meetkai/functionary-small-v3.1", # "r1": "deepseek-ai/DeepSeek-R1", # "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", - # "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", - # "qwen2.5-text-only": "Qwen/Qwen2.5-32B", + "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "qwen2.5-text-only": "Qwen/Qwen2.5-32B", "gemma3": "google/gemma-3-27b-it", "cogito": "deepcogito/cogito-v1-preview-qwen-32B", } From 45f05372a3651f6af005cc6aeaeb44f573ba643b Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 16 Apr 2025 10:26:49 +0700 Subject: [PATCH 28/31] remove unused global --- server_sglang.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/server_sglang.py b/server_sglang.py index 54a330b6..0e1a881d 100644 --- a/server_sglang.py +++ b/server_sglang.py @@ -172,10 +172,6 @@ async def stream_results() -> AsyncIterator[bytes]: @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - global tokenizer_manager - - # if not args.grammar_sampling: - # backend = None return await v1_chat_completions(tokenizer_manager, None, raw_request, served_model) From 3067d1c6486d12189e3c9eebe0d403d50443ec1c Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 16 Apr 2025 10:51:53 +0700 Subject: [PATCH 29/31] fix syntax --- functionary/prompt_template/r1_original_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functionary/prompt_template/r1_original_template.py b/functionary/prompt_template/r1_original_template.py index e94e8f35..a58fce4c 100644 --- a/functionary/prompt_template/r1_original_template.py +++ b/functionary/prompt_template/r1_original_template.py @@ -96,7 +96,7 @@ def get_stop_tokens_for_generation(self) -> List[str]: return ["<|end▁of▁sentence|>"] def parse_assistant_response( - self, llm_output: str, tool_choice: Any | None + self, llm_output: str, tool_choice: Optional[Any] = None ) -> Dict: for stop in self.get_stop_tokens_for_generation(): if llm_output.endswith(stop): From b867cfbd1f2f9d83f79bbd1c4e5baf347977ee43 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 16 Apr 2025 11:51:50 +0700 Subject: [PATCH 30/31] fix: only test prompt templates that we added to the list, instead of test all --- tests/test_request_handling.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/test_request_handling.py b/tests/test_request_handling.py index 28a5c642..82e7cb51 100644 --- a/tests/test_request_handling.py +++ b/tests/test_request_handling.py @@ -22,8 +22,8 @@ LlavaLlama, PromptTemplate, PromptTemplateV2, - Llama31ReasoningTemplate, - get_available_prompt_template_versions, + Llama31ReasoningTemplate, + Qwen25TextOnlyPromptTemplate, ) from functionary.prompt_template.prompt_utils import ( enforce_tool_choice, @@ -145,7 +145,7 @@ def __init__(self, *args, **kwargs): {"type": "function", "function": self.default_functions[0]} ] self.default_python_args = "from datetime import date\n# Find today's date\ntoday = date.today()\ntoday" - self.test_prompt_templates = get_available_prompt_template_versions() + self.prompt_template_to_tokenizer_name_mapping = { PromptTemplateV2: "meetkai/functionary-small-v2.4", Llama3Template: "meetkai/functionary-small-v2.5", @@ -153,7 +153,12 @@ def __init__(self, *args, **kwargs): Llama31Template: "meetkai/functionary-small-v3.1", Llama31ReasoningTemplate: "meetkai/functionary-small-v3.1", LlavaLlama: "lmms-lab/llama3-llava-next-8b", + Qwen25TextOnlyPromptTemplate: "Qwen/Qwen2.5-32B-Instruct", } + # only test the prompt templates that we will use, no need to test all the prompt templates + self.test_prompt_templates = list( + self.prompt_template_to_tokenizer_name_mapping.keys() + ) # get_available_prompt_template_versions() self.default_text_str = "Normal text generation" self.default_tool_call_name = "get_weather" self.default_tool_call_args = [ @@ -500,13 +505,14 @@ async def test_prompt_template_to_tokenizer(self): # Test whether all prompt templates are included in template_to_tokenizer_mapping yet for prompt_template in self.test_prompt_templates: self.assertIn( - type(prompt_template), + prompt_template, self.prompt_template_to_tokenizer_name_mapping.keys(), f"Prompt template `{type(prompt_template)}` is not included in template_to_tokenizer_mapping yet.", ) async def test_request_handling(self): - for prompt_template in self.test_prompt_templates: + for prompt_template_class in self.test_prompt_templates: + prompt_template = prompt_template_class.get_prompt_template() for test_case in self.request_handling_test_cases: raw_response = generate_raw_response( gen_text=test_case["gen_text"], @@ -590,10 +596,10 @@ async def wrap_generator(tokenizer, raw_response, test_case): yield "", "stop" # yield "", test_case["expected_finish_reason"] - for prompt_template in self.test_prompt_templates: - + for prompt_template_class in self.test_prompt_templates: + prompt_template = prompt_template_class.get_prompt_template() tokenizer = AutoTokenizer.from_pretrained( - self.prompt_template_to_tokenizer_name_mapping[type(prompt_template)] + self.prompt_template_to_tokenizer_name_mapping[prompt_template_class] ) special_tokens = { "additional_special_tokens": prompt_template.get_additional_tokens() From faca47a44e20d7671b3ea72610d86bf5656fdee3 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Wed, 16 Apr 2025 12:20:34 +0700 Subject: [PATCH 31/31] change to public repo --- tests/test_prompt_creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_prompt_creation.py b/tests/test_prompt_creation.py index 42b939a9..18eb7003 100644 --- a/tests/test_prompt_creation.py +++ b/tests/test_prompt_creation.py @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): # "r1_distilled_qwen": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", "r1_distilled_llama": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "qwen2.5-text-only": "Qwen/Qwen2.5-32B", - "gemma3": "google/gemma-3-27b-it", + "gemma3": "unsloth/gemma-3-12b-it", "cogito": "deepcogito/cogito-v1-preview-qwen-32B", } self.image_template_version_to_model_name = {