From 1dc1e2c8a3bd1de84e65e3a6abc8e5d4503af729 Mon Sep 17 00:00:00 2001 From: JSap0914 Date: Tue, 16 Jun 2026 17:23:01 +0900 Subject: [PATCH] Fix YUAN2 conversation template stripping message characters The YUAN2 separator style joined messages with the literal '' separator and then called ret.rstrip('') to drop the trailing separator before appending seps[0]. str.rstrip treats its argument as a set of characters, so it stripped every trailing '<', 'n' and '>' character, corrupting any message that ended in one of them (e.g. 'What is the value of n' became 'What is the value of '). Remove only the trailing '' separator with an explicit suffix check so message content is preserved. Adds tests/test_conversation.py covering the YUAN2 template. --- fastchat/conversation.py | 4 +++- tests/test_conversation.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 tests/test_conversation.py diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 4a46103ec..9ee9f133f 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -292,7 +292,9 @@ def get_prompt(self) -> str: ret += message + "" else: ret += "" - ret = ret.rstrip("") + seps[0] + if ret.endswith(""): + ret = ret[: -len("")] + ret = ret + seps[0] return ret elif self.sep_style == SeparatorStyle.GEMMA: ret = "" diff --git a/tests/test_conversation.py b/tests/test_conversation.py new file mode 100644 index 000000000..7b7a14820 --- /dev/null +++ b/tests/test_conversation.py @@ -0,0 +1,38 @@ +""" +Usage: +python3 -m unittest tests.test_conversation +""" + +import unittest + +from fastchat.conversation import get_conv_template + + +class TestYuan2Template(unittest.TestCase): + def test_message_ending_in_n_is_preserved(self): + """The YUAN2 template must only strip the trailing ```` separator, + not characters (``<``, ``n``, ``>``) that belong to the message.""" + conv = get_conv_template("yuan2") + conv.append_message(conv.roles[0], "What is the value of n") + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + self.assertIn("What is the value of n", prompt) + self.assertEqual(prompt, "What is the value of n") + + def test_message_ending_in_angle_bracket_is_preserved(self): + conv = get_conv_template("yuan2") + conv.append_message(conv.roles[0], "compare a < b") + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + self.assertIn("compare a < b", prompt) + + def test_separator_between_messages_is_kept(self): + conv = get_conv_template("yuan2") + conv.append_message(conv.roles[0], "hello") + conv.append_message(conv.roles[1], "world") + prompt = conv.get_prompt() + self.assertEqual(prompt, "helloworld") + + +if __name__ == "__main__": + unittest.main()