为什么在template函数广播中把两个extensor表达式加在一起不正确

Why does adding two xtensor expressions together in template function broadcast incorrectly?

本文关键字:两个 extensor 表达式 不正确 加在一起 template 函数 广播 为什么      更新时间:2023-10-16

考虑以下程序:

#include <iostream>
#include "xtensor/xarray.hpp"
#include "xtensor/xio.hpp"
#include "xtensor/xview.hpp"
xt::xarray<double> arr1
{1.0, 2.0, 3.0};
xt::xarray<double> arr2
{5.0, 6.0, 7.0};
template <typename T, typename U>
struct container{
container(const T& t, const U& u) : a(t), b(u) {}
T a;
U b;
};
template <typename T, typename U>
container<T, U> make_container(const T& t, const U& u){
return container<T,U>(t, u);
}
auto c = make_container(arr1, arr1);
std::cout << (arr1 * arr1) + arr2;
template <typename A, typename B, typename R>
auto operator+(const container<A, B>& e1, const R& e2){
return (e1.a * e1.b) + e2;
}
std::cout << (c + arr2);

如果我们看看代码:

std::cout << (arr1 * arr1) + arr2;

它将输出:

{  6.,  10.,  16.}

然而,运行最后一行:

std::cout << (c + arr2);

产生以下结果:

{{  6.,   9.,  14.}, {  7.,  10.,  15.}, {  8.,  11.,  16.}}

为什么会出现这种情况?我将operator+的功能定义更改为:

template <typename A, typename B, typename R>
auto operator+(const container<A, B>& e1, const R& e2){
std::cout << __PRETTY_FUNCTION__ << std::endl;
return (e1.b * e1.alpha) + e2;
}

输出有点令人惊讶:

auto operator+(const container<A, B> &, const R &) [A = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, B = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, R = double]
auto operator+(const container<A, B> &, const R &) [A = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, B = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, R = double]
auto operator+(const container<A, B> &, const R &) [A = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, B = xt::xarray_container<xt::uvector<double, std::allocator<double> >, xt::layout_type::row_major, xt::svector<unsigned long, 4, std::allocator<unsigned long>, true>, xt::xtensor_expression_tag>, R = double]
{{  6.,   9.,  14.}, {  7.,  10.,  15.}, {  8.,  11.,  16.}}

为什么在一个操作中调用3个+操作?是否在某个地方定义了导致这种行为的宏?operator+中的R类型给出了double,实际上应该是xt::xarray<double>

任何见解都将不胜感激,谢谢。

在命名空间xt中定义的operator+采用通用引用,因此在编写c + arr2时,它优先于重载。

因此,最后一行将返回一个xfunction,其第一个操作数是container,第二个操作数为xarray

现在,由于container不是xexpression,在xfunction内部,它被处理为。。。xscalar<container>

因此,当您尝试访问此xfunction的第i个元素时,将执行以下操作:xscalar<container> + arr2[i](广播xscalar(。由于xscalar<container>可转换为container,因此调用operator+重载时,R被解析为arr2value_type,即double

以下循环说明了这种行为:

auto f = c + arr2;
for(auto iter = f.begin(); iter != f.end(); ++iter)
{
std::cout << *iter << std::endl;
}

它生成operator+的以下调用:

operator+(c, arr[0]);
operator+(c, arr[1]);
operator+(c, arr[2]);

这就是为什么您会看到operator+的3个呼叫。