Skip to content

Commit d81c887

Browse files
author
peng.li24
committed
feat: add N-D concatenate with axis support — bit-exact with numpy
Add axis-aware N-D concatenate to the native core (core.h), updating the pybind11 wrapper (core_py.h) and bindings (module.cpp) to accept an axis parameter. Native implementation uses leading-slice block copies — each slice contributes contiguous elements per array, so a single memcpy per array per slice suffices. Per-array strides correctly account for differing axis dimension sizes. Also improve vstack/hstack wrappers: - vstack: 1D arrays reshaped to (1,N) before stacking - hstack: uses axis=1 for 2D+ arrays Adds 30+ new concatenate tests covering 1D–5D, all axes, float32/64, large arrays, identity, zeros/ones, and edge cases.
1 parent 49647b7 commit d81c887

4 files changed

Lines changed: 296 additions & 15 deletions

File tree

numpy/core.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ inline void stack(const T* const* arrays, T* dst, size_t n_arrays, size_t elem_s
532532
}
533533

534534
/// numpy.concatenate((a1, a2, ...), axis=0, out=None, dtype=None, casting=...)
535+
/// 1D flat: arrays are treated as flat buffers, concatenated sequentially.
535536
template<typename T>
536537
inline void concatenate(const T* const* arrays, T* dst, const size_t* sizes, size_t n_arrays) {
537538
size_t off = 0;
@@ -541,6 +542,97 @@ inline void concatenate(const T* const* arrays, T* dst, const size_t* sizes, siz
541542
}
542543
}
543544

545+
/// numpy.concatenate((a1, a2, ...), axis=0, ...) — N-D with axis support.
546+
/// All arrays must have identical shape except along `axis`.
547+
/// `shape` is the representative common shape (use first array's shape);
548+
/// `axis_sizes[i]` gives the size of array i along the concatenation axis.
549+
///
550+
/// Strategy: iterate over "leading slices" (product of dims before axis).
551+
/// Within each slice, every array contributes a contiguous block of
552+
/// `axis_sizes[i] * trailing` elements. Since the elements are C-contiguous
553+
/// within each slice, a single memcpy per array per slice suffices.
554+
template<typename T>
555+
inline void concatenate(const T* const* arrays, T* dst,
556+
const ptrdiff_t* shape, int ndim, int axis,
557+
const size_t* axis_sizes, size_t n_arrays) {
558+
if (n_arrays == 0 || ndim == 0) return;
559+
560+
// Normalize axis
561+
if (axis < 0) axis += ndim;
562+
563+
// Trailing product = product of dims after axis (also = stride along axis)
564+
ptrdiff_t trailing = 1;
565+
for (int d = axis + 1; d < ndim; ++d) trailing *= shape[d];
566+
567+
// Total output axis size
568+
ptrdiff_t out_axis = 0;
569+
for (size_t i = 0; i < n_arrays; ++i)
570+
out_axis += static_cast<ptrdiff_t>(axis_sizes[i]);
571+
572+
// Per-array full strides (differ because axis dim sizes differ)
573+
// C-contiguous: stride[d] = stride[d+1] * size_of_dim[d+1].
574+
// Use axis_sizes[k] when d+1 == axis, common shape otherwise.
575+
std::vector<std::vector<ptrdiff_t>> in_stride(n_arrays);
576+
for (size_t k = 0; k < n_arrays; ++k) {
577+
in_stride[k].resize(ndim);
578+
in_stride[k][ndim - 1] = 1;
579+
for (int d = ndim - 2; d >= 0; --d) {
580+
ptrdiff_t s = (d + 1 == axis)
581+
? static_cast<ptrdiff_t>(axis_sizes[k])
582+
: shape[d + 1];
583+
in_stride[k][d] = in_stride[k][d + 1] * s;
584+
}
585+
}
586+
587+
// Output strides
588+
std::vector<ptrdiff_t> out_shape(shape, shape + ndim);
589+
out_shape[axis] = out_axis;
590+
std::vector<ptrdiff_t> out_stride(ndim);
591+
out_stride[ndim - 1] = 1;
592+
for (int d = ndim - 2; d >= 0; --d)
593+
out_stride[d] = out_stride[d + 1] * out_shape[d + 1];
594+
595+
// Per-array per-slice element count (contiguous elements contributed per slice)
596+
std::vector<size_t> slice_n(n_arrays);
597+
for (size_t i = 0; i < n_arrays; ++i)
598+
slice_n[i] = static_cast<size_t>(axis_sizes[i]) * static_cast<size_t>(trailing);
599+
600+
// Number of leading slices
601+
ptrdiff_t n_slices = 1;
602+
for (int d = 0; d < axis; ++d) n_slices *= shape[d];
603+
604+
// Total per-array byte size of one slice (for output position stepping)
605+
ptrdiff_t out_slice_bytes = static_cast<ptrdiff_t>(
606+
static_cast<size_t>(out_axis) * static_cast<size_t>(trailing) * sizeof(T));
607+
608+
// For each leading slice, copy contiguous blocks from each array
609+
for (ptrdiff_t s = 0; s < n_slices; ++s) {
610+
// Decompose slice index → multi-index for dims 0..axis-1
611+
ptrdiff_t rem = s;
612+
613+
// Per-array leading offset within the array
614+
std::vector<size_t> in_off(n_arrays, 0);
615+
616+
for (int d = axis - 1; d >= 0; --d) {
617+
ptrdiff_t idx = rem % shape[d];
618+
rem /= shape[d];
619+
for (size_t k = 0; k < n_arrays; ++k)
620+
in_off[k] += static_cast<size_t>(idx) * static_cast<size_t>(in_stride[k][d]);
621+
}
622+
623+
// Output position for this slice
624+
char* out_slice_start = reinterpret_cast<char*>(dst) + s * out_slice_bytes;
625+
size_t out_byte_off = 0;
626+
627+
for (size_t i = 0; i < n_arrays; ++i) {
628+
size_t bytes = slice_n[i] * sizeof(T);
629+
std::memcpy(out_slice_start + out_byte_off,
630+
arrays[i] + in_off[i], bytes);
631+
out_byte_off += bytes;
632+
}
633+
}
634+
}
635+
544636
/// numpy.where(condition, x, y) — scalar x, y
545637
template<typename T>
546638
inline void where_scalar(const bool* cond, T* dst, size_t n, T x, T y) {

pycpp/core_py.h

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -884,28 +884,93 @@ py::array_t<T> stack(const std::vector<py::array_t<T>>& arrays) {
884884

885885
/// numpy.concatenate((a1, a2, ...), axis=0, out=None, dtype=None, casting=...)
886886
template<typename T>
887-
py::array_t<T> concatenate(const std::vector<py::array_t<T>>& arrays) {
887+
py::array_t<T> concatenate(const std::vector<py::array_t<T>>& arrays, int axis = 0) {
888888
if (arrays.empty()) return py::array_t<T>{};
889-
py::ssize_t total = 0;
890-
for (const auto& arr : arrays) total += arr.request().size;
891-
py::array_t<T> result({total});
892-
T* dst = static_cast<T*>(result.request().ptr);
893-
py::ssize_t off = 0;
889+
890+
auto buf0 = arrays[0].request();
891+
int ndim = static_cast<int>(buf0.ndim);
892+
893+
if (axis < 0) axis += ndim;
894+
if (axis < 0 || axis >= ndim)
895+
throw std::invalid_argument("concatenate: axis out of range");
896+
897+
// Validate that all arrays have same number of dimensions
894898
for (const auto& arr : arrays) {
895-
auto buf = arr.request();
896-
std::memcpy(dst + off, static_cast<const T*>(buf.ptr), buf.size * sizeof(T));
897-
off += buf.size;
899+
if (arr.request().ndim != ndim)
900+
throw std::invalid_argument("concatenate: all arrays must have same number of dimensions");
901+
}
902+
903+
// Collect shape (from first array) and per-array axis sizes
904+
std::vector<ptrdiff_t> shape(ndim);
905+
for (int d = 0; d < ndim; ++d) shape[d] = buf0.shape[d];
906+
907+
std::vector<size_t> axis_sizes(arrays.size());
908+
for (size_t i = 0; i < arrays.size(); ++i) {
909+
auto buf = arrays[i].request();
910+
axis_sizes[i] = static_cast<size_t>(buf.shape[axis]);
911+
}
912+
913+
// Validate non-axis dimensions match
914+
for (size_t i = 0; i < arrays.size(); ++i) {
915+
auto buf = arrays[i].request();
916+
for (int d = 0; d < ndim; ++d) {
917+
if (d == axis) continue;
918+
if (buf.shape[d] != shape[d])
919+
throw std::invalid_argument(
920+
"concatenate: all arrays must have same shape except along axis");
921+
}
898922
}
923+
924+
// Compute output shape
925+
std::vector<ptrdiff_t> out_shape = shape;
926+
ptrdiff_t total_axis = 0;
927+
for (size_t i = 0; i < arrays.size(); ++i)
928+
total_axis += static_cast<ptrdiff_t>(axis_sizes[i]);
929+
out_shape[axis] = total_axis;
930+
931+
std::vector<py::ssize_t> py_out_shape(out_shape.begin(), out_shape.end());
932+
py::array_t<T> result(py_out_shape);
933+
T* dst = static_cast<T*>(result.request().ptr);
934+
935+
// Build pointer array
936+
std::vector<const T*> ptrs(arrays.size());
937+
for (size_t i = 0; i < arrays.size(); ++i)
938+
ptrs[i] = static_cast<const T*>(arrays[i].request().ptr);
939+
940+
numpy::concatenate(ptrs.data(), dst, shape.data(), ndim, axis,
941+
axis_sizes.data(), arrays.size());
899942
return result;
900943
}
901944

902945
/// numpy.vstack(tup, *, dtype=None, casting=...)
903946
template<typename T>
904-
py::array_t<T> vstack(const std::vector<py::array_t<T>>& arrays) { return stack(arrays); }
947+
py::array_t<T> vstack(const std::vector<py::array_t<T>>& arrays) {
948+
if (arrays.empty()) return py::array_t<T>{};
949+
int ndim = static_cast<int>(arrays[0].request().ndim);
950+
if (ndim == 1) {
951+
// numpy.vstack: 1D arrays are reshaped to (1, N) before stacking
952+
auto buf0 = arrays[0].request();
953+
py::array_t<T> result({static_cast<py::ssize_t>(arrays.size()), static_cast<py::ssize_t>(buf0.size)});
954+
T* dst = static_cast<T*>(result.request().ptr);
955+
for (size_t i = 0; i < arrays.size(); ++i) {
956+
auto buf = arrays[i].request();
957+
std::memcpy(dst + i * buf0.size, static_cast<const T*>(buf.ptr),
958+
buf.size * sizeof(T));
959+
}
960+
return result;
961+
}
962+
return concatenate(arrays, 0);
963+
}
905964

906965
/// numpy.hstack(tup, *, dtype=None, casting=...)
907966
template<typename T>
908-
py::array_t<T> hstack(const std::vector<py::array_t<T>>& arrays) { return concatenate(arrays); }
967+
py::array_t<T> hstack(const std::vector<py::array_t<T>>& arrays) {
968+
if (arrays.empty()) return py::array_t<T>{};
969+
int ndim = static_cast<int>(arrays[0].request().ndim);
970+
// 1D arrays: hstack is identical to concatenate along axis=0
971+
// 2D+ arrays: hstack concatenates along axis=1
972+
return concatenate(arrays, (ndim == 1) ? 0 : 1);
973+
}
909974

910975
/// numpy.where(condition, x, y) — scalar x, y
911976
template<typename T>

tests/module.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,12 @@ PYBIND11_MODULE(numpycpp, m) {
196196
py::arg("arr"), py::arg("n") = 1, py::arg("axis") = -1);
197197
m.def("diff", static_cast<py::array_t<float>(*)(const py::array_t<float>&, int, int)>(&numpy::diff),
198198
py::arg("arr"), py::arg("n") = 1, py::arg("axis") = -1);
199-
BIND_F_STACK(stack); BIND_F_STACK(concatenate); BIND_F_STACK(vstack); BIND_F_STACK(hstack);
199+
BIND_F_STACK(stack); BIND_F_STACK(vstack); BIND_F_STACK(hstack);
200+
201+
m.def("concatenate", static_cast<py::array_t<double>(*)(const std::vector<py::array_t<double>>&, int)>(&numpy::concatenate),
202+
py::arg("arrays"), py::arg("axis") = 0);
203+
m.def("concatenate", static_cast<py::array_t<float>(*)(const std::vector<py::array_t<float>>&, int)>(&numpy::concatenate),
204+
py::arg("arrays"), py::arg("axis") = 0);
200205

201206
m.def("where", static_cast<py::array_t<double>(*)(const py::array_t<bool>&, double, double)>(&numpy::where));
202207
m.def("where", static_cast<py::array_t<float>(*)(const py::array_t<bool>&, float, float)>(&numpy::where));

tests/test_all.py

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,17 +580,136 @@ def test_stack(cpp, dtype):
580580
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(4)]
581581
assert_bit_aligned(cpp.stack(arrays), np.stack(arrays), "stack")
582582

583-
def test_concatenate(cpp, dtype):
583+
def test_concatenate_1d(cpp, dtype):
584584
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(3)]
585-
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate")
585+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate 1d")
586+
587+
def test_concatenate_2d_axis0(cpp, dtype):
588+
arrays = [random_array((2, 3), seed=i, dtype=dtype) for i in range(3)]
589+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concatenate 2d axis=0")
590+
# Verify default axis=0
591+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate 2d default axis")
592+
593+
def test_concatenate_2d_axis1(cpp, dtype):
594+
arrays = [random_array((3, 2), seed=i, dtype=dtype) for i in range(3)]
595+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concatenate 2d axis=1")
596+
597+
def test_concatenate_2d_axis_neg1(cpp, dtype):
598+
arrays = [random_array((3, 2), seed=i, dtype=dtype) for i in range(3)]
599+
assert_bit_aligned(cpp.concatenate(arrays, -1), np.concatenate(arrays, axis=-1), "concatenate 2d axis=-1")
600+
601+
def test_concatenate_3d_axis0(cpp, dtype):
602+
arrays = [random_array((2, 3, 4), seed=i, dtype=dtype) for i in range(2)]
603+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concatenate 3d axis=0")
604+
605+
def test_concatenate_3d_axis1(cpp, dtype):
606+
arrays = [random_array((3, 2, 4), seed=i, dtype=dtype) for i in range(2)]
607+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concatenate 3d axis=1")
608+
609+
def test_concatenate_3d_axis2(cpp, dtype):
610+
arrays = [random_array((3, 4, 2), seed=i, dtype=dtype) for i in range(2)]
611+
assert_bit_aligned(cpp.concatenate(arrays, 2), np.concatenate(arrays, axis=2), "concatenate 3d axis=2")
612+
613+
def test_concatenate_two_arrays(cpp, dtype):
614+
arrays = [random_array((5,), seed=0, dtype=dtype), random_array((7,), seed=1, dtype=dtype)]
615+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate two")
616+
617+
def test_concatenate_single(cpp, dtype):
618+
arrays = [random_array((5,), dtype=dtype)]
619+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concatenate single")
586620

587621
def test_vstack(cpp, dtype):
588622
arrays = [random_array((1, 3), seed=i, dtype=dtype) for i in range(4)]
589623
assert_bit_aligned(cpp.vstack(arrays), np.vstack(arrays), "vstack")
590624

625+
def test_vstack_1d(cpp, dtype):
626+
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(4)]
627+
assert_bit_aligned(cpp.vstack(arrays), np.vstack(arrays), "vstack 1d")
628+
591629
def test_hstack(cpp, dtype):
592630
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(3)]
593-
assert_bit_aligned(cpp.hstack(arrays), np.hstack(arrays), "hstack")
631+
assert_bit_aligned(cpp.hstack(arrays), np.hstack(arrays), "hstack 1d")
632+
633+
def test_hstack_2d(cpp, dtype):
634+
arrays = [random_array((3, 2), seed=i, dtype=dtype) for i in range(3)]
635+
assert_bit_aligned(cpp.hstack(arrays), np.hstack(arrays), "hstack 2d")
636+
637+
# -- Concatenate complex / edge-case tests ----------------------------------
638+
639+
def test_concatenate_4d_axis0(cpp, dtype):
640+
arrays = [random_array((2, 3, 4, 5), seed=i, dtype=dtype) for i in range(2)]
641+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concatenate 4d axis=0")
642+
643+
def test_concatenate_4d_axis2(cpp, dtype):
644+
arrays = [random_array((2, 3, 2, 5), seed=i, dtype=dtype) for i in range(2)]
645+
assert_bit_aligned(cpp.concatenate(arrays, 2), np.concatenate(arrays, axis=2), "concatenate 4d axis=2")
646+
647+
def test_concatenate_4d_axis_neg2(cpp, dtype):
648+
arrays = [random_array((2, 3, 2, 5), seed=i, dtype=dtype) for i in range(2)]
649+
assert_bit_aligned(cpp.concatenate(arrays, -2), np.concatenate(arrays, axis=-2), "concatenate 4d axis=-2")
650+
651+
def test_concatenate_unequal_axis_sizes(cpp, dtype):
652+
"""Concatenate arrays of different sizes along the concatenation axis."""
653+
a = random_array((3, 2), seed=1, dtype=dtype)
654+
b = random_array((3, 4), seed=2, dtype=dtype)
655+
c = random_array((3, 1), seed=3, dtype=dtype)
656+
assert_bit_aligned(cpp.concatenate([a, b, c], 1),
657+
np.concatenate([a, b, c], axis=1), "concat unequal axis sizes")
658+
659+
def test_concatenate_many_arrays(cpp, dtype):
660+
"""Concatenate 10 arrays along axis=0."""
661+
arrays = [random_array((3,), seed=i, dtype=dtype) for i in range(10)]
662+
assert_bit_aligned(cpp.concatenate(arrays), np.concatenate(arrays), "concat 10 arrays")
663+
664+
def test_concatenate_large_3d(cpp, dtype):
665+
"""Large 3D concatenation along middle axis."""
666+
arrays = [random_array((50, 20, 30), seed=i, dtype=dtype) for i in range(3)]
667+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concat large 3d axis=1")
668+
669+
def test_concatenate_large_2d_axis0(cpp, dtype):
670+
"""Large 2D concatenation — 500 rows each, 4 arrays."""
671+
arrays = [random_array((500, 10), seed=i, dtype=dtype) for i in range(4)]
672+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concat large 2d axis=0")
673+
674+
def test_concatenate_large_2d_axis1(cpp, dtype):
675+
"""Large 2D concatenation — 500 cols each, 3 arrays."""
676+
arrays = [random_array((10, 500), seed=i, dtype=dtype) for i in range(3)]
677+
assert_bit_aligned(cpp.concatenate(arrays, 1), np.concatenate(arrays, axis=1), "concat large 2d axis=1")
678+
679+
def test_concatenate_identity(cpp, dtype):
680+
"""Concatenating a single array returns identical copy."""
681+
a = random_array((3, 4), seed=42, dtype=dtype)
682+
assert_bit_aligned(cpp.concatenate([a], 0), np.concatenate([a], axis=0), "concat identity")
683+
assert_bit_aligned(cpp.concatenate([a], 1), np.concatenate([a], axis=1), "concat identity axis=1")
684+
685+
def test_concatenate_zeros(cpp, dtype):
686+
"""Concatenate arrays of zeros."""
687+
a = np.zeros((2, 3), dtype=dtype)
688+
b = np.zeros((2, 5), dtype=dtype)
689+
assert_bit_aligned(cpp.concatenate([a, b], 1), np.concatenate([a, b], axis=1), "concat zeros")
690+
691+
def test_concatenate_ones(cpp, dtype):
692+
"""Concatenate arrays of ones."""
693+
a = np.ones((3, 2), dtype=dtype)
694+
b = np.ones((5, 2), dtype=dtype)
695+
assert_bit_aligned(cpp.concatenate([a, b], 0), np.concatenate([a, b], axis=0), "concat ones")
696+
697+
def test_concatenate_3d_axis_neg2(cpp, dtype):
698+
"""3D concatenate along axis=-2 (middle axis)."""
699+
arrays = [random_array((2, 3, 4), seed=i, dtype=dtype) for i in range(3)]
700+
assert_bit_aligned(cpp.concatenate(arrays, -2), np.concatenate(arrays, axis=-2), "concat 3d axis=-2")
701+
702+
def test_concatenate_3d_axis_neg3(cpp, dtype):
703+
"""3D concatenate along axis=-3 (first axis)."""
704+
arrays = [random_array((2, 3, 4), seed=i, dtype=dtype) for i in range(2)]
705+
assert_bit_aligned(cpp.concatenate(arrays, -3), np.concatenate(arrays, axis=-3), "concat 3d axis=-3")
706+
707+
def test_concatenate_5d(cpp, dtype):
708+
"""5D concatenate along various axes."""
709+
arrays = [random_array((2, 3, 2, 3, 2), seed=i, dtype=dtype) for i in range(2)]
710+
assert_bit_aligned(cpp.concatenate(arrays, 0), np.concatenate(arrays, axis=0), "concat 5d axis=0")
711+
assert_bit_aligned(cpp.concatenate(arrays, 2), np.concatenate(arrays, axis=2), "concat 5d axis=2")
712+
assert_bit_aligned(cpp.concatenate(arrays, -1), np.concatenate(arrays, axis=-1), "concat 5d axis=-1")
594713

595714
def test_where_scalar(cpp, dtype):
596715
cond = np.array([True, False, True, False, True])

0 commit comments

Comments
 (0)