I have a large numpy array of shape (n,m). I need to extract one element from each row, and I have another array of shape (n,) that gives the column index of the element I need. The following code does this, but it requires an explicit loop (in the form of a list comprehension.):
import numpy as np
arr = np.array(range(12))
arr = arr.reshape((4,3))
keys = np.array([1,0,1,2])
#This is the line that I'd like to optimize
answers = np.array([arr[i,keys[i]] for i in range(len(keys))])
print(answers)
# [ 1 3 7 11]
Is there a built-in numpy (or pandas?) function that could do this more efficiently?
arr[np.arange(4), keys]\$\endgroup\$