diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 4a46103ec..9ee9f133f 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -292,7 +292,9 @@ def get_prompt(self) -> str: ret += message + "" else: ret += "" - ret = ret.rstrip("") + seps[0] + if ret.endswith(""): + ret = ret[: -len("")] + ret = ret + seps[0] return ret elif self.sep_style == SeparatorStyle.GEMMA: ret = "" diff --git a/tests/test_conversation.py b/tests/test_conversation.py new file mode 100644 index 000000000..7b7a14820 --- /dev/null +++ b/tests/test_conversation.py @@ -0,0 +1,38 @@ +""" +Usage: +python3 -m unittest tests.test_conversation +""" + +import unittest + +from fastchat.conversation import get_conv_template + + +class TestYuan2Template(unittest.TestCase): + def test_message_ending_in_n_is_preserved(self): + """The YUAN2 template must only strip the trailing ```` separator, + not characters (``<``, ``n``, ``>``) that belong to the message.""" + conv = get_conv_template("yuan2") + conv.append_message(conv.roles[0], "What is the value of n") + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + self.assertIn("What is the value of n", prompt) + self.assertEqual(prompt, "What is the value of n") + + def test_message_ending_in_angle_bracket_is_preserved(self): + conv = get_conv_template("yuan2") + conv.append_message(conv.roles[0], "compare a < b") + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + self.assertIn("compare a < b", prompt) + + def test_separator_between_messages_is_kept(self): + conv = get_conv_template("yuan2") + conv.append_message(conv.roles[0], "hello") + conv.append_message(conv.roles[1], "world") + prompt = conv.get_prompt() + self.assertEqual(prompt, "helloworld") + + +if __name__ == "__main__": + unittest.main()