3
\$\begingroup\$

I have a large numpy array of shape (n,m). I need to extract one element from each row, and I have another array of shape (n,) that gives the column index of the element I need. The following code does this, but it requires an explicit loop (in the form of a list comprehension.):

import numpy as np

arr = np.array(range(12))
arr = arr.reshape((4,3))
keys = np.array([1,0,1,2])
#This is the line that I'd like to optimize
answers = np.array([arr[i,keys[i]] for i in range(len(keys))])
print(answers)
# [ 1  3  7 11]

Is there a built-in numpy (or pandas?) function that could do this more efficiently?

\$\endgroup\$
2
  • \$\begingroup\$ arr[np.arange(4), keys] \$\endgroup\$ Commented Nov 1, 2015 at 10:54
  • 1
    \$\begingroup\$ It's correct. @GarethRees Please put that in the answer so you don't waste others time looking at something that hasn't been answered yet. \$\endgroup\$ Commented Nov 1, 2015 at 14:26

1 Answer 1

4
\$\begingroup\$

The best way to do this is how @GarethRees suggested in the comments:

>>> arr[np.arange(4), keys]
array([1, 3, 7, 11])

There is another (not as good) solution that is a better alternative to using arr[i, keys[i]] for i in range(len(keys)). Whenever you want both the index and the item from some iterable, you should generally use the enumerate function:

>>> np.array([arr[i, item] for i, item in enumerate(keys)])
array([1, 3, 7, 11])
\$\endgroup\$
1
  • 1
    \$\begingroup\$ Good point about enumerate since it's generally applicable. Thanks to Gareth for providing the initial answer in a comment. I'm accepting this one so that the question doesn't show as unanswered. \$\endgroup\$ Commented Nov 2, 2015 at 5:08

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.