Multi index array access

import numpy as np probs= np.array([[0.1, 0.2, 0.7],[0.2, 0.8, 0.6],[0.3, 0.6, 0.1]]) y = [2,1,1] print(probs[np.arange(len(probs)), y])
Accessing array values with multiplw indices

