Skip to content

Commit 491158f

Browse files
committed
Moved re-ordering logic into assignment
1 parent ae52796 commit 491158f

File tree

1 file changed

+57
-3
lines changed

1 file changed

+57
-3
lines changed

include/xtensor/xassign.hpp

+57-3
Original file line numberDiff line numberDiff line change
@@ -439,16 +439,70 @@ namespace xt
439439
using requested_value_type = detail::conditional_promote_to_complex_t<e1_value_type, e2_requested_value_type>;
440440
};
441441

442+
/**********************************
443+
* Expression Order Optimizations *
444+
**********************************/
445+
446+
class optimize_expression
447+
{
448+
private:
449+
450+
template <class E1, class E2>
451+
struct equal_rank
452+
{
453+
static constexpr bool value = get_rank<E1>::value == get_rank<E2>::value;
454+
};
455+
456+
template <class E1, class... E>
457+
struct all_equal_rank
458+
{
459+
static constexpr bool value = xtl::conjunction<equal_rank<E1, E>...>::value
460+
&& (get_rank<E1>::value != SIZE_MAX);
461+
};
462+
463+
template <class F, class... CT, class... S, size_t... I, size_t... J>
464+
inline auto
465+
impl_reorder_function(const xfunction<F, CT...>& e, std::tuple<S...> slices, std::index_sequence<I...>, std::index_sequence<J...>)
466+
{
467+
return make_lambda_xfunction(F(), view(std::get<I>(e.arguments()), std::get<J>(slices)...)...);
468+
}
469+
470+
public:
471+
472+
// when we have a view of a function where the closures of the functions are of equal rank (i.e no
473+
// broadcasting) we can flip the order of the function and the view such that we have a function of
474+
// views of containers which can be linearly assigned unlike the inverse.
475+
template <class F, class... CT, class... S, class = std::enable_if_t<all_equal_rank<std::decay_t<CT>...>::value>>
476+
inline auto reorder(const xview<xfunction<F, CT...>, S...>& e)
477+
{
478+
return impl_reorder_function(
479+
e.expression(),
480+
e.slices(),
481+
std::make_index_sequence<sizeof...(CT)>(),
482+
std::make_index_sequence<sizeof...(S)>()
483+
);
484+
}
485+
486+
// base case no applicable optimization
487+
template <class E>
488+
inline auto& reorder(E&& e)
489+
{
490+
return std::forward<E>(e);
491+
}
492+
};
493+
442494
template <class E1, class E2>
443495
inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
444496
xexpression<E1>& e1,
445497
const xexpression<E2>& e2,
446498
bool trivial
447499
)
448500
{
449-
E1& de1 = e1.derived_cast();
450-
const E2& de2 = e2.derived_cast();
451-
using traits = xassign_traits<E1, E2>;
501+
auto& de1 = e1.derived_cast();
502+
const auto& de2 = optimize_expression().reorder(e2.derived_cast());
503+
using dst_type = typename std::decay_t<decltype(de1)>;
504+
using src_type = typename std::decay_t<decltype(de2)>;
505+
using traits = xassign_traits<dst_type, src_type>;
452506

453507
bool linear_assign = traits::linear_assign(de1, de2, trivial);
454508
constexpr bool simd_assign = traits::simd_assign();

0 commit comments

Comments
 (0)