@@ -82,11 +82,16 @@ def fill_kernel(init_cuda):
8282 return mod .get_kernel ("fill" )
8383
8484
85- def _aligned_half (sm ):
86- """Compute half the SM count, rounded down to min_partition_size alignment."""
85+ def _safe_two_group_count (sm ):
86+ """Return a safe per-group SM count for a 2-group split.
87+
88+ Uses min_partition_size which is always a valid split size regardless
89+ of hardware topology. Returns None if the device doesn't have enough SMs.
90+ """
8791 min_size = sm .min_partition_size
88- half = (sm .sm_count // 2 // min_size ) * min_size
89- return half
92+ if sm .sm_count < 2 * min_size :
93+ return None
94+ return min_size
9095
9196
9297@contextlib .contextmanager
@@ -238,30 +243,33 @@ def test_discovery_respects_alignment(self, sm_resource):
238243 assert groups [0 ].sm_count % sm_resource .coscheduled_alignment == 0
239244
240245 def test_two_groups (self , sm_resource ):
241- """Two-group split with explicit aligned counts ."""
242- half = _aligned_half (sm_resource )
243- if half < sm_resource . min_partition_size :
246+ """Two-group split with min_partition_size (always topology-safe) ."""
247+ count = _safe_two_group_count (sm_resource )
248+ if count is None :
244249 pytest .skip ("Not enough SMs for a 2-group split" )
245250
246- groups , rem = sm_resource .split (SMResourceOptions (count = (half , half )))
251+ groups , rem = sm_resource .split (SMResourceOptions (count = (count , count )))
247252
248253 assert len (groups ) == 2
249- assert groups [0 ].sm_count > 0
250- assert groups [1 ].sm_count > 0
254+ assert groups [0 ].sm_count >= count
255+ assert groups [1 ].sm_count >= count
251256 total = groups [0 ].sm_count + groups [1 ].sm_count + rem .sm_count
252257 assert total <= sm_resource .sm_count
253258
254- def test_two_groups_each_meets_request (self , sm_resource ):
255- min_size = sm_resource .min_partition_size
256- half = _aligned_half (sm_resource )
257- if half < min_size :
258- pytest .skip ("Not enough SMs for a 2-group split" )
259+ def test_two_groups_backfill (self , sm_resource ):
260+ """Two-group split with backfill allows larger partitions."""
261+ align = sm_resource .coscheduled_alignment
262+ if align == 0 :
263+ align = sm_resource .min_partition_size
264+ half = (sm_resource .sm_count // 2 // align ) * align
265+ if half < sm_resource .min_partition_size :
266+ pytest .skip ("Not enough SMs for a 2-group backfill split" )
259267
260- groups , _ = sm_resource .split (SMResourceOptions (count = (min_size , min_size ) ))
268+ groups , rem = sm_resource .split (SMResourceOptions (count = (half , half ), backfill = True ))
261269
262270 assert len (groups ) == 2
263- assert groups [0 ].sm_count >= min_size
264- assert groups [1 ].sm_count >= min_size
271+ assert groups [0 ].sm_count >= half
272+ assert groups [1 ].sm_count >= half
265273
266274 def test_dry_run_matches_real (self , sm_resource ):
267275 """Dry-run reports the same SM counts as a real split."""
@@ -352,11 +360,11 @@ def test_green_ctx_sm_resources(self, green_ctx, sm_resource):
352360
353361 def test_green_ctx_resources_reflect_partition (self , init_cuda , sm_resource ):
354362 """Two green contexts should have disjoint SM partitions."""
355- half = _aligned_half (sm_resource )
356- if half < sm_resource . min_partition_size :
363+ count = _safe_two_group_count (sm_resource )
364+ if count is None :
357365 pytest .skip ("Not enough SMs for a 2-group split" )
358366
359- groups , _ = sm_resource .split (SMResourceOptions (count = (half , half )))
367+ groups , _ = sm_resource .split (SMResourceOptions (count = (count , count )))
360368
361369 ctx_a = ctx_b = None
362370 try :
@@ -425,11 +433,11 @@ def test_launch_and_verify(self, init_cuda, green_ctx, fill_kernel):
425433 def test_two_green_contexts_independent (self , init_cuda , sm_resource , fill_kernel ):
426434 """Two SM groups -> two green contexts -> two independent kernels."""
427435 dev = init_cuda
428- half = _aligned_half (sm_resource )
429- if half < sm_resource . min_partition_size :
436+ count = _safe_two_group_count (sm_resource )
437+ if count is None :
430438 pytest .skip ("Not enough SMs for a 2-group split" )
431439
432- groups , _ = sm_resource .split (SMResourceOptions (count = (half , half )))
440+ groups , _ = sm_resource .split (SMResourceOptions (count = (count , count )))
433441 assert len (groups ) == 2
434442
435443 ctx_a = ctx_b = None
0 commit comments