From 9864f1c829e484954dc21d8b35f85488dfc02584 Mon Sep 17 00:00:00 2001 From: ThreeMonth03 Date: Sun, 10 May 2026 20:54:22 +0800 Subject: [PATCH] Accelerate fast_matmul with packing and tiling --- cpp/modmesh/buffer/SimpleArray.hpp | 540 +++++++++++++----- cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp | 12 + profiling/profile_matrix_ops.py | 34 +- tests/test_matrix.py | 65 +++ 4 files changed, 508 insertions(+), 143 deletions(-) diff --git a/cpp/modmesh/buffer/SimpleArray.hpp b/cpp/modmesh/buffer/SimpleArray.hpp index 89f7d8f0f..d5b15b53f 100644 --- a/cpp/modmesh/buffer/SimpleArray.hpp +++ b/cpp/modmesh/buffer/SimpleArray.hpp @@ -134,6 +134,363 @@ struct SimpleArrayInternalTypes using buffer_type = ConcreteBuffer; }; /* end class SimpleArrayInternalType */ +template +class SimpleArrayMatmulHelper +{ + +private: + + using internal_types = detail::SimpleArrayInternalTypes; + +public: + + using value_type = typename internal_types::value_type; + using shape_type = typename internal_types::shape_type; + + SimpleArrayMatmulHelper() = delete; + SimpleArrayMatmulHelper(A const & lhs, A const & rhs); + SimpleArrayMatmulHelper(A const & lhs, + A const & rhs, + size_t tile_x, + size_t tile_y, + size_t tile_z); + ~SimpleArrayMatmulHelper() = default; + + SimpleArrayMatmulHelper(SimpleArrayMatmulHelper const &) = delete; + SimpleArrayMatmulHelper(SimpleArrayMatmulHelper &&) = delete; + SimpleArrayMatmulHelper & operator=(SimpleArrayMatmulHelper const &) = delete; + SimpleArrayMatmulHelper & operator=(SimpleArrayMatmulHelper &&) = delete; + + A matmul(); + A matmul_fast(); + +private: + + static std::string shape_str(A const & arr); + void check_dims() const; + void check_inner(size_t lhs_idx, size_t rhs_idx) const; + void check_tiles() const; + A matmul_vec_vec(); + A matmul_vec_mat(); + A matmul_mat_vec(); + A matmul_mat_mat(); + A pack_rhs(size_t n, size_t k); + void accumulate_tile(A const & packed_rhs, + size_t row_begin, + size_t row_end, + size_t col_begin, + size_t col_end, + size_t inner_begin, + size_t inner_end); + A matmul_mat_mat_tiled(); + + A const & m_lhs; + A const & m_rhs; + A m_result; + size_t m_tile_x; + size_t m_tile_y; + size_t m_tile_z; + +}; /* end class SimpleArrayMatmulHelper */ + +template +SimpleArrayMatmulHelper::SimpleArrayMatmulHelper(A const & lhs, A const & rhs) + : SimpleArrayMatmulHelper(lhs, rhs, 0, 0, 0) +{ +} + +template +SimpleArrayMatmulHelper::SimpleArrayMatmulHelper(A const & lhs, + A const & rhs, + size_t tile_x, + size_t tile_y, + size_t tile_z) + : m_lhs(lhs) + , m_rhs(rhs) + , m_tile_x(tile_x) + , m_tile_y(tile_y) + , m_tile_z(tile_z) +{ + check_dims(); + + size_t const lhs_ndim = m_lhs.ndim(); + size_t const rhs_ndim = m_rhs.ndim(); + + if (lhs_ndim == 1 && rhs_ndim == 1) + { + check_inner(0, 0); + m_result = A(1); + return; + } + + if (lhs_ndim == 1) + { + check_inner(0, 0); + m_result = A(m_rhs.shape(1)); + return; + } + + if (rhs_ndim == 1) + { + check_inner(1, 0); + m_result = A(m_lhs.shape(0)); + return; + } + + check_inner(1, 0); + shape_type const result_shape{m_lhs.shape(0), m_rhs.shape(1)}; + m_result = A(result_shape); +} + +template +A SimpleArrayMatmulHelper::matmul() +{ + if (m_lhs.ndim() == 1 && m_rhs.ndim() == 1) + { + return matmul_vec_vec(); + } + if (m_lhs.ndim() == 1) + { + return matmul_vec_mat(); + } + if (m_rhs.ndim() == 1) + { + return matmul_mat_vec(); + } + + return matmul_mat_mat(); +} + +/** + * Perform fast matrix multiplication for SimpleArrays. + * This implementation currently uses tiling for 2D x 2D matrix multiplication. + * Future optimizations may add other techniques such as SIMD kernels. + */ +template +A SimpleArrayMatmulHelper::matmul_fast() +{ + check_tiles(); + + if (m_lhs.ndim() == 1 && m_rhs.ndim() == 1) + { + return matmul_vec_vec(); + } + if (m_lhs.ndim() == 1) + { + return matmul_vec_mat(); + } + if (m_rhs.ndim() == 1) + { + return matmul_mat_vec(); + } + + return matmul_mat_mat_tiled(); +} + +/** + * Format shape for matrix multiplication diagnostics. + */ +template +std::string SimpleArrayMatmulHelper::shape_str(A const & arr) +{ + if (arr.ndim() == 0) + { + return "()"; + } + + std::string result = "("; + for (size_t i = 0; i < arr.ndim(); ++i) + { + if (i > 0) + { + result += ","; + } + result += std::to_string(arr.shape(i)); + } + result += ")"; + return result; +} + +template +void SimpleArrayMatmulHelper::check_dims() const +{ + bool const lhs_is_supported = m_lhs.ndim() == 1 || m_lhs.ndim() == 2; + bool const rhs_is_supported = m_rhs.ndim() == 1 || m_rhs.ndim() == 2; + if (lhs_is_supported && rhs_is_supported) + { + return; + } + + std::string const err = std::format("SimpleArray::matmul(): unsupported dimensions: " + "this={} other={}. SimpleArray must be 1D or 2D.", + shape_str(m_lhs), + shape_str(m_rhs)); + throw std::out_of_range(err); +} + +template +void SimpleArrayMatmulHelper::check_inner(size_t lhs_idx, size_t rhs_idx) const +{ + if (m_lhs.shape(lhs_idx) == m_rhs.shape(rhs_idx)) + { + return; + } + + throw std::out_of_range( + std::format("SimpleArray::matmul(): shape mismatch: this={} other={}", + shape_str(m_lhs), + shape_str(m_rhs))); +} + +template +void SimpleArrayMatmulHelper::check_tiles() const +{ + if (m_tile_x != 0 && m_tile_y != 0 && m_tile_z != 0) + { + return; + } + + throw std::out_of_range( + std::format("SimpleArray::fast_matmul(): tile sizes must be positive: " + "tile_x={} tile_y={} tile_z={}", + m_tile_x, + m_tile_y, + m_tile_z)); +} + +template +A SimpleArrayMatmulHelper::matmul_vec_vec() +{ + size_t const k = m_lhs.shape(0); + value_type v = 0; + for (size_t i = 0; i < k; ++i) + { + v += m_lhs(i) * m_rhs.data(i); + } + m_result.data(0) = v; + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::matmul_vec_mat() +{ + size_t const n = m_result.size(); + size_t const k = m_lhs.shape(0); + for (size_t j = 0; j < n; ++j) + { + value_type v = 0; + for (size_t l = 0; l < k; ++l) + { + v += m_lhs(l) * m_rhs(l, j); + } + m_result.data(j) = v; + } + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::matmul_mat_vec() +{ + size_t const m = m_result.size(); + size_t const k = m_lhs.shape(1); + for (size_t i = 0; i < m; ++i) + { + value_type v = 0; + for (size_t l = 0; l < k; ++l) + { + v += m_lhs(i, l) * m_rhs(l); + } + m_result.data(i) = v; + } + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::matmul_mat_mat() +{ + size_t const m = m_result.shape(0); + size_t const n = m_result.shape(1); + size_t const k = m_lhs.shape(1); + for (size_t i = 0; i < m; ++i) + { + for (size_t j = 0; j < n; ++j) + { + value_type v = 0; + for (size_t l = 0; l < k; ++l) + { + v += m_lhs(i, l) * m_rhs(l, j); + } + m_result(i, j) = v; + } + } + return std::move(m_result); +} + +template +A SimpleArrayMatmulHelper::pack_rhs(size_t n, size_t k) +{ + shape_type const packing_shape{n, k}; + A packing(packing_shape); + for (size_t i = 0; i < n; ++i) + { + for (size_t j = 0; j < k; ++j) + { + packing(i, j) = m_rhs(j, i); + } + } + return packing; +} + +template +void SimpleArrayMatmulHelper::accumulate_tile(A const & packed_rhs, + size_t row_begin, + size_t row_end, + size_t col_begin, + size_t col_end, + size_t inner_begin, + size_t inner_end) +{ + for (size_t i = row_begin; i < row_end; ++i) + { + for (size_t j = col_begin; j < col_end; ++j) + { + value_type v = m_result(i, j); + for (size_t l = inner_begin; l < inner_end; ++l) + { + v += m_lhs(i, l) * packed_rhs(j, l); + } + m_result(i, j) = v; + } + } +} + +template +A SimpleArrayMatmulHelper::matmul_mat_mat_tiled() +{ + size_t const m = m_result.shape(0); + size_t const n = m_result.shape(1); + size_t const k = m_lhs.shape(1); + A packed_rhs = pack_rhs(n, k); + for (size_t i = 0; i < m_result.size(); ++i) + { + m_result.data(i) = value_type{0}; + } + for (size_t row = 0; row < m; row += m_tile_x) + { + size_t const row_end = std::min(row + m_tile_x, m); + for (size_t col = 0; col < n; col += m_tile_y) + { + size_t const col_end = std::min(col + m_tile_y, n); + for (size_t inner = 0; inner < k; inner += m_tile_z) + { + size_t const inner_end = std::min(inner + m_tile_z, k); + accumulate_tile(packed_rhs, row, row_end, col, col_end, inner, inner_end); + } + } + } + return std::move(m_result); +} + template class SimpleArrayMixinModifiers { @@ -854,8 +1211,6 @@ class SimpleArrayMixinCalculators return *athis; } - A & imatmul(A const & other); - A add_simd(A const & other) const { A const * athis = static_cast(this); @@ -953,6 +1308,15 @@ class SimpleArrayMixinCalculators } A matmul(A const & other) const; + A & imatmul(A const & other); + A fast_matmul(A const & other, + size_t tile_x, + size_t tile_y, + size_t tile_z) const; + A & fast_imatmul(A const & other, + size_t tile_x, + size_t tile_y, + size_t tile_z); private: static void find_two_bins(const uint32_t * freq, size_t n, int & bin1, int & bin2); @@ -1051,152 +1415,56 @@ detail::SimpleArrayMixinCalculators::median_freq(small_vector * This implementation supports 1D x 1D, 1D x 2D, 2D x 1D, and 2D x 2D matrix multiplication. */ template -A SimpleArrayMixinCalculators::matmul(A const & other) const // NOLINT(readability-function-cognitive-complexity) +A SimpleArrayMixinCalculators::matmul(A const & other) const { - auto athis = static_cast(this); - const size_t this_ndim = athis->ndim(); - const size_t other_ndim = other.ndim(); - - auto format_shape = [](A const * arr) -> std::string - { - if (arr->ndim() == 0) - { - return "()"; - } - else - { - std::string result = "("; - for (size_t i = 0; i < arr->ndim(); ++i) - { - if (i > 0) - { - result += ","; - } - result += std::to_string(arr->shape(i)); - } - result += ")"; - return result; - } - }; - - auto check_product_shape = [&](A const * athis, A const * other, ssize_t athis_idx, ssize_t other_idx) -> void - { - if (athis->shape(athis_idx) != other->shape(other_idx)) - { - throw std::out_of_range( - std::format("SimpleArray::matmul(): shape mismatch: this={} other={}", - format_shape(athis), - format_shape(other))); - } - }; - - if ((this_ndim != 2 && this_ndim != 1) || (other_ndim != 2 && other_ndim != 1)) - { - const std::string err = std::format("SimpleArray::matmul(): unsupported dimensions: " - "this={} other={}. SimpleArray must be 1D or 2D.", - format_shape(athis), - format_shape(&other)); - throw std::out_of_range(err); - } - - bool const this_is_1d = (this_ndim == 1); - bool const other_is_1d = (other_ndim == 1); - - // 1D x 1D - if (this_is_1d && other_is_1d) - { - check_product_shape(athis, &other, 0, 0); - A result(1); - value_type v = 0; - for (size_t i = 0; i < athis->shape(0); ++i) - { - v += (*athis)(i)*other.data(i); - } - result.data(0) = v; - return result; - } - // 1D x 2D - else if (this_is_1d) - { - const size_t k = athis->shape(0); - const size_t n = other.shape(1); - check_product_shape(athis, &other, 0, 0); - A result(n); - - for (size_t j = 0; j < n; ++j) - { - value_type v = 0; - for (size_t l = 0; l < k; ++l) - { - v += (*athis)(l)*other(l, j); - } - result.data(j) = v; - } - return result; - } - // 2D x 1D - else if (other_is_1d) - { - const size_t m = athis->shape(0); - const size_t k = athis->shape(1); - - check_product_shape(athis, &other, 1, 0); - A result(m); - - for (size_t i = 0; i < m; ++i) - { - value_type v = 0; - for (size_t l = 0; l < k; ++l) - { - v += (*athis)(i, l) * other(l); - } - result.data(i) = v; - } - return result; - } - // 2D x 2D - else - { - const size_t m = athis->shape(0); - const size_t k = athis->shape(1); - const size_t n = other.shape(1); - check_product_shape(athis, &other, 1, 0); + auto const * athis = static_cast(this); + SimpleArrayMatmulHelper helper(*athis, other); + return helper.matmul(); +} - shape_type const result_shape{m, n}; - A result(result_shape); +/** + * Perform in-place matrix multiplication for SimpleArrays. + * This implementation supports 1D x 1D, 1D x 2D, 2D x 1D, and 2D x 2D matrix multiplication. + * The result replaces the content of the current array. + */ +template +A & SimpleArrayMixinCalculators::imatmul(A const & other) +{ + auto athis = static_cast(this); + A result = athis->matmul(other); + *athis = std::move(result); - for (size_t i = 0; i < m; ++i) - { - for (size_t j = 0; j < n; ++j) - { - value_type v = 0; - for (size_t l = 0; l < k; ++l) - { - v += (*athis)(i, l) * other(l, j); - } - result(i, j) = v; - } - } - return result; - } + return *athis; +} - throw std::out_of_range( - std::format("SimpleArray::matmul(): this={} other={}" - " cannot perform matrix multiplication.", - format_shape(athis), - format_shape(&other))); +/** + * Perform fast matrix multiplication for SimpleArrays. + * This implementation supports 1D x 1D, 1D x 2D, 2D x 1D, and 2D x 2D matrix multiplication. + */ +template +A SimpleArrayMixinCalculators::fast_matmul(A const & other, + size_t tile_x, + size_t tile_y, + size_t tile_z) const +{ + auto const * athis = static_cast(this); + SimpleArrayMatmulHelper helper(*athis, other, tile_x, tile_y, tile_z); + return helper.matmul_fast(); } /** - * Perform in-place matrix multiplication for 2D arrays. - * This implementation supports only 2D x 2D matrix multiplication. + * Perform in-place fast matrix multiplication for SimpleArrays. + * This implementation supports 1D x 1D, 1D x 2D, 2D x 1D, and 2D x 2D matrix multiplication. * The result replaces the content of the current array. */ template -A & SimpleArrayMixinCalculators::imatmul(A const & other) +A & SimpleArrayMixinCalculators::fast_imatmul(A const & other, + size_t tile_x, + size_t tile_y, + size_t tile_z) { auto athis = static_cast(this); - A result = athis->matmul(other); + A result = athis->fast_matmul(other, tile_x, tile_y, tile_z); *athis = std::move(result); return *athis; diff --git a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp index 311f3ae5a..42c9c2b75 100644 --- a/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp +++ b/cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp @@ -359,6 +359,18 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray { return self.div(scalar); }) .def("matmul", &wrapped_type::matmul) .def("__matmul__", &wrapped_type::matmul) + .def( + "fast_matmul", + [](wrapped_type const & self, + wrapped_type const & other, + size_t tile_x, + size_t tile_y, + size_t tile_z) + { return self.fast_matmul(other, tile_x, tile_y, tile_z); }, + py::arg("other"), + py::arg("tile_x") = 16, + py::arg("tile_y") = 16, + py::arg("tile_z") = 16) // TODO: In-place operation should return reference to self to support function chaining /* * Regular in-place methods (iadd, imul, etc.) are procedural calls and do diff --git a/profiling/profile_matrix_ops.py b/profiling/profile_matrix_ops.py index 05e2443dd..1cfddfa55 100644 --- a/profiling/profile_matrix_ops.py +++ b/profiling/profile_matrix_ops.py @@ -54,15 +54,26 @@ def profile_matmul_np(lhs, rhs): @profile_function -def profile_matmul_sa(lhs, rhs): +def profile_matmul_naive_sa(lhs, rhs): return lhs.matmul(rhs) +def profile_matmul_fast_sa(lhs, rhs, tile_x, tile_y, tile_z): + name = f"profile_matmul_fast_sa_{tile_x}_{tile_y}_{tile_z}" + _ = modmesh.CallProfilerProbe(name) + return lhs.fast_matmul(rhs, tile_x=tile_x, tile_y=tile_y, tile_z=tile_z) + + def make_data(dtype, shape): return np.random.rand(*shape).astype(dtype) def profile_matmul_operation(dtype, shapes, it=10): + tile_configs = ( + (16, 16, 16), + (32, 32, 32), + (64, 64, 64), + ) for m in shapes: lhs = make_data(dtype, (m, m)) rhs = make_data(dtype, (m, m)) @@ -71,7 +82,9 @@ def profile_matmul_operation(dtype, shapes, it=10): modmesh.call_profiler.reset() for _ in range(it): profile_matmul_np(lhs, rhs) - profile_matmul_sa(lhs_sa, rhs_sa) + profile_matmul_naive_sa(lhs_sa, rhs_sa) + for tile_x, tile_y, tile_z in tile_configs: + profile_matmul_fast_sa(lhs_sa, rhs_sa, tile_x, tile_y, tile_z) res = modmesh.call_profiler.result()["children"] out = {} @@ -79,16 +92,23 @@ def profile_matmul_operation(dtype, shapes, it=10): name = r["name"].replace("profile_matmul_", "") out[name] = r["total_time"] / r["count"] - print(f"## 2D x 2D shape: ({m}, {m}) x ({m}, {m}) dtype:" - f"`{np.dtype(dtype)}`\n") + print( + f"## 2D x 2D shape: ({m}, {m}) x ({m}, {m}) dtype:" + f"`{np.dtype(dtype)}`\n" + ) def print_row(*cols): - print(str.format("| {:10s} | {:15s} | {:15s} |", *(cols[0:3]))) + print(str.format("| {:20s} | {:15s} | {:15s} |", *(cols[0:3]))) print_row("func", "per call (ms)", "cmp to np") - print_row("-" * 10, "-" * 15, "-" * 15) + print_row("-" * 20, "-" * 15, "-" * 15) npbase = out["np"] - for key in ("np", "sa"): + keys = ["np", "naive_sa"] + keys += [ + f"fast_sa_{tile_x}_{tile_y}_{tile_z}" + for tile_x, tile_y, tile_z in tile_configs + ] + for key in keys: value = out[key] print_row(f"{key:8s}", f"{value:.3E}", f"{value / npbase:.3f}") print() diff --git a/tests/test_matrix.py b/tests/test_matrix.py index bf402c6b4..fe3e8a0c0 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -97,9 +97,14 @@ def test_square(self): # Test matrix multiplication result = a.matmul(b) + fast_result = a.fast_matmul(b) self.assertEqual(list(result.shape), [2, 2]) np.testing.assert_array_almost_equal(result.ndarray, expected) + self.assertEqual(list(fast_result.shape), [2, 2]) + np.testing.assert_array_almost_equal(fast_result.ndarray, expected) + np.testing.assert_array_almost_equal(fast_result.ndarray, + result.ndarray) def test_rectangular(self): """Test rectangular matrix multiplication""" @@ -117,9 +122,14 @@ def test_rectangular(self): dtype=self.dtype) result = a.matmul(b) + fast_result = a.fast_matmul(b) self.assertEqual(list(result.shape), [2, 2]) np.testing.assert_array_almost_equal(result.ndarray, expected) + self.assertEqual(list(fast_result.shape), [2, 2]) + np.testing.assert_array_almost_equal(fast_result.ndarray, expected) + np.testing.assert_array_almost_equal(fast_result.ndarray, + result.ndarray) def test_identity(self): """Test multiplication with identity matrix""" @@ -131,9 +141,14 @@ def test_identity(self): identity = self.SimpleArray.eye(3) result = a.matmul(identity) + fast_result = a.fast_matmul(identity) self.assertEqual(list(result.shape), [3, 3]) np.testing.assert_array_almost_equal(result.ndarray, a_data) + self.assertEqual(list(fast_result.shape), [3, 3]) + np.testing.assert_array_almost_equal(fast_result.ndarray, a_data) + np.testing.assert_array_almost_equal(fast_result.ndarray, + result.ndarray) def test_zero(self): """Test multiplication with zero matrix""" @@ -144,9 +159,14 @@ def test_zero(self): zero = self.SimpleArray(array=zero_data) result = a.matmul(zero) + fast_result = a.fast_matmul(zero) self.assertEqual(list(result.shape), [2, 2]) np.testing.assert_array_almost_equal(result.ndarray, zero_data) + self.assertEqual(list(fast_result.shape), [2, 2]) + np.testing.assert_array_almost_equal(fast_result.ndarray, zero_data) + np.testing.assert_array_almost_equal(fast_result.ndarray, + result.ndarray) def test_dimension_mismatch_error(self): """Test error handling for incompatible dimensions""" @@ -167,6 +187,12 @@ def test_dimension_mismatch_error(self): r"\(3,3\)" ): a.matmul(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: this=\(2,2\) other=" + r"\(3,3\)" + ): + a.fast_matmul(b) def test_compare_with_numpy(self): """Compare results with NumPy using fixed test data""" @@ -268,6 +294,7 @@ def test_compare_with_numpy(self): # Compute with our implementation result = a.matmul(b) + fast_result = a.fast_matmul(b) # Verify with NumPy np_result = np.matmul(a_data, b_data) @@ -275,12 +302,20 @@ def test_compare_with_numpy(self): # Compare our result with expected self.assertEqual(list(result.shape), list(expected.shape)) + self.assertEqual(list(fast_result.shape), + list(expected.shape)) + np.testing.assert_array_almost_equal(fast_result.ndarray, + result.ndarray) if self.dtype == np.float32: np.testing.assert_array_almost_equal( result.ndarray, expected, decimal=4) + np.testing.assert_array_almost_equal( + fast_result.ndarray, expected, decimal=4) else: np.testing.assert_array_almost_equal( result.ndarray, expected, decimal=10) + np.testing.assert_array_almost_equal( + fast_result.ndarray, expected, decimal=10) def test_wrong_shape_error(self): """Test error handling for wrong shapes""" @@ -298,6 +333,12 @@ def test_wrong_shape_error(self): r"this=\(2,2,2\) other=\(2,2,2\)\. SimpleArray must be 1D or 2D." ): a_3d.matmul(b_3d) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): unsupported dimensions: " + r"this=\(2,2,2\) other=\(2,2,2\)\. SimpleArray must be 1D or 2D." + ): + a_3d.fast_matmul(b_3d) a = np.zeros((3, 3), dtype=self.dtype) b = np.zeros((2, 3), dtype=self.dtype) @@ -309,6 +350,12 @@ def test_wrong_shape_error(self): r"this=\(3,3\) other=\(2,3\)" ): a.matmul(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(3,3\) other=\(2,3\)" + ): + a.fast_matmul(b) a = np.zeros((3, 3), dtype=self.dtype) b = np.zeros((2), dtype=self.dtype) @@ -320,6 +367,12 @@ def test_wrong_shape_error(self): r"this=\(3,3\) other=\(2\)" ): a.matmul(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(3,3\) other=\(2\)" + ): + a.fast_matmul(b) a = np.zeros((2), dtype=self.dtype) b = np.zeros((3, 3), dtype=self.dtype) @@ -331,6 +384,12 @@ def test_wrong_shape_error(self): r"this=\(2\) other=\(3,3\)" ): a.matmul(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(2\) other=\(3,3\)" + ): + a.fast_matmul(b) a = np.zeros((2), dtype=self.dtype) b = np.zeros((3), dtype=self.dtype) @@ -342,6 +401,12 @@ def test_wrong_shape_error(self): r"this=\(2\) other=\(3\)" ): a.matmul(b) + with self.assertRaisesRegex( + IndexError, + r"SimpleArray::matmul\(\): shape mismatch: " + r"this=\(2\) other=\(3\)" + ): + a.fast_matmul(b) def test_matmul_operator(self): """Test @ operator for matrix multiplication"""