@@ -52,8 +52,14 @@ def _messages(
5252 _workspace_id : UUID ,
5353 _thread_id : str ,
5454 ) -> list [Message ]:
55- """Create a chain of 10 messages for testing."""
55+ """Create two branches of messages for testing.
56+
57+ Branch 1: Messages 0-9 (linear chain from ROOT)
58+ Branch 2: Messages 10-19 (separate linear chain from ROOT)
59+ """
5660 _created_messages : list [Message ] = []
61+
62+ # Create first branch: messages 0-9 (linear chain)
5763 for i in range (10 ):
5864 _msg = _message_service .create (
5965 workspace_id = _workspace_id ,
@@ -69,6 +75,24 @@ def _messages(
6975 ),
7076 )
7177 _created_messages .append (_msg )
78+
79+ # Create second branch: messages 10-19 (separate linear chain from ROOT)
80+ for i in range (10 , 20 ):
81+ _msg = _message_service .create (
82+ workspace_id = _workspace_id ,
83+ thread_id = _thread_id ,
84+ params = MessageCreate (
85+ role = "user" if i % 2 == 0 else "assistant" ,
86+ content = f"Test message { i } " ,
87+ parent_id = (
88+ ROOT_MESSAGE_PARENT_ID
89+ if i == 10
90+ else _created_messages [i - 1 ].id
91+ ),
92+ ),
93+ )
94+ _created_messages .append (_msg )
95+
7296 return _created_messages
7397
7498 def test_list_asc_without_after (
@@ -79,20 +103,21 @@ def test_list_asc_without_after(
79103 _messages : list [Message ],
80104 ) -> None :
81105 """Test listing messages in ascending order without 'after' parameter."""
106+ # Without before/after, gets latest branch (branch 2)
82107 _response = _message_service .list_ (
83108 workspace_id = _workspace_id ,
84109 thread_id = _thread_id ,
85110 query = ListQuery (limit = 5 , order = "asc" ),
86111 )
87112
88113 assert len (_response .data ) == 5
89- # Should get the first 5 messages in ascending order (0, 1, 2, 3, 4 )
114+ # Should get the first 5 messages from branch 2 (10, 11, 12, 13, 14 )
90115 assert [_msg .content for _msg in _response .data ] == [
91- "Test message 0 " ,
92- "Test message 1 " ,
93- "Test message 2 " ,
94- "Test message 3 " ,
95- "Test message 4 " ,
116+ "Test message 10 " ,
117+ "Test message 11 " ,
118+ "Test message 12 " ,
119+ "Test message 13 " ,
120+ "Test message 14 " ,
96121 ]
97122 assert _response .has_more is True
98123
@@ -104,7 +129,7 @@ def test_list_asc_with_after(
104129 _messages : list [Message ],
105130 ) -> None :
106131 """Test listing messages in ascending order with 'after' parameter."""
107- # First, get the first page
132+ # First, get the first page from branch 2 (default)
108133 _first_page = _message_service .list_ (
109134 workspace_id = _workspace_id ,
110135 thread_id = _thread_id ,
@@ -113,9 +138,9 @@ def test_list_asc_with_after(
113138
114139 assert len (_first_page .data ) == 3
115140 assert [_msg .content for _msg in _first_page .data ] == [
116- "Test message 0 " ,
117- "Test message 1 " ,
118- "Test message 2 " ,
141+ "Test message 10 " ,
142+ "Test message 11 " ,
143+ "Test message 12 " ,
119144 ]
120145
121146 # Now get the second page using 'after'
@@ -126,11 +151,11 @@ def test_list_asc_with_after(
126151 )
127152
128153 assert len (_second_page .data ) == 3
129- # Should get the next 3 messages (3, 4, 5 )
154+ # Should get the next 3 messages (13, 14, 15 )
130155 assert [_msg .content for _msg in _second_page .data ] == [
131- "Test message 3 " ,
132- "Test message 4 " ,
133- "Test message 5 " ,
156+ "Test message 13 " ,
157+ "Test message 14 " ,
158+ "Test message 15 " ,
134159 ]
135160 assert _second_page .has_more is True
136161
@@ -142,11 +167,11 @@ def test_list_asc_with_after(
142167 )
143168
144169 assert len (_third_page .data ) == 3
145- # Should get the next 3 messages (6, 7, 8 )
170+ # Should get the next 3 messages (16, 17, 18 )
146171 assert [_msg .content for _msg in _third_page .data ] == [
147- "Test message 6 " ,
148- "Test message 7 " ,
149- "Test message 8 " ,
172+ "Test message 16 " ,
173+ "Test message 17 " ,
174+ "Test message 18 " ,
150175 ]
151176
152177 def test_list_desc_without_after (
@@ -157,20 +182,21 @@ def test_list_desc_without_after(
157182 _messages : list [Message ],
158183 ) -> None :
159184 """Test listing messages in descending order without 'after' parameter."""
185+ # Without before/after, gets latest branch (branch 2)
160186 _response = _message_service .list_ (
161187 workspace_id = _workspace_id ,
162188 thread_id = _thread_id ,
163189 query = ListQuery (limit = 5 , order = "desc" ),
164190 )
165191
166192 assert len (_response .data ) == 5
167- # Should get the last 5 messages in descending order (9, 8, 7, 6, 5 )
193+ # Should get the last 5 messages from branch 2 (19, 18, 17, 16, 15 )
168194 assert [_msg .content for _msg in _response .data ] == [
169- "Test message 9 " ,
170- "Test message 8 " ,
171- "Test message 7 " ,
172- "Test message 6 " ,
173- "Test message 5 " ,
195+ "Test message 19 " ,
196+ "Test message 18 " ,
197+ "Test message 17 " ,
198+ "Test message 16 " ,
199+ "Test message 15 " ,
174200 ]
175201 assert _response .has_more is True
176202
@@ -182,7 +208,7 @@ def test_list_desc_with_after(
182208 _messages : list [Message ],
183209 ) -> None :
184210 """Test listing messages in descending order with 'after' parameter."""
185- # First, get the first page
211+ # First, get the first page from branch 2 (default)
186212 _first_page = _message_service .list_ (
187213 workspace_id = _workspace_id ,
188214 thread_id = _thread_id ,
@@ -191,9 +217,9 @@ def test_list_desc_with_after(
191217
192218 assert len (_first_page .data ) == 3
193219 assert [_msg .content for _msg in _first_page .data ] == [
194- "Test message 9 " ,
195- "Test message 8 " ,
196- "Test message 7 " ,
220+ "Test message 19 " ,
221+ "Test message 18 " ,
222+ "Test message 17 " ,
197223 ]
198224
199225 # Now get the second page using 'after'
@@ -204,11 +230,11 @@ def test_list_desc_with_after(
204230 )
205231
206232 assert len (_second_page .data ) == 3
207- # Should get the previous 3 messages (6, 5, 4 )
233+ # Should get the previous 3 messages (16, 15, 14 )
208234 assert [_msg .content for _msg in _second_page .data ] == [
209- "Test message 6 " ,
210- "Test message 5 " ,
211- "Test message 4 " ,
235+ "Test message 16 " ,
236+ "Test message 15 " ,
237+ "Test message 14 " ,
212238 ]
213239
214240 def test_iter_asc (
@@ -219,6 +245,7 @@ def test_iter_asc(
219245 _messages : list [Message ],
220246 ) -> None :
221247 """Test iterating through messages in ascending order."""
248+ # Without before/after, iter returns the latest branch (branch 2)
222249 _collected_messages : list [Message ] = list (
223250 _message_service .iter (
224251 workspace_id = _workspace_id ,
@@ -228,10 +255,10 @@ def test_iter_asc(
228255 )
229256 )
230257
231- # Should get all 10 messages in ascending order
258+ # Should get all 10 messages from branch 2 in ascending order
232259 assert len (_collected_messages ) == 10
233260 assert [_msg .content for _msg in _collected_messages ] == [
234- f"Test message { i } " for i in range (10 )
261+ f"Test message { i } " for i in range (10 , 20 )
235262 ]
236263
237264 def test_iter_desc (
@@ -242,6 +269,7 @@ def test_iter_desc(
242269 _messages : list [Message ],
243270 ) -> None :
244271 """Test iterating through messages in descending order."""
272+ # Without before/after, iter returns the latest branch (branch 2)
245273 _collected_messages : list [Message ] = list (
246274 _message_service .iter (
247275 workspace_id = _workspace_id ,
@@ -251,10 +279,10 @@ def test_iter_desc(
251279 )
252280 )
253281
254- # Should get all 10 messages in descending order
282+ # Should get all 10 messages from branch 2 in descending order
255283 assert len (_collected_messages ) == 10
256284 assert [_msg .content for _msg in _collected_messages ] == [
257- f"Test message { i } " for i in range (9 , - 1 , - 1 )
285+ f"Test message { i } " for i in range (19 , 9 , - 1 )
258286 ]
259287
260288 def test_list_asc_with_before (
@@ -351,3 +379,86 @@ def test_list_desc_with_before_paginated(
351379 "Test message 7" ,
352380 ]
353381 assert _response .has_more is True
382+
383+ def test_list_branch1_with_after (
384+ self ,
385+ _message_service : MessageService ,
386+ _workspace_id : UUID ,
387+ _thread_id : str ,
388+ _messages : list [Message ],
389+ ) -> None :
390+ """Test querying branch 1 by starting from its first message."""
391+ # Query from the first message of branch 1 downward
392+ _response = _message_service .list_ (
393+ workspace_id = _workspace_id ,
394+ thread_id = _thread_id ,
395+ query = ListQuery (limit = 20 , order = "asc" , after = _messages [0 ].id ),
396+ )
397+
398+ # Should get messages 1-9 from branch 1 (excluding message 0)
399+ assert len (_response .data ) == 9
400+ assert [_msg .content for _msg in _response .data ] == [
401+ f"Test message { i } " for i in range (1 , 10 )
402+ ]
403+ assert _response .has_more is False
404+
405+ def test_list_branch2_with_after (
406+ self ,
407+ _message_service : MessageService ,
408+ _workspace_id : UUID ,
409+ _thread_id : str ,
410+ _messages : list [Message ],
411+ ) -> None :
412+ """Test querying branch 2 by starting from its first message."""
413+ # Query from the first message of branch 2 downward
414+ _response = _message_service .list_ (
415+ workspace_id = _workspace_id ,
416+ thread_id = _thread_id ,
417+ query = ListQuery (limit = 20 , order = "asc" , after = _messages [10 ].id ),
418+ )
419+
420+ # Should get messages 11-19 from branch 2 (excluding message 10)
421+ assert len (_response .data ) == 9
422+ assert [_msg .content for _msg in _response .data ] == [
423+ f"Test message { i } " for i in range (11 , 20 )
424+ ]
425+ assert _response .has_more is False
426+
427+ def test_list_branches_separately (
428+ self ,
429+ _message_service : MessageService ,
430+ _workspace_id : UUID ,
431+ _thread_id : str ,
432+ _messages : list [Message ],
433+ ) -> None :
434+ """Test that the two branches are separate by querying from each."""
435+ # Get branch 1: query from branch 1's last message going up
436+ _branch1_response = _message_service .list_ (
437+ workspace_id = _workspace_id ,
438+ thread_id = _thread_id ,
439+ query = ListQuery (limit = 20 , order = "desc" , after = _messages [9 ].id ),
440+ )
441+
442+ # Should get messages 9-0 from branch 1 in descending order
443+ assert len (_branch1_response .data ) == 9
444+ assert [_msg .content for _msg in _branch1_response .data ] == [
445+ f"Test message { i } " for i in range (8 , - 1 , - 1 )
446+ ]
447+
448+ # Get branch 2: query from branch 2's last message going up
449+ _branch2_response = _message_service .list_ (
450+ workspace_id = _workspace_id ,
451+ thread_id = _thread_id ,
452+ query = ListQuery (limit = 20 , order = "desc" , after = _messages [19 ].id ),
453+ )
454+
455+ # Should get messages 19-10 from branch 2 in descending order
456+ assert len (_branch2_response .data ) == 9
457+ assert [_msg .content for _msg in _branch2_response .data ] == [
458+ f"Test message { i } " for i in range (18 , 9 , - 1 )
459+ ]
460+
461+ # Verify no overlap between branches
462+ _branch1_ids = {_msg .id for _msg in _branch1_response .data }
463+ _branch2_ids = {_msg .id for _msg in _branch2_response .data }
464+ assert _branch1_ids .isdisjoint (_branch2_ids )
0 commit comments