Skip to content
231 changes: 200 additions & 31 deletions include/xtensor/core/xmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include <cmath>
#include <complex>
#include <type_traits>
#include <utility>

#include <xtl/xcomplex.hpp>
#include <xtl/xmasked_value.hpp>
#include <xtl/xsequence.hpp>
#include <xtl/xtype_traits.hpp>

Expand Down Expand Up @@ -569,13 +571,93 @@ namespace xt

namespace math
{
namespace detail
{
template <typename T>
constexpr decltype(auto) masked_data(const T& value) noexcept
{
return value;
}

template <typename T, typename B>
constexpr decltype(auto) masked_data(const xtl::xmasked_value<T, B>& value) noexcept
{
return value.value();
}

template <typename T>
constexpr bool masked_visible(const T&) noexcept
{
return true;
}

template <typename T, typename B>
constexpr bool masked_visible(const xtl::xmasked_value<T, B>& value) noexcept
{
return static_cast<bool>(value.visible());
}

template <class... Args>
inline constexpr bool has_masked_value_v = (xtl::is_xmasked_value<std::decay_t<Args>>::value || ...);

template <class T>
using masked_data_type_t = std::decay_t<decltype(masked_data(std::declval<const T&>()))>;

template <class... Args>
using masked_common_value_type_t = xtl::promote_type_t<masked_data_type_t<Args>...>;

template <class T>
using masked_return_type_t = xtl::xmasked_value<T, bool>;

template <class... Args>
constexpr bool all_masked_visible(const Args&... args) noexcept
{
return (masked_visible(args) && ...);
}

template <class T>
constexpr auto hidden_masked_value() noexcept -> masked_return_type_t<T>
{
return masked_return_type_t<T>(T(0), false);
}

template <class Result, class F, class... Args>
constexpr auto masked_map(F&& function, const Args&... args) -> masked_return_type_t<Result>
{
if (all_masked_visible(args...))
{
return masked_return_type_t<Result>(
static_cast<Result>(std::forward<F>(function)(masked_data(args)...)),
true
);
}

return hidden_masked_value<Result>();
}
}

template <class T = void>
struct minimum
{
template <class A1, class A2>
constexpr auto operator()(const A1& t1, const A2& t2) const noexcept
{
return xtl::select(t1 < t2, t1, t2);
if constexpr (detail::has_masked_value_v<A1, A2>)
{
using value_type = detail::masked_common_value_type_t<A1, A2>;
return detail::masked_map<value_type>(
[](const auto& lhs, const auto& rhs)
{
return lhs < rhs ? lhs : rhs;
},
t1,
t2
);
}
else
{
return xtl::select(t1 < t2, t1, t2);
}
}

template <class A1, class A2>
Expand All @@ -591,7 +673,22 @@ namespace xt
template <class A1, class A2>
constexpr auto operator()(const A1& t1, const A2& t2) const noexcept
{
return xtl::select(t1 > t2, t1, t2);
if constexpr (detail::has_masked_value_v<A1, A2>)
{
using value_type = detail::masked_common_value_type_t<A1, A2>;
return detail::masked_map<value_type>(
[](const auto& lhs, const auto& rhs)
{
return lhs > rhs ? lhs : rhs;
},
t1,
t2
);
}
else
{
return xtl::select(t1 > t2, t1, t2);
}
}

template <class A1, class A2>
Expand All @@ -606,7 +703,23 @@ namespace xt
template <class A1, class A2, class A3>
constexpr auto operator()(const A1& v, const A2& lo, const A3& hi) const
{
return xtl::select(lo < hi, xtl::select(v < lo, lo, xtl::select(hi < v, hi, v)), hi);
if constexpr (detail::has_masked_value_v<A1, A2, A3>)
{
using value_type = detail::masked_common_value_type_t<A1, A2, A3>;
return detail::masked_map<value_type>(
[](const auto& value, const auto& lower, const auto& upper)
{
return value < lower ? lower : (upper < value ? upper : value);
},
v,
lo,
hi
);
}
else
{
return xtl::select(v < lo, lo, xtl::select(hi < v, hi, v));
}
}

template <class A1, class A2, class A3>
Expand All @@ -618,16 +731,29 @@ namespace xt

struct deg2rad
{
template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
constexpr double operator()(const A& a) const noexcept
{
return a * xt::numeric_constants<double>::PI / 180.0;
}

template <class A, std::enable_if_t<std::is_floating_point<A>::value, int> = 0>
template <class A>
constexpr auto operator()(const A& a) const noexcept
{
return a * xt::numeric_constants<A>::PI / A(180.0);
if constexpr (detail::has_masked_value_v<A>)
{
using data_type = detail::masked_data_type_t<A>;
using result_type = std::conditional_t<xtl::is_integral<data_type>::value, double, data_type>;
return detail::masked_map<result_type>(
[](const auto& value)
{
return value * xt::numeric_constants<result_type>::PI / result_type(180.0);
},
a
);
}
else if constexpr (xtl::is_integral<A>::value)
{
return a * xt::numeric_constants<double>::PI / 180.0;
}
else
{
return a * xt::numeric_constants<A>::PI / A(180.0);
}
}

template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
Expand All @@ -645,16 +771,29 @@ namespace xt

struct rad2deg
{
template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
constexpr double operator()(const A& a) const noexcept
{
return a * 180.0 / xt::numeric_constants<double>::PI;
}

template <class A, std::enable_if_t<std::is_floating_point<A>::value, int> = 0>
template <class A>
constexpr auto operator()(const A& a) const noexcept
{
return a * A(180.0) / xt::numeric_constants<A>::PI;
if constexpr (detail::has_masked_value_v<A>)
{
using data_type = detail::masked_data_type_t<A>;
using result_type = std::conditional_t<xtl::is_integral<data_type>::value, double, data_type>;
return detail::masked_map<result_type>(
[](const auto& value)
{
return value * result_type(180.0) / xt::numeric_constants<result_type>::PI;
},
a
);
}
else if constexpr (xtl::is_integral<A>::value)
{
return a * 180.0 / xt::numeric_constants<double>::PI;
}
else
{
return a * A(180.0) / xt::numeric_constants<A>::PI;
}
}

template <class A, std::enable_if_t<xtl::is_integral<A>::value, int> = 0>
Expand Down Expand Up @@ -858,7 +997,22 @@ namespace xt
template <class T>
constexpr auto operator()(const T& x) const
{
return sign_impl<T>::run(x);
if constexpr (detail::has_masked_value_v<T>)
{
using data_type = detail::masked_data_type_t<T>;
using result_type = std::decay_t<decltype(sign_impl<data_type>::run(detail::masked_data(x)))>;
return detail::masked_map<result_type>(
[](const auto& value)
{
return sign_impl<data_type>::run(value);
},
x
);
}
else
{
return sign_impl<T>::run(x);
}
}
};
}
Expand Down Expand Up @@ -1031,6 +1185,19 @@ namespace xt
{
};

template <typename T>
inline decltype(auto) lambda_argument(T&& value)
{
if constexpr (xtl::is_xmasked_value<std::decay_t<T>>::value)
{
return +value;
}
else
{
return std::forward<T>(value);
}
}

template <class F>
struct lambda_adapt
{
Expand All @@ -1040,15 +1207,15 @@ namespace xt
}

template <class... T>
auto operator()(T... args) const
auto operator()(T&&... args) const
{
return m_lambda(args...);
return m_lambda(lambda_argument(std::forward<T>(args))...);
}

template <class... T, XTL_REQUIRES(detail::supports<F(T...)>)>
auto simd_apply(T... args) const
auto simd_apply(T&&... args) const
{
return m_lambda(args...);
return m_lambda(lambda_argument(std::forward<T>(args))...);
}

F m_lambda;
Expand Down Expand Up @@ -1136,30 +1303,32 @@ namespace xt
struct pow_impl
{
template <class T>
auto operator()(T v) const -> decltype(v * v)
auto operator()(T&& v) const
{
T temp = pow_impl<N / 2>{}(v);
return temp * temp * pow_impl<N & 1>{}(v);
auto value = lambda_argument(std::forward<T>(v));
auto temp = pow_impl<N / 2>{}(value);
return temp * temp * pow_impl<N & 1>{}(value);
}
};

template <>
struct pow_impl<1>
{
template <class T>
auto operator()(T v) const -> T
decltype(auto) operator()(T&& v) const
{
return v;
return lambda_argument(std::forward<T>(v));
}
};

template <>
struct pow_impl<0>
{
template <class T>
auto operator()(T /*v*/) const -> T
auto operator()(T&& v) const
{
return T(1);
using value_type = std::decay_t<decltype(lambda_argument(std::forward<T>(v)))>;
return value_type(1);
}
};
}
Expand Down
21 changes: 13 additions & 8 deletions include/xtensor/io/xio.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,18 @@ namespace xt

namespace detail
{
template <typename T, typename B>
inline auto printable_value(const xtl::xmasked_value<T, B>& value)
{
return +value;
}

template <typename T>
inline const T& printable_value(const T& value)
{
return value;
}

template <class E, class F>
std::ostream& xoutput(
std::ostream& out,
Expand Down Expand Up @@ -647,14 +659,7 @@ namespace xt
void update(const_reference val)
{
std::stringstream buf;
if constexpr (xtl::is_xmasked_value<value_type>::value)
{
buf << +val;
}
else
{
buf << val;
}
buf << printable_value(val);
std::string s = buf.str();
if (int(s.size()) > m_width)
{
Expand Down
Loading
Loading