Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 200 additions & 31 deletions include/xtensor/core/xmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include <cmath>
#include <complex>
#include <type_traits>
#include <utility>

#include <xtl/xcomplex.hpp>
#include <xtl/xmasked_value.hpp>
#include <xtl/xsequence.hpp>
#include <xtl/xtype_traits.hpp>

Expand Down Expand Up @@ -569,13 +571,93 @@ namespace xt

namespace math
{
namespace detail
{
template <typename T>
constexpr decltype(auto) masked_data(const T& value) noexcept
{
return value;
}

template <typename T, typename B>
constexpr decltype(auto) masked_data(const xtl::xmasked_value<T, B>& value) noexcept
{
return value.value();
}

template <typename T>
constexpr bool masked_visible(const T&) noexcept
{
return true;
}

template <typename T, typename B>
constexpr bool masked_visible(const xtl::xmasked_value<T, B>& value) noexcept
{
return static_cast<bool>(value.visible());
}

template <class... Args>
inline constexpr bool has_masked_value_v = (xtl::is_xmasked_value<std::decay_t<Args>>::value || ...);

template <class T>
using masked_data_type_t = std::decay_t<decltype(masked_data(std::declval<const T&>()))>;

template <class... Args>
using masked_common_value_type_t = xtl::promote_type_t<masked_data_type_t<Args>...>;

template <class T>
using masked_return_type_t = xtl::xmasked_value<T, bool>;

template <class... Args>
constexpr bool all_masked_visible(const Args&... args) noexcept
{
return (masked_visible(args) && ...);
}

template <class T>
constexpr auto hidden_masked_value() noexcept -> masked_return_type_t<T>
{
return masked_return_type_t<T>(T(0), false);
}

template <class Result, class F, class... Args>
constexpr auto masked_map(F&& function, const Args&... args) -> masked_return_type_t<Result>
{
if (all_masked_visible(args...))
{
return masked_return_type_t<Result>(
static_cast<Result>(std::forward<F>(function)(masked_data(args)...)),
true
);
}

return hidden_masked_value<Result>();
}
}

template <class T = void>
struct minimum
{
template <class A1, class A2>
constexpr auto operator()(const A1& t1, const A2& t2) const noexcept
{
return xtl::select(t1 < t2, t1, t2);
if constexpr (detail::has_masked_value_v<A1, A2>)
{
using value_type = detail::masked_common_value_type_t<A1, A2>;
return detail::masked_map<value_type>(
[](const auto& lhs, const auto& rhs)
{
return lhs < rhs ? lhs : rhs;
},
t1,
t2
);
}
else
{
return xtl::select(t1 < t2, t1, t2);
}
}

template <class A1, class A2>
Expand All @@ -591,7 +673,22 @@ namespace xt
template <class A1, class A2>
constexpr auto operator()(const A1& t1, const A2& t2) const noexcept
{
return xtl::select(t1 > t2, t1, t2);
if constexpr (detail::has_masked_value_v<A1, A2>)
{
using value_type = detail::masked_common_value_type_t<A1, A2>;
return detail::masked_map<value_type>(
[](const auto& lhs, const auto& rhs)
{
return lhs > rhs ? lhs : rhs;
},
t1,
t2
);
}
else
{
return xtl::select(t1 > t2, t1, t2);
}
}

template <class A1, class A2>
Expand All @@ -606,7 +703,23 @@ namespace xt
template <class A1, class A2, class A3>
constexpr auto operator()(const A1& v, const A2& lo, const A3& hi) const
{
return xtl::select(v < lo, lo, xtl::select(hi < v, hi, v));
if constexpr (detail::has_masked_value_v<A1, A2, A3>)
{
using value_type = detail::masked_common_value_type_t<A1, A2, A3>;
return detail::masked_map<value_type>(
[](const auto& value, const auto& lower, const auto& upper)
{
return value < lower ? lower : (upper < value ? upper : value);
},
v,
lo,
hi
);
}
else
{
return xtl::select(v < lo, lo, xtl::select(hi < v, hi, v));
}
}

template <class A1, class A2, class A3>
Expand All @@ -618,16 +731,29 @@ namespace xt

struct deg2rad
{
template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
constexpr double operator()(const A& a) const noexcept
{
return a * xt::numeric_constants<double>::PI / 180.0;
}

template <class A, std::enable_if_t<std::is_floating_point<A>::value, int> = 0>
template <class A>
constexpr auto operator()(const A& a) const noexcept
{
return a * xt::numeric_constants<A>::PI / A(180.0);
if constexpr (detail::has_masked_value_v<A>)
{
using data_type = detail::masked_data_type_t<A>;
using result_type = std::conditional_t<xtl::is_integral<data_type>::value, double, data_type>;
return detail::masked_map<result_type>(
[](const auto& value)
{
return value * xt::numeric_constants<result_type>::PI / result_type(180.0);
},
a
);
}
else if constexpr (xtl::is_integral<A>::value)
{
return a * xt::numeric_constants<double>::PI / 180.0;
}
else
{
return a * xt::numeric_constants<A>::PI / A(180.0);
}
}

template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
Expand All @@ -645,16 +771,29 @@ namespace xt

struct rad2deg
{
template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
constexpr double operator()(const A& a) const noexcept
{
return a * 180.0 / xt::numeric_constants<double>::PI;
}

template <class A, std::enable_if_t<std::is_floating_point<A>::value, int> = 0>
template <class A>
constexpr auto operator()(const A& a) const noexcept
{
return a * A(180.0) / xt::numeric_constants<A>::PI;
if constexpr (detail::has_masked_value_v<A>)
{
using data_type = detail::masked_data_type_t<A>;
using result_type = std::conditional_t<xtl::is_integral<data_type>::value, double, data_type>;
return detail::masked_map<result_type>(
[](const auto& value)
{
return value * result_type(180.0) / xt::numeric_constants<result_type>::PI;
},
a
);
}
else if constexpr (xtl::is_integral<A>::value)
{
return a * 180.0 / xt::numeric_constants<double>::PI;
}
else
{
return a * A(180.0) / xt::numeric_constants<A>::PI;
}
}

template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
Expand Down Expand Up @@ -858,7 +997,22 @@ namespace xt
template <class T>
constexpr auto operator()(const T& x) const
{
return sign_impl<T>::run(x);
if constexpr (detail::has_masked_value_v<T>)
{
using data_type = detail::masked_data_type_t<T>;
using result_type = std::decay_t<decltype(sign_impl<data_type>::run(detail::masked_data(x)))>;
return detail::masked_map<result_type>(
[](const auto& value)
{
return sign_impl<data_type>::run(value);
},
x
);
}
else
{
return sign_impl<T>::run(x);
}
}
};
}
Expand Down Expand Up @@ -1031,6 +1185,19 @@ namespace xt
{
};

template <typename T>
inline decltype(auto) lambda_argument(T&& value)
{
if constexpr (xtl::is_xmasked_value<std::decay_t<T>>::value)
{
return +value;
}
else
{
return std::forward<T>(value);
}
}

template <class F>
struct lambda_adapt
{
Expand All @@ -1040,15 +1207,15 @@ namespace xt
}

template <class... T>
auto operator()(T... args) const
auto operator()(T&&... args) const
{
return m_lambda(args...);
return m_lambda(lambda_argument(std::forward<T>(args))...);
}

template <class... T, XTL_REQUIRES(detail::supports<F(T...)>)>
auto simd_apply(T... args) const
auto simd_apply(T&&... args) const
{
return m_lambda(args...);
return m_lambda(lambda_argument(std::forward<T>(args))...);
}

F m_lambda;
Expand Down Expand Up @@ -1171,30 +1338,32 @@ namespace xt
struct pow_impl
{
template <class T>
auto operator()(T v) const -> decltype(v * v)
auto operator()(T&& v) const
{
T temp = pow_impl<N / 2>{}(v);
return temp * temp * pow_impl<N & 1>{}(v);
auto value = lambda_argument(std::forward<T>(v));
auto temp = pow_impl<N / 2>{}(value);
return temp * temp * pow_impl<N & 1>{}(value);
}
};

template <>
struct pow_impl<1>
{
template <class T>
auto operator()(T v) const -> T
decltype(auto) operator()(T&& v) const
{
return v;
return lambda_argument(std::forward<T>(v));
}
};

template <>
struct pow_impl<0>
{
template <class T>
auto operator()(T /*v*/) const -> T
auto operator()(T&& v) const
{
return T(1);
using value_type = std::decay_t<decltype(lambda_argument(std::forward<T>(v)))>;
return value_type(1);
}
};
}
Expand Down
14 changes: 13 additions & 1 deletion include/xtensor/io/xio.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ namespace xt

namespace detail
{
template <typename T, typename B>
inline auto printable_value(const xtl::xmasked_value<T, B>& value)
{
return +value;
}

template <typename T>
inline const T& printable_value(const T& value)
{
return value;
}

template <class E, class F>
std::ostream& xoutput(
std::ostream& out,
Expand Down Expand Up @@ -646,7 +658,7 @@ namespace xt
void update(const_reference val)
{
std::stringstream buf;
buf << val;
buf << printable_value(val);
std::string s = buf.str();
if (int(s.size()) > m_width)
{
Expand Down
Loading
Loading