@@ -439,16 +439,70 @@ namespace xt
439
439
using requested_value_type = detail::conditional_promote_to_complex_t <e1_value_type, e2_requested_value_type>;
440
440
};
441
441
442
+ /* *********************************
443
+ * Expression Order Optimizations *
444
+ **********************************/
445
+
446
+ class optimize_expression
447
+ {
448
+ private:
449
+
450
+ template <class E1 , class E2 >
451
+ struct equal_rank
452
+ {
453
+ static constexpr bool value = get_rank<E1 >::value == get_rank<E2 >::value;
454
+ };
455
+
456
+ template <class E1 , class ... E>
457
+ struct all_equal_rank
458
+ {
459
+ static constexpr bool value = xtl::conjunction<equal_rank<E1 , E>...>::value
460
+ && (get_rank<E1 >::value != SIZE_MAX);
461
+ };
462
+
463
+ template <class F , class ... CT, class ... S, size_t ... I, size_t ... J>
464
+ inline auto
465
+ impl_reorder_function (const xfunction<F, CT...>& e, std::tuple<S...> slices, std::index_sequence<I...>, std::index_sequence<J...>)
466
+ {
467
+ return make_lambda_xfunction (F (), view (std::get<I>(e.arguments ()), std::get<J>(slices)...)...);
468
+ }
469
+
470
+ public:
471
+
472
+ // when we have a view of a function where the closures of the functions are of equal rank (i.e no
473
+ // broadcasting) we can flip the order of the function and the view such that we have a function of
474
+ // views of containers which can be linearly assigned unlike the inverse.
475
+ template <class F , class ... CT, class ... S, class = std::enable_if_t <all_equal_rank<std::decay_t <CT>...>::value>>
476
+ inline auto reorder (const xview<xfunction<F, CT...>, S...>& e)
477
+ {
478
+ return impl_reorder_function (
479
+ e.expression (),
480
+ e.slices (),
481
+ std::make_index_sequence<sizeof ...(CT)>(),
482
+ std::make_index_sequence<sizeof ...(S)>()
483
+ );
484
+ }
485
+
486
+ // base case no applicable optimization
487
+ template <class E >
488
+ inline auto & reorder (E&& e)
489
+ {
490
+ return std::forward<E>(e);
491
+ }
492
+ };
493
+
442
494
template <class E1 , class E2 >
443
495
inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
444
496
xexpression<E1 >& e1 ,
445
497
const xexpression<E2 >& e2 ,
446
498
bool trivial
447
499
)
448
500
{
449
- E1 & de1 = e1 .derived_cast ();
450
- const E2 & de2 = e2 .derived_cast ();
451
- using traits = xassign_traits<E1 , E2 >;
501
+ auto & de1 = e1 .derived_cast ();
502
+ const auto & de2 = optimize_expression ().reorder (e2 .derived_cast ());
503
+ using dst_type = typename std::decay_t <decltype (de1)>;
504
+ using src_type = typename std::decay_t <decltype (de2)>;
505
+ using traits = xassign_traits<dst_type, src_type>;
452
506
453
507
bool linear_assign = traits::linear_assign (de1, de2, trivial);
454
508
constexpr bool simd_assign = traits::simd_assign ();
0 commit comments