diff --git a/include/xtensor/core/xmath.hpp b/include/xtensor/core/xmath.hpp index 77f864929..e73878e11 100644 --- a/include/xtensor/core/xmath.hpp +++ b/include/xtensor/core/xmath.hpp @@ -19,8 +19,10 @@ #include #include #include +#include #include +#include #include #include @@ -569,13 +571,93 @@ namespace xt namespace math { + namespace detail + { + template + constexpr decltype(auto) masked_data(const T& value) noexcept + { + return value; + } + + template + constexpr decltype(auto) masked_data(const xtl::xmasked_value& value) noexcept + { + return value.value(); + } + + template + constexpr bool masked_visible(const T&) noexcept + { + return true; + } + + template + constexpr bool masked_visible(const xtl::xmasked_value& value) noexcept + { + return static_cast(value.visible()); + } + + template + inline constexpr bool has_masked_value_v = (xtl::is_xmasked_value>::value || ...); + + template + using masked_data_type_t = std::decay_t()))>; + + template + using masked_common_value_type_t = xtl::promote_type_t...>; + + template + using masked_return_type_t = xtl::xmasked_value; + + template + constexpr bool all_masked_visible(const Args&... args) noexcept + { + return (masked_visible(args) && ...); + } + + template + constexpr auto hidden_masked_value() noexcept -> masked_return_type_t + { + return masked_return_type_t(T(0), false); + } + + template + constexpr auto masked_map(F&& function, const Args&... args) -> masked_return_type_t + { + if (all_masked_visible(args...)) + { + return masked_return_type_t( + static_cast(std::forward(function)(masked_data(args)...)), + true + ); + } + + return hidden_masked_value(); + } + } + template struct minimum { template constexpr auto operator()(const A1& t1, const A2& t2) const noexcept { - return xtl::select(t1 < t2, t1, t2); + if constexpr (detail::has_masked_value_v) + { + using value_type = detail::masked_common_value_type_t; + return detail::masked_map( + [](const auto& lhs, const auto& rhs) + { + return lhs < rhs ? lhs : rhs; + }, + t1, + t2 + ); + } + else + { + return xtl::select(t1 < t2, t1, t2); + } } template @@ -591,7 +673,22 @@ namespace xt template constexpr auto operator()(const A1& t1, const A2& t2) const noexcept { - return xtl::select(t1 > t2, t1, t2); + if constexpr (detail::has_masked_value_v) + { + using value_type = detail::masked_common_value_type_t; + return detail::masked_map( + [](const auto& lhs, const auto& rhs) + { + return lhs > rhs ? lhs : rhs; + }, + t1, + t2 + ); + } + else + { + return xtl::select(t1 > t2, t1, t2); + } } template @@ -606,7 +703,23 @@ namespace xt template constexpr auto operator()(const A1& v, const A2& lo, const A3& hi) const { - return xtl::select(v < lo, lo, xtl::select(hi < v, hi, v)); + if constexpr (detail::has_masked_value_v) + { + using value_type = detail::masked_common_value_type_t; + return detail::masked_map( + [](const auto& value, const auto& lower, const auto& upper) + { + return value < lower ? lower : (upper < value ? upper : value); + }, + v, + lo, + hi + ); + } + else + { + return xtl::select(v < lo, lo, xtl::select(hi < v, hi, v)); + } } template @@ -618,16 +731,29 @@ namespace xt struct deg2rad { - template ::value, int> = 0> - constexpr double operator()(const A& a) const noexcept - { - return a * xt::numeric_constants::PI / 180.0; - } - - template ::value, int> = 0> + template constexpr auto operator()(const A& a) const noexcept { - return a * xt::numeric_constants::PI / A(180.0); + if constexpr (detail::has_masked_value_v) + { + using data_type = detail::masked_data_type_t; + using result_type = std::conditional_t::value, double, data_type>; + return detail::masked_map( + [](const auto& value) + { + return value * xt::numeric_constants::PI / result_type(180.0); + }, + a + ); + } + else if constexpr (xtl::is_integral::value) + { + return a * xt::numeric_constants::PI / 180.0; + } + else + { + return a * xt::numeric_constants::PI / A(180.0); + } } template ::value, int> = 0> @@ -645,16 +771,29 @@ namespace xt struct rad2deg { - template ::value, int> = 0> - constexpr double operator()(const A& a) const noexcept - { - return a * 180.0 / xt::numeric_constants::PI; - } - - template ::value, int> = 0> + template constexpr auto operator()(const A& a) const noexcept { - return a * A(180.0) / xt::numeric_constants::PI; + if constexpr (detail::has_masked_value_v) + { + using data_type = detail::masked_data_type_t; + using result_type = std::conditional_t::value, double, data_type>; + return detail::masked_map( + [](const auto& value) + { + return value * result_type(180.0) / xt::numeric_constants::PI; + }, + a + ); + } + else if constexpr (xtl::is_integral::value) + { + return a * 180.0 / xt::numeric_constants::PI; + } + else + { + return a * A(180.0) / xt::numeric_constants::PI; + } } template ::value, int> = 0> @@ -858,7 +997,22 @@ namespace xt template constexpr auto operator()(const T& x) const { - return sign_impl::run(x); + if constexpr (detail::has_masked_value_v) + { + using data_type = detail::masked_data_type_t; + using result_type = std::decay_t::run(detail::masked_data(x)))>; + return detail::masked_map( + [](const auto& value) + { + return sign_impl::run(value); + }, + x + ); + } + else + { + return sign_impl::run(x); + } } }; } @@ -1031,6 +1185,19 @@ namespace xt { }; + template + inline decltype(auto) lambda_argument(T&& value) + { + if constexpr (xtl::is_xmasked_value>::value) + { + return +value; + } + else + { + return std::forward(value); + } + } + template struct lambda_adapt { @@ -1040,15 +1207,15 @@ namespace xt } template - auto operator()(T... args) const + auto operator()(T&&... args) const { - return m_lambda(args...); + return m_lambda(lambda_argument(std::forward(args))...); } template )> - auto simd_apply(T... args) const + auto simd_apply(T&&... args) const { - return m_lambda(args...); + return m_lambda(lambda_argument(std::forward(args))...); } F m_lambda; @@ -1171,10 +1338,11 @@ namespace xt struct pow_impl { template - auto operator()(T v) const -> decltype(v * v) + auto operator()(T&& v) const { - T temp = pow_impl{}(v); - return temp * temp * pow_impl{}(v); + auto value = lambda_argument(std::forward(v)); + auto temp = pow_impl{}(value); + return temp * temp * pow_impl{}(value); } }; @@ -1182,9 +1350,9 @@ namespace xt struct pow_impl<1> { template - auto operator()(T v) const -> T + decltype(auto) operator()(T&& v) const { - return v; + return lambda_argument(std::forward(v)); } }; @@ -1192,9 +1360,10 @@ namespace xt struct pow_impl<0> { template - auto operator()(T /*v*/) const -> T + auto operator()(T&& v) const { - return T(1); + using value_type = std::decay_t(v)))>; + return value_type(1); } }; } diff --git a/include/xtensor/io/xio.hpp b/include/xtensor/io/xio.hpp index fbc0cd3a0..d0a99f4ed 100644 --- a/include/xtensor/io/xio.hpp +++ b/include/xtensor/io/xio.hpp @@ -185,6 +185,18 @@ namespace xt namespace detail { + template + inline auto printable_value(const xtl::xmasked_value& value) + { + return +value; + } + + template + inline const T& printable_value(const T& value) + { + return value; + } + template std::ostream& xoutput( std::ostream& out, @@ -646,7 +658,7 @@ namespace xt void update(const_reference val) { std::stringstream buf; - buf << val; + buf << printable_value(val); std::string s = buf.str(); if (int(s.size()) > m_width) { diff --git a/include/xtensor/views/xmasked_view.hpp b/include/xtensor/views/xmasked_view.hpp index 05210d8c8..f2ca32d33 100644 --- a/include/xtensor/views/xmasked_view.hpp +++ b/include/xtensor/views/xmasked_view.hpp @@ -21,6 +21,19 @@ namespace xt { + namespace detail + { + template + struct xmasked_view_strides + { + using fallback_type = get_strides_type; + using strides_type = xtl::mpl::eval_if_t, expr_strides_type, fallback_type>; + using backstrides_type = xtl::mpl::eval_if_t, expr_backstrides_type, fallback_type>; + using inner_strides_type = xtl::mpl::eval_if_t, expr_inner_strides_type, fallback_type>; + using inner_backstrides_type = xtl::mpl::eval_if_t, expr_inner_backstrides_type, fallback_type>; + }; + } + /**************************** * xmasked_view declaration * *****************************/ @@ -118,14 +131,16 @@ namespace xt using bool_load_type = xtl::xmasked_value; using shape_type = typename data_type::shape_type; - using strides_type = typename data_type::strides_type; + using strides_helper = detail::xmasked_view_strides; + using strides_type = typename strides_helper::strides_type; + using backstrides_type = typename strides_helper::backstrides_type; static constexpr layout_type static_layout = data_type::static_layout; static constexpr bool contiguous_layout = false; using inner_shape_type = typename data_type::inner_shape_type; - using inner_strides_type = typename data_type::inner_strides_type; - using inner_backstrides_type = typename data_type::inner_backstrides_type; + using inner_strides_type = typename strides_helper::inner_strides_type; + using inner_backstrides_type = typename strides_helper::inner_backstrides_type; using expression_tag = xtensor_expression_tag; @@ -163,7 +178,12 @@ namespace xt size_type size() const noexcept; const inner_shape_type& shape() const noexcept; + template + requires has_strides
::value const inner_strides_type& strides() const noexcept; + + template + requires has_strides
::value const inner_backstrides_type& backstrides() const noexcept; using accessible_base::dimension; using accessible_base::shape; @@ -202,6 +222,9 @@ namespace xt template bool has_linear_assign(const S& strides) const noexcept; + template + bool broadcast_shape(S& shape, bool reuse_cache = false) const; + data_type& value() noexcept; const data_type& value() const noexcept; @@ -338,6 +361,8 @@ namespace xt * Returns the strides of the xmasked_view. */ template + template + requires has_strides
::value inline auto xmasked_view::strides() const noexcept -> const inner_strides_type& { return m_data.strides(); @@ -347,6 +372,8 @@ namespace xt * Returns the backstrides of the xmasked_view. */ template + template + requires has_strides
::value inline auto xmasked_view::backstrides() const noexcept -> const inner_backstrides_type& { return m_data.backstrides(); @@ -370,6 +397,13 @@ namespace xt return false; } + template + template + inline bool xmasked_view::broadcast_shape(S& shape, bool) const + { + return xt::broadcast_shape(m_data.shape(), shape); + } + /** * Fills the data with the given value. * @param value the value to fill the data with. diff --git a/test/test_utils.hpp b/test/test_utils.hpp index d196ce63c..9855d059b 100644 --- a/test/test_utils.hpp +++ b/test/test_utils.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include "xtensor/core/xexpression.hpp" @@ -95,6 +96,20 @@ namespace xt } return res; } + + template + std::string stream_output(const E& expression) + { + std::stringstream out; + out << expression; + return out.str(); + } + + template + bool has_stream_output(const E& expression) + { + return !stream_output(expression).empty(); + } } #endif diff --git a/test/test_xmasked_view.cpp b/test/test_xmasked_view.cpp index 14613d169..382bf9261 100644 --- a/test/test_xmasked_view.cpp +++ b/test/test_xmasked_view.cpp @@ -7,6 +7,9 @@ * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ +#include + +#include "xtensor/core/xmath.hpp" #include "xtensor/io/xio.hpp" #include "xtensor/optional/xoptional_assembly.hpp" #include "xtensor/views/xmasked_view.hpp" @@ -210,6 +213,22 @@ namespace xt EXPECT_EQ(data, expected2); } + TEST(xmasked_view, lazy_expression_stream) + { + using array_type = xarray; + const array_type a = {1., 1., 1., 1.}; + const array_type b = {0.1, 0.7, 0.3, 0.9}; + + const auto mask = b < 0.5; + + EXPECT_TRUE(has_stream_output(minimum(masked_view(a, mask), masked_view(b, mask)))); + EXPECT_TRUE(has_stream_output(masked_view(minimum(a, b), mask))); + EXPECT_TRUE(has_stream_output(maximum(masked_view(a, mask), masked_view(b, mask)))); + EXPECT_TRUE(has_stream_output(masked_view(maximum(a, b), mask))); + EXPECT_TRUE(has_stream_output(clip(masked_view(a, mask), 0.2, 0.8))); + EXPECT_TRUE(has_stream_output(masked_view(clip(a, 0.2, 0.8), mask))); + } + TEST(xmasked_view, assign) { xarray data = {{1., -2., 3.}, {4., 5., -6.}, {7., 8., -9.}}; @@ -223,6 +242,21 @@ namespace xt EXPECT_EQ(data, expected1); } + TEST(xmasked_view, assign_const_masked_view_rhs) + { + xarray data = {{1., -2., 3.}, {4., 5., -6.}, {7., 8., -9.}}; + const xarray data2 = {{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}}; + xarray mask = {{true, true, true}, {true, false, false}, {true, false, true}}; + + auto masked_data = masked_view(data, mask); + const auto masked_data2 = masked_view(data2, mask); + + masked_data = masked_data2; + + xarray expected = {{0.1, 0.2, 0.3}, {0.4, 5., -6.}, {0.7, 8., 0.9}}; + EXPECT_EQ(data, expected); + } + TEST(xmasked_view, view) { xt::xarray data = {{0, 1}, {2, 3}, {4, 5}}; diff --git a/test/test_xmath.cpp b/test/test_xmath.cpp index 8f56e93b9..c2e536792 100644 --- a/test/test_xmath.cpp +++ b/test/test_xmath.cpp @@ -14,7 +14,9 @@ #include "xtensor/containers/xarray.hpp" #include "xtensor/core/xmath.hpp" #include "xtensor/generators/xrandom.hpp" +#include "xtensor/io/xio.hpp" #include "xtensor/optional/xoptional_assembly.hpp" +#include "xtensor/views/xmasked_view.hpp" #include "test_common_macros.hpp" @@ -23,6 +25,42 @@ namespace xt using std::size_t; using shape_type = dynamic_shape; + template + void expect_streamable(const E& expression) + { + EXPECT_TRUE(has_stream_output(expression)); + } + + template + void expect_masked_unary_stream(F&& function, const D& data, const M& mask) + { + expect_streamable(function(masked_view(data, mask))); + expect_streamable(masked_view(function(data), mask)); + } + + template + void expect_masked_binary_stream(F&& function, const D1& lhs, const D2& rhs, const M& mask) + { + expect_streamable(function(masked_view(lhs, mask), masked_view(rhs, mask))); + expect_streamable(masked_view(function(lhs, rhs), mask)); + } + + template + void + expect_masked_ternary_stream(F&& function, const D1& arg1, const D2& arg2, const D3& arg3, const M& mask) + { + expect_streamable(function(masked_view(arg1, mask), masked_view(arg2, mask), masked_view(arg3, mask))); + expect_streamable(masked_view(function(arg1, arg2, arg3), mask)); + } + + template + void + expect_masked_ternary_scalar_stream(F&& function, const D& data, const T1& arg2, const T2& arg3, const M& mask) + { + expect_streamable(function(masked_view(data, mask), arg2, arg3)); + expect_streamable(masked_view(function(data, arg2, arg3), mask)); + } + /******************** * Basic operations * ********************/ @@ -222,6 +260,517 @@ namespace xt EXPECT_EQ(res1, clip(opt_a, 2.0, 4.0)); } + TEST(xmath, masked_view_lazy_expressions) + { + using array_type = xarray; + + const array_type a = {1., 1., 1., 1.}; + const array_type b = {0.1, 0.7, 0.3, 0.9}; + const auto mask = b < 0.5; + + const auto expected_min = eval(masked_view(minimum(a, b), mask)); + const auto expected_max = eval(masked_view(maximum(a, b), mask)); + const auto expected_clip = eval(masked_view(clip(a, 0.2, 0.8), mask)); + + EXPECT_EQ(expected_min, eval(minimum(masked_view(a, mask), masked_view(b, mask)))); + EXPECT_EQ(expected_max, eval(maximum(masked_view(a, mask), masked_view(b, mask)))); + EXPECT_EQ(expected_clip, eval(clip(masked_view(a, mask), 0.2, 0.8))); + } + + TEST(xmath, masked_view_lazy_unary_math_functions) + { + const xarray mask = {true, false, true, false}; + const xarray positive = {1.25, 1.5, 1.75, 2.0}; + const xarray unit = {-0.75, -0.25, 0.25, 0.75}; + const xarray signed_values = {-1.8, -0.2, 0.2, 1.8}; + const xarray special = { + 1.0, + std::numeric_limits::infinity(), + std::numeric_limits::quiet_NaN(), + -std::numeric_limits::infinity() + }; + + expect_masked_unary_stream( + [](const auto& e) + { + return abs(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return fabs(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return exp(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return exp2(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return expm1(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return log(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return log10(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return log2(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return log1p(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return sqrt(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return cbrt(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return sin(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return cos(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return tan(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return asin(e); + }, + unit, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return acos(e); + }, + unit, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return atan(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return sinh(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return cosh(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return tanh(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return asinh(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return acosh(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return atanh(e); + }, + unit, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return erf(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return erfc(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return tgamma(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return lgamma(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return ceil(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return floor(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return trunc(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return round(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return nearbyint(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return rint(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return isfinite(e); + }, + special, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return isinf(e); + }, + special, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return isnan(e); + }, + special, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return sign(e); + }, + signed_values, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return deg2rad(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return radians(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return rad2deg(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return degrees(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return square(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return cube(e); + }, + positive, + mask + ); + expect_masked_unary_stream( + [](const auto& e) + { + return pow<3>(e); + }, + positive, + mask + ); + } + + TEST(xmath, masked_view_lazy_binary_math_functions) + { + const xarray mask = {true, false, true, false}; + const xarray lhs = {1.25, 1.5, 1.75, 2.0}; + const xarray rhs = {0.5, 0.75, 1.25, 1.5}; + + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return fmod(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return remainder(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return fmax(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return fmin(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return fdim(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return pow(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return hypot(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return atan2(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return minimum(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + expect_masked_binary_stream( + [](const auto& lhs_expr, const auto& rhs_expr) + { + return maximum(lhs_expr, rhs_expr); + }, + lhs, + rhs, + mask + ); + } + + TEST(xmath, masked_view_lazy_ternary_math_functions) + { + const xarray mask = {true, false, true, false}; + const xarray a = {1.25, 1.5, 1.75, 2.0}; + const xarray b = {0.5, 0.75, 1.25, 1.5}; + const xarray c = {2.0, 2.0, 2.0, 2.0}; + + expect_masked_ternary_stream( + [](const auto& arg1_expr, const auto& arg2_expr, const auto& arg3_expr) + { + return fma(arg1_expr, arg2_expr, arg3_expr); + }, + a, + b, + c, + mask + ); + expect_masked_ternary_scalar_stream( + [](const auto& data_expr, const auto& lower, const auto& upper) + { + return clip(data_expr, lower, upper); + }, + a, + 0.75, + 1.8, + mask + ); + } + TEST(xmath, sign) { shape_type shape = {3, 2};