花式索引
组合索引
与numpy索引方式几乎一致,参考numpy学习
import torch
x = torch.randn(1, 3, 4, 5)
print(x.shape)
x
torch.Size([1, 3, 4, 5])
tensor([[[[ 2.2271, 2.0278, -0.5272, 1.4764, -0.6787],
[-0.7569, -0.8737, 0.0031, 0.7208, 1.9315],
[ 0.6217, 0.3060, -0.3342, 0.1870, 1.5624],
[-1.2602, -1.7282, -0.1555, 2.0486, 0.2562]],
[[-0.3865, -0.4514, 0.6373, -0.2836, -0.0389],
[ 1.1489, -1.4416, 0.2477, 1.7385, 0.7891],
[ 0.7877, 0.6736, -1.0933, 0.3111, 0.2772],
[ 0.4253, 1.3894, -1.1572, 1.6345, 0.1543]],
[[-1.6844, 0.5027, -0.4037, 1.5505, -0.6599],
[ 0.8898, -1.4016, -0.2114, 1.2523, 0.3860],
[-0.0034, 0.9842, -0.2550, -0.6099, -0.5720],
[-1.4554, 0.3559, 1.4612, -0.4953, -1.1379]]]])
dim0 = torch.tensor([0, 0])
dim2 = torch.tensor([0, 1])
x_ = x[dim0, :, dim2]
x_
tensor([[[ 2.2271, 2.0278, -0.5272, 1.4764, -0.6787],
[-0.3865, -0.4514, 0.6373, -0.2836, -0.0389],
[-1.6844, 0.5027, -0.4037, 1.5505, -0.6599]],
[[-0.7569, -0.8737, 0.0031, 0.7208, 1.9315],
[ 1.1489, -1.4416, 0.2477, 1.7385, 0.7891],
[ 0.8898, -1.4016, -0.2114, 1.2523, 0.3860]]])