Pytorch 对张量进行随机采样

更多pytorch操作见:
Pytorch学习之torch
python选取tensor某一维_Pytorch的Tensor操作(1)

如何对张量进行随机采样

假设我目前的张量的shape为 [ B , N , C ] = [ 32 , 512 , 64 ] [B,N,C] = [32,512,64] [B,N,C]=[32,512,64],N表示点的数量,我要从中随机选取S个点,且为不重复采样。

a = torch.randn((32,512,64))
print(a.shape)  # torch.Size([32, 512, 64])
B,N,C = a.shape
S = 64
index = torch.LongTensor(random.sample(range(N), S))
print(index)
b = torch.index_select(a, 1, index)
print(b.shape)  # torch.Size([32, 64, 64])

如果想要截取 [ B , S ] [B,S] [B,S]的话,可以使用“ …”:表示同时取多个维度,只能全选或已知前后具体采样维度,剩下的全选,某一维度取“1”时,会自动降维。

b = torch.index_select(a[..., 0], 1, index)