https://numpy.org/doc/stable/reference/generated/numpy.where.html?highlight=where#numpy.where
https://www.cnblogs.com/massquantity/p/8908859.html
import numpy as np
scores = -1 + 2 * np.random.random(size=(2, 4, 4)) # 随机生成[-1, 1]之间小数
scores
array([[[-0.9941625 , -0.13803124, -0.1789303 , -0.84169049],
[-0.34679528, -0.09184122, 0.2305758 , -0.83257719],
[ 0.31593582, -0.14398776, 0.24329462, -0.72213975],
[ 0.07164275, 0.76827745, -0.50925405, -0.55656705]],
[[-0.59053878, -0.44938365, 0.49978236, 0.51931104],
[-0.7532766 , -0.63130856, 0.57406023, -0.77735105],
[ 0.84732412, 0.56404257, 0.93578561, 0.34568311],
[-0.07856562, 0.24609107, 0.83911535, 0.58042135]]])
scores[0]
array([[-0.9941625 , -0.13803124, -0.1789303 , -0.84169049],
[-0.34679528, -0.09184122, 0.2305758 , -0.83257719],
[ 0.31593582, -0.14398776, 0.24329462, -0.72213975],
[ 0.07164275, 0.76827745, -0.50925405, -0.55656705]])
locs = np.where(scores > 0.5)
locs
(array([0, 1, 1, 1, 1, 1, 1, 1]),
array([3, 0, 1, 2, 2, 2, 3, 3]),
array([1, 3, 2, 0, 1, 2, 2, 3]))
for label, start, end in zip(*locs):
print(label, start, end)
0 3 1
1 0 3
1 1 2
1 2 0
1 2 1
1 2 2
1 3 2
1 3 3