@@ -580,17 +580,136 @@ def test_stack(cpp, dtype):
580580 arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (4 )]
581581 assert_bit_aligned (cpp .stack (arrays ), np .stack (arrays ), "stack" )
582582
583- def test_concatenate (cpp , dtype ):
583+ def test_concatenate_1d (cpp , dtype ):
584584 arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (3 )]
585- assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate" )
585+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate 1d" )
586+
587+ def test_concatenate_2d_axis0 (cpp , dtype ):
588+ arrays = [random_array ((2 , 3 ), seed = i , dtype = dtype ) for i in range (3 )]
589+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concatenate 2d axis=0" )
590+ # Verify default axis=0
591+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate 2d default axis" )
592+
593+ def test_concatenate_2d_axis1 (cpp , dtype ):
594+ arrays = [random_array ((3 , 2 ), seed = i , dtype = dtype ) for i in range (3 )]
595+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concatenate 2d axis=1" )
596+
597+ def test_concatenate_2d_axis_neg1 (cpp , dtype ):
598+ arrays = [random_array ((3 , 2 ), seed = i , dtype = dtype ) for i in range (3 )]
599+ assert_bit_aligned (cpp .concatenate (arrays , - 1 ), np .concatenate (arrays , axis = - 1 ), "concatenate 2d axis=-1" )
600+
601+ def test_concatenate_3d_axis0 (cpp , dtype ):
602+ arrays = [random_array ((2 , 3 , 4 ), seed = i , dtype = dtype ) for i in range (2 )]
603+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concatenate 3d axis=0" )
604+
605+ def test_concatenate_3d_axis1 (cpp , dtype ):
606+ arrays = [random_array ((3 , 2 , 4 ), seed = i , dtype = dtype ) for i in range (2 )]
607+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concatenate 3d axis=1" )
608+
609+ def test_concatenate_3d_axis2 (cpp , dtype ):
610+ arrays = [random_array ((3 , 4 , 2 ), seed = i , dtype = dtype ) for i in range (2 )]
611+ assert_bit_aligned (cpp .concatenate (arrays , 2 ), np .concatenate (arrays , axis = 2 ), "concatenate 3d axis=2" )
612+
613+ def test_concatenate_two_arrays (cpp , dtype ):
614+ arrays = [random_array ((5 ,), seed = 0 , dtype = dtype ), random_array ((7 ,), seed = 1 , dtype = dtype )]
615+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate two" )
616+
617+ def test_concatenate_single (cpp , dtype ):
618+ arrays = [random_array ((5 ,), dtype = dtype )]
619+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concatenate single" )
586620
587621def test_vstack (cpp , dtype ):
588622 arrays = [random_array ((1 , 3 ), seed = i , dtype = dtype ) for i in range (4 )]
589623 assert_bit_aligned (cpp .vstack (arrays ), np .vstack (arrays ), "vstack" )
590624
625+ def test_vstack_1d (cpp , dtype ):
626+ arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (4 )]
627+ assert_bit_aligned (cpp .vstack (arrays ), np .vstack (arrays ), "vstack 1d" )
628+
591629def test_hstack (cpp , dtype ):
592630 arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (3 )]
593- assert_bit_aligned (cpp .hstack (arrays ), np .hstack (arrays ), "hstack" )
631+ assert_bit_aligned (cpp .hstack (arrays ), np .hstack (arrays ), "hstack 1d" )
632+
633+ def test_hstack_2d (cpp , dtype ):
634+ arrays = [random_array ((3 , 2 ), seed = i , dtype = dtype ) for i in range (3 )]
635+ assert_bit_aligned (cpp .hstack (arrays ), np .hstack (arrays ), "hstack 2d" )
636+
637+ # -- Concatenate complex / edge-case tests ----------------------------------
638+
639+ def test_concatenate_4d_axis0 (cpp , dtype ):
640+ arrays = [random_array ((2 , 3 , 4 , 5 ), seed = i , dtype = dtype ) for i in range (2 )]
641+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concatenate 4d axis=0" )
642+
643+ def test_concatenate_4d_axis2 (cpp , dtype ):
644+ arrays = [random_array ((2 , 3 , 2 , 5 ), seed = i , dtype = dtype ) for i in range (2 )]
645+ assert_bit_aligned (cpp .concatenate (arrays , 2 ), np .concatenate (arrays , axis = 2 ), "concatenate 4d axis=2" )
646+
647+ def test_concatenate_4d_axis_neg2 (cpp , dtype ):
648+ arrays = [random_array ((2 , 3 , 2 , 5 ), seed = i , dtype = dtype ) for i in range (2 )]
649+ assert_bit_aligned (cpp .concatenate (arrays , - 2 ), np .concatenate (arrays , axis = - 2 ), "concatenate 4d axis=-2" )
650+
651+ def test_concatenate_unequal_axis_sizes (cpp , dtype ):
652+ """Concatenate arrays of different sizes along the concatenation axis."""
653+ a = random_array ((3 , 2 ), seed = 1 , dtype = dtype )
654+ b = random_array ((3 , 4 ), seed = 2 , dtype = dtype )
655+ c = random_array ((3 , 1 ), seed = 3 , dtype = dtype )
656+ assert_bit_aligned (cpp .concatenate ([a , b , c ], 1 ),
657+ np .concatenate ([a , b , c ], axis = 1 ), "concat unequal axis sizes" )
658+
659+ def test_concatenate_many_arrays (cpp , dtype ):
660+ """Concatenate 10 arrays along axis=0."""
661+ arrays = [random_array ((3 ,), seed = i , dtype = dtype ) for i in range (10 )]
662+ assert_bit_aligned (cpp .concatenate (arrays ), np .concatenate (arrays ), "concat 10 arrays" )
663+
664+ def test_concatenate_large_3d (cpp , dtype ):
665+ """Large 3D concatenation along middle axis."""
666+ arrays = [random_array ((50 , 20 , 30 ), seed = i , dtype = dtype ) for i in range (3 )]
667+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concat large 3d axis=1" )
668+
669+ def test_concatenate_large_2d_axis0 (cpp , dtype ):
670+ """Large 2D concatenation — 500 rows each, 4 arrays."""
671+ arrays = [random_array ((500 , 10 ), seed = i , dtype = dtype ) for i in range (4 )]
672+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concat large 2d axis=0" )
673+
674+ def test_concatenate_large_2d_axis1 (cpp , dtype ):
675+ """Large 2D concatenation — 500 cols each, 3 arrays."""
676+ arrays = [random_array ((10 , 500 ), seed = i , dtype = dtype ) for i in range (3 )]
677+ assert_bit_aligned (cpp .concatenate (arrays , 1 ), np .concatenate (arrays , axis = 1 ), "concat large 2d axis=1" )
678+
679+ def test_concatenate_identity (cpp , dtype ):
680+ """Concatenating a single array returns identical copy."""
681+ a = random_array ((3 , 4 ), seed = 42 , dtype = dtype )
682+ assert_bit_aligned (cpp .concatenate ([a ], 0 ), np .concatenate ([a ], axis = 0 ), "concat identity" )
683+ assert_bit_aligned (cpp .concatenate ([a ], 1 ), np .concatenate ([a ], axis = 1 ), "concat identity axis=1" )
684+
685+ def test_concatenate_zeros (cpp , dtype ):
686+ """Concatenate arrays of zeros."""
687+ a = np .zeros ((2 , 3 ), dtype = dtype )
688+ b = np .zeros ((2 , 5 ), dtype = dtype )
689+ assert_bit_aligned (cpp .concatenate ([a , b ], 1 ), np .concatenate ([a , b ], axis = 1 ), "concat zeros" )
690+
691+ def test_concatenate_ones (cpp , dtype ):
692+ """Concatenate arrays of ones."""
693+ a = np .ones ((3 , 2 ), dtype = dtype )
694+ b = np .ones ((5 , 2 ), dtype = dtype )
695+ assert_bit_aligned (cpp .concatenate ([a , b ], 0 ), np .concatenate ([a , b ], axis = 0 ), "concat ones" )
696+
697+ def test_concatenate_3d_axis_neg2 (cpp , dtype ):
698+ """3D concatenate along axis=-2 (middle axis)."""
699+ arrays = [random_array ((2 , 3 , 4 ), seed = i , dtype = dtype ) for i in range (3 )]
700+ assert_bit_aligned (cpp .concatenate (arrays , - 2 ), np .concatenate (arrays , axis = - 2 ), "concat 3d axis=-2" )
701+
702+ def test_concatenate_3d_axis_neg3 (cpp , dtype ):
703+ """3D concatenate along axis=-3 (first axis)."""
704+ arrays = [random_array ((2 , 3 , 4 ), seed = i , dtype = dtype ) for i in range (2 )]
705+ assert_bit_aligned (cpp .concatenate (arrays , - 3 ), np .concatenate (arrays , axis = - 3 ), "concat 3d axis=-3" )
706+
707+ def test_concatenate_5d (cpp , dtype ):
708+ """5D concatenate along various axes."""
709+ arrays = [random_array ((2 , 3 , 2 , 3 , 2 ), seed = i , dtype = dtype ) for i in range (2 )]
710+ assert_bit_aligned (cpp .concatenate (arrays , 0 ), np .concatenate (arrays , axis = 0 ), "concat 5d axis=0" )
711+ assert_bit_aligned (cpp .concatenate (arrays , 2 ), np .concatenate (arrays , axis = 2 ), "concat 5d axis=2" )
712+ assert_bit_aligned (cpp .concatenate (arrays , - 1 ), np .concatenate (arrays , axis = - 1 ), "concat 5d axis=-1" )
594713
595714def test_where_scalar (cpp , dtype ):
596715 cond = np .array ([True , False , True , False , True ])
0 commit comments