Skip to content

Commit 0b067ad

Browse files
test(chat): enhance message service pagination tests with branch separation
- Updated the `TestMessageServicePagination` class in `test_message_service.py` to create two separate branches of messages for testing, improving the coverage of pagination scenarios. - Added new tests to validate querying messages from each branch independently, ensuring no overlap between branches and confirming correct behavior when using the `after` parameter. - Enhanced docstrings for clarity regarding the structure of the test messages and the expected outcomes during pagination.
1 parent 74015d0 commit 0b067ad

File tree

1 file changed

+148
-37
lines changed

1 file changed

+148
-37
lines changed

tests/integration/chat/api/test_message_service.py

Lines changed: 148 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,14 @@ def _messages(
5252
_workspace_id: UUID,
5353
_thread_id: str,
5454
) -> list[Message]:
55-
"""Create a chain of 10 messages for testing."""
55+
"""Create two branches of messages for testing.
56+
57+
Branch 1: Messages 0-9 (linear chain from ROOT)
58+
Branch 2: Messages 10-19 (separate linear chain from ROOT)
59+
"""
5660
_created_messages: list[Message] = []
61+
62+
# Create first branch: messages 0-9 (linear chain)
5763
for i in range(10):
5864
_msg = _message_service.create(
5965
workspace_id=_workspace_id,
@@ -69,6 +75,24 @@ def _messages(
6975
),
7076
)
7177
_created_messages.append(_msg)
78+
79+
# Create second branch: messages 10-19 (separate linear chain from ROOT)
80+
for i in range(10, 20):
81+
_msg = _message_service.create(
82+
workspace_id=_workspace_id,
83+
thread_id=_thread_id,
84+
params=MessageCreate(
85+
role="user" if i % 2 == 0 else "assistant",
86+
content=f"Test message {i}",
87+
parent_id=(
88+
ROOT_MESSAGE_PARENT_ID
89+
if i == 10
90+
else _created_messages[i - 1].id
91+
),
92+
),
93+
)
94+
_created_messages.append(_msg)
95+
7296
return _created_messages
7397

7498
def test_list_asc_without_after(
@@ -79,20 +103,21 @@ def test_list_asc_without_after(
79103
_messages: list[Message],
80104
) -> None:
81105
"""Test listing messages in ascending order without 'after' parameter."""
106+
# Without before/after, gets latest branch (branch 2)
82107
_response = _message_service.list_(
83108
workspace_id=_workspace_id,
84109
thread_id=_thread_id,
85110
query=ListQuery(limit=5, order="asc"),
86111
)
87112

88113
assert len(_response.data) == 5
89-
# Should get the first 5 messages in ascending order (0, 1, 2, 3, 4)
114+
# Should get the first 5 messages from branch 2 (10, 11, 12, 13, 14)
90115
assert [_msg.content for _msg in _response.data] == [
91-
"Test message 0",
92-
"Test message 1",
93-
"Test message 2",
94-
"Test message 3",
95-
"Test message 4",
116+
"Test message 10",
117+
"Test message 11",
118+
"Test message 12",
119+
"Test message 13",
120+
"Test message 14",
96121
]
97122
assert _response.has_more is True
98123

@@ -104,7 +129,7 @@ def test_list_asc_with_after(
104129
_messages: list[Message],
105130
) -> None:
106131
"""Test listing messages in ascending order with 'after' parameter."""
107-
# First, get the first page
132+
# First, get the first page from branch 2 (default)
108133
_first_page = _message_service.list_(
109134
workspace_id=_workspace_id,
110135
thread_id=_thread_id,
@@ -113,9 +138,9 @@ def test_list_asc_with_after(
113138

114139
assert len(_first_page.data) == 3
115140
assert [_msg.content for _msg in _first_page.data] == [
116-
"Test message 0",
117-
"Test message 1",
118-
"Test message 2",
141+
"Test message 10",
142+
"Test message 11",
143+
"Test message 12",
119144
]
120145

121146
# Now get the second page using 'after'
@@ -126,11 +151,11 @@ def test_list_asc_with_after(
126151
)
127152

128153
assert len(_second_page.data) == 3
129-
# Should get the next 3 messages (3, 4, 5)
154+
# Should get the next 3 messages (13, 14, 15)
130155
assert [_msg.content for _msg in _second_page.data] == [
131-
"Test message 3",
132-
"Test message 4",
133-
"Test message 5",
156+
"Test message 13",
157+
"Test message 14",
158+
"Test message 15",
134159
]
135160
assert _second_page.has_more is True
136161

@@ -142,11 +167,11 @@ def test_list_asc_with_after(
142167
)
143168

144169
assert len(_third_page.data) == 3
145-
# Should get the next 3 messages (6, 7, 8)
170+
# Should get the next 3 messages (16, 17, 18)
146171
assert [_msg.content for _msg in _third_page.data] == [
147-
"Test message 6",
148-
"Test message 7",
149-
"Test message 8",
172+
"Test message 16",
173+
"Test message 17",
174+
"Test message 18",
150175
]
151176

152177
def test_list_desc_without_after(
@@ -157,20 +182,21 @@ def test_list_desc_without_after(
157182
_messages: list[Message],
158183
) -> None:
159184
"""Test listing messages in descending order without 'after' parameter."""
185+
# Without before/after, gets latest branch (branch 2)
160186
_response = _message_service.list_(
161187
workspace_id=_workspace_id,
162188
thread_id=_thread_id,
163189
query=ListQuery(limit=5, order="desc"),
164190
)
165191

166192
assert len(_response.data) == 5
167-
# Should get the last 5 messages in descending order (9, 8, 7, 6, 5)
193+
# Should get the last 5 messages from branch 2 (19, 18, 17, 16, 15)
168194
assert [_msg.content for _msg in _response.data] == [
169-
"Test message 9",
170-
"Test message 8",
171-
"Test message 7",
172-
"Test message 6",
173-
"Test message 5",
195+
"Test message 19",
196+
"Test message 18",
197+
"Test message 17",
198+
"Test message 16",
199+
"Test message 15",
174200
]
175201
assert _response.has_more is True
176202

@@ -182,7 +208,7 @@ def test_list_desc_with_after(
182208
_messages: list[Message],
183209
) -> None:
184210
"""Test listing messages in descending order with 'after' parameter."""
185-
# First, get the first page
211+
# First, get the first page from branch 2 (default)
186212
_first_page = _message_service.list_(
187213
workspace_id=_workspace_id,
188214
thread_id=_thread_id,
@@ -191,9 +217,9 @@ def test_list_desc_with_after(
191217

192218
assert len(_first_page.data) == 3
193219
assert [_msg.content for _msg in _first_page.data] == [
194-
"Test message 9",
195-
"Test message 8",
196-
"Test message 7",
220+
"Test message 19",
221+
"Test message 18",
222+
"Test message 17",
197223
]
198224

199225
# Now get the second page using 'after'
@@ -204,11 +230,11 @@ def test_list_desc_with_after(
204230
)
205231

206232
assert len(_second_page.data) == 3
207-
# Should get the previous 3 messages (6, 5, 4)
233+
# Should get the previous 3 messages (16, 15, 14)
208234
assert [_msg.content for _msg in _second_page.data] == [
209-
"Test message 6",
210-
"Test message 5",
211-
"Test message 4",
235+
"Test message 16",
236+
"Test message 15",
237+
"Test message 14",
212238
]
213239

214240
def test_iter_asc(
@@ -219,6 +245,7 @@ def test_iter_asc(
219245
_messages: list[Message],
220246
) -> None:
221247
"""Test iterating through messages in ascending order."""
248+
# Without before/after, iter returns the latest branch (branch 2)
222249
_collected_messages: list[Message] = list(
223250
_message_service.iter(
224251
workspace_id=_workspace_id,
@@ -228,10 +255,10 @@ def test_iter_asc(
228255
)
229256
)
230257

231-
# Should get all 10 messages in ascending order
258+
# Should get all 10 messages from branch 2 in ascending order
232259
assert len(_collected_messages) == 10
233260
assert [_msg.content for _msg in _collected_messages] == [
234-
f"Test message {i}" for i in range(10)
261+
f"Test message {i}" for i in range(10, 20)
235262
]
236263

237264
def test_iter_desc(
@@ -242,6 +269,7 @@ def test_iter_desc(
242269
_messages: list[Message],
243270
) -> None:
244271
"""Test iterating through messages in descending order."""
272+
# Without before/after, iter returns the latest branch (branch 2)
245273
_collected_messages: list[Message] = list(
246274
_message_service.iter(
247275
workspace_id=_workspace_id,
@@ -251,10 +279,10 @@ def test_iter_desc(
251279
)
252280
)
253281

254-
# Should get all 10 messages in descending order
282+
# Should get all 10 messages from branch 2 in descending order
255283
assert len(_collected_messages) == 10
256284
assert [_msg.content for _msg in _collected_messages] == [
257-
f"Test message {i}" for i in range(9, -1, -1)
285+
f"Test message {i}" for i in range(19, 9, -1)
258286
]
259287

260288
def test_list_asc_with_before(
@@ -351,3 +379,86 @@ def test_list_desc_with_before_paginated(
351379
"Test message 7",
352380
]
353381
assert _response.has_more is True
382+
383+
def test_list_branch1_with_after(
384+
self,
385+
_message_service: MessageService,
386+
_workspace_id: UUID,
387+
_thread_id: str,
388+
_messages: list[Message],
389+
) -> None:
390+
"""Test querying branch 1 by starting from its first message."""
391+
# Query from the first message of branch 1 downward
392+
_response = _message_service.list_(
393+
workspace_id=_workspace_id,
394+
thread_id=_thread_id,
395+
query=ListQuery(limit=20, order="asc", after=_messages[0].id),
396+
)
397+
398+
# Should get messages 1-9 from branch 1 (excluding message 0)
399+
assert len(_response.data) == 9
400+
assert [_msg.content for _msg in _response.data] == [
401+
f"Test message {i}" for i in range(1, 10)
402+
]
403+
assert _response.has_more is False
404+
405+
def test_list_branch2_with_after(
406+
self,
407+
_message_service: MessageService,
408+
_workspace_id: UUID,
409+
_thread_id: str,
410+
_messages: list[Message],
411+
) -> None:
412+
"""Test querying branch 2 by starting from its first message."""
413+
# Query from the first message of branch 2 downward
414+
_response = _message_service.list_(
415+
workspace_id=_workspace_id,
416+
thread_id=_thread_id,
417+
query=ListQuery(limit=20, order="asc", after=_messages[10].id),
418+
)
419+
420+
# Should get messages 11-19 from branch 2 (excluding message 10)
421+
assert len(_response.data) == 9
422+
assert [_msg.content for _msg in _response.data] == [
423+
f"Test message {i}" for i in range(11, 20)
424+
]
425+
assert _response.has_more is False
426+
427+
def test_list_branches_separately(
428+
self,
429+
_message_service: MessageService,
430+
_workspace_id: UUID,
431+
_thread_id: str,
432+
_messages: list[Message],
433+
) -> None:
434+
"""Test that the two branches are separate by querying from each."""
435+
# Get branch 1: query from branch 1's last message going up
436+
_branch1_response = _message_service.list_(
437+
workspace_id=_workspace_id,
438+
thread_id=_thread_id,
439+
query=ListQuery(limit=20, order="desc", after=_messages[9].id),
440+
)
441+
442+
# Should get messages 9-0 from branch 1 in descending order
443+
assert len(_branch1_response.data) == 9
444+
assert [_msg.content for _msg in _branch1_response.data] == [
445+
f"Test message {i}" for i in range(8, -1, -1)
446+
]
447+
448+
# Get branch 2: query from branch 2's last message going up
449+
_branch2_response = _message_service.list_(
450+
workspace_id=_workspace_id,
451+
thread_id=_thread_id,
452+
query=ListQuery(limit=20, order="desc", after=_messages[19].id),
453+
)
454+
455+
# Should get messages 19-10 from branch 2 in descending order
456+
assert len(_branch2_response.data) == 9
457+
assert [_msg.content for _msg in _branch2_response.data] == [
458+
f"Test message {i}" for i in range(18, 9, -1)
459+
]
460+
461+
# Verify no overlap between branches
462+
_branch1_ids = {_msg.id for _msg in _branch1_response.data}
463+
_branch2_ids = {_msg.id for _msg in _branch2_response.data}
464+
assert _branch1_ids.isdisjoint(_branch2_ids)

0 commit comments

Comments
 (0)