ONNXRUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:‘Where‘

遇到此类错误,如:

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Gather node. Name:'Gather_4445' Status Message: indices element out of data bounds, idx=8 must be within the inclusive range [-3,2]

或:

RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where' Status Message...

可以配合Netron工具(安装方法:pip install netron,使用时终端输入netron)查看导出的onnx模型网络图,可以查找相应的Node(如:Where_XXXX),再去代码中找对应代码,将其改为onnx支持的tensor运算方式即可解决相应问题。

根据在ONNX导出时遇到的问题比较麻烦的是和torch.gather、torch.where、torch.split等Tensor运算方法。

1. torch.where函数

torch.where(condition,x,y)->tensor

当满足condition,则来自于a,反之来自b

import torch
condition=torch.randn(2,2)
# tensor([[ 0.2589, -0.5600],
#        [ 0.9056, -0.3915]])
a=torch.tensor([[0,0],[0,0]])
b=torch.tensor([[1,1],[1,1]])
torch.where(cond>0.5,a,b)

得到结果

tensor([[1, 1],
        [0, 1]])

输出为0的代表来源为a,输出为1的代表来源为b

2. torch.gather(查表的过程)

torch.gather(input,dim,index,out=None)->tensor

就像是给了数据以后,查表得到对应参数,再收集回来进行输出。

gather函数即为gather(对应的参数表,dim,数据表)

import torch
prob=torch.randn(4,4)
#tensor([[-0.9845,  0.5094, -0.5014, -0.5354],
#        [-1.8514,  0.2640,  0.7895, -0.1660],
#        [ 0.3955,  0.7571,  0.1451,  0.1970],
#        [ 0.3674, -0.8006, -0.5625,  1.3455]])
idx=prob.topk(dim=1,k=2)
idx=idx[1]
#tensor([[1, 2],
#        [2, 1],
#        [1, 0],
#        [3, 0]]))
label=torch.arange(4)+100
#tensor([100, 101, 102, 103])
torch.gather(label.expand(4,4),dim=1,index=idx.long())

 输出结果:

tensor([[101, 102],
        [102, 101],
        [101, 100],
        [103, 100]])

3. torch.split

含义:将一个张量分为几个chunks

torch.split(tensor, split_size_or_sections, dim=0)

参数

  • tensor(Tensor) -张量分裂。

  • split_size_or_sections(int) 或者(list(int)) -单个块的大小或每个块的大小列表

  • dim(int) -沿其分割张量的维度。

如果split_size_or_sections 是整数类型,那么tensor将被分成大小相等的块(如果可能)。如果沿给定维度 dim 的张量大小不能被 split_size 整除,则最后一个块会更小。

如果 split_size_or_sections 是一个列表,那么 tensor 将根据 split_size_or_sections 被拆分为大小在 dim 中的 len(split_size_or_sections) 块。

示例:

>>> a = torch.arange(8).reshape(4,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
>>> torch.split(a, 3)
(tensor([[0, 1],
         [2, 3],
         [4, 5]]),
 tensor([[6, 7]]))
>>> torch.split(a, [1,3])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7]]))

4. Tensor.scatter_函数

TORCH.TENSOR.SCATTER_

Tensor.scatter_(dimindexsrcreduce=None) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

For a 3-D tensor, self is updated as: 

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
This is the reverse operation of the manner described in gather().

selfindex and src (if it is a Tensor) should all have the same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim. Note that index and src do not broadcast.

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive.

Parameters

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.

  • src (Tensor or float) – the source element(s) to scatter.

  • reduce (stroptional) – reduction operation to apply, can be either 'add' or 'multiply'.

总结:scatter函数就是把src数组中的数据重新分配到output数组当中,index数组中表示了要把src数组中的数据分配到output数组中的位置,若未指定,则填充0.

举例:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])