diff --git a/include/xtensor/views/xview.hpp b/include/xtensor/views/xview.hpp index ea546d4ff..025df0ae4 100644 --- a/include/xtensor/views/xview.hpp +++ b/include/xtensor/views/xview.hpp @@ -1645,8 +1645,9 @@ namespace xt { if constexpr (lesser_condition::value) { - return sliced_access(I) + newaxis_count_before(I + 1)>( - std::get(I + 1)>(m_slices), + constexpr size_type slice_index = newaxis_skip(I); + return sliced_access(slice_index)>( + std::get(m_slices), args... ); } diff --git a/test/test_xview.cpp b/test/test_xview.cpp index 5f81047b2..c98993bf9 100644 --- a/test/test_xview.cpp +++ b/test/test_xview.cpp @@ -1591,6 +1591,45 @@ namespace xt EXPECT_EQ(a, b); } + TEST(xview, assign_through_multiple_leading_newaxis) + { + SUBCASE("updates the underlying tensor for every element") + { + xt::xtensor tensor = xt::zeros({4, 3}); + auto view = xt::view(tensor, xt::newaxis(), xt::newaxis(), xt::newaxis(), xt::all(), xt::all()); + + uint8_t value = 0; + for (std::size_t row = 0; row < 4; ++row) + { + for (std::size_t col = 0; col < 3; ++col) + { + view(std::size_t{0}, std::size_t{0}, std::size_t{0}, row, col) = value; + EXPECT_EQ(tensor(row, col), value); + ++value; + } + } + + EXPECT_EQ(tensor, xt::arange(12).reshape({4, 3})); + } + + SUBCASE("preserves bool assignment semantics") + { + xt::xtensor tensor = xt::zeros({4, 3}); + auto view = xt::view(tensor, xt::newaxis(), xt::newaxis(), xt::newaxis(), xt::all(), xt::all()); + + for (std::size_t row = 0; row < 4; ++row) + { + for (std::size_t col = 0; col < 3; ++col) + { + view(std::size_t{0}, std::size_t{0}, std::size_t{0}, row, col) = true; + EXPECT_TRUE(tensor(row, col)); + } + } + + EXPECT_EQ(tensor, xt::ones({4, 3})); + } + } + TEST(xview, in_bounds) { xt::xtensor a = {{0, 1, 2}, {3, 4, 5}};