29

Similar to this answer, I have a pair of 3D numpy arrays, a and b, and I want to sort the entries of b by the values of a. Unlike this answer, I want to sort only along one axis of the arrays.

My naive reading of the numpy.argsort() documentation:

Returns
-------
index_array : ndarray, int
    Array of indices that sort `a` along the specified axis.
    In other words, ``a[index_array]`` yields a sorted `a`.

led me to believe that I could do my sort with the following code:

import numpy

a = numpy.zeros((3, 3, 3))
a += numpy.array((1, 3, 2)).reshape((3, 1, 1))
print "a"
print a
"""
[[[ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]]

 [[ 3.  3.  3.]
  [ 3.  3.  3.]
  [ 3.  3.  3.]]

 [[ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]]]
"""
b = numpy.arange(3*3*3).reshape((3, 3, 3))
print "b"
print b
"""
[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]]
"""
print "a, sorted"
print numpy.sort(a, axis=0)
"""
[[[ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]]

 [[ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]]

 [[ 3.  3.  3.]
  [ 3.  3.  3.]
  [ 3.  3.  3.]]]
"""

##This isnt' working how I'd like
sort_indices = numpy.argsort(a, axis=0)
c = b[sort_indices]
"""
Desired output:

[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]]
"""
print "Desired shape of b[sort_indices]: (3, 3, 3)."
print "Actual shape of b[sort_indices]:"
print c.shape
"""
(3, 3, 3, 3, 3)
"""

What's the right way to do this?

2 Answers 2

23

You still have to supply indices for the other two dimensions for this to work correctly.

>>> a = numpy.zeros((3, 3, 3))
>>> a += numpy.array((1, 3, 2)).reshape((3, 1, 1))
>>> b = numpy.arange(3*3*3).reshape((3, 3, 3))
>>> sort_indices = numpy.argsort(a, axis=0)
>>> static_indices = numpy.indices((3, 3, 3))
>>> b[sort_indices, static_indices[1], static_indices[2]]
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]]])

numpy.indices calculates the indices of each axis of the array when "flattened" through the other two axes (or n - 1 axes where n = total number of axes). In other words, this (apologies for the long post):

>>> static_indices
array([[[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],

        [[1, 1, 1],
         [1, 1, 1],
         [1, 1, 1]],

        [[2, 2, 2],
         [2, 2, 2],
         [2, 2, 2]]],


       [[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],

        [[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],

        [[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]]],


       [[[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]]])

These are the identity indices for each axis; when used to index b, they recreate b.

>>> b[static_indices[0], static_indices[1], static_indices[2]]
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])

As an alternative to numpy.indices, you could use numpy.ogrid, as unutbu suggests. Since the object generated by ogrid is smaller, I'll create all three axes, just for consistency sake, but note unutbu's comment for a way to do this by generating only two.

>>> static_indices = numpy.ogrid[0:a.shape[0], 0:a.shape[1], 0:a.shape[2]]
>>> a[sort_indices, static_indices[1], static_indices[2]]
array([[[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]],

       [[ 3.,  3.,  3.],
        [ 3.,  3.,  3.],
        [ 3.,  3.,  3.]]])
Sign up to request clarification or add additional context in comments.

3 Comments

I like this answer. To save some memory, perhaps change static_indices to static_indices = np.ogrid[0:a.shape[1],0:a.shape[2]]. This will produce smaller arrays, but will do the same thing as np.indices by taking advantage of broadcasting. It could be used like this: b[sort_indices, static_indices[1], static_indices[2]].
Err, b[sort_indices, static_indices[0], static_indices[1]] rather.
@unutbu, thanks! I'm still getting to know numpy's very rich indexing system; it's nice to know about an automatic way to generate broadcastable indices.
2

numpy.take_along_axis() does it, and presumably without unnecessary extra memory usage:

# From the original question:
import numpy
a = numpy.zeros((3, 3, 3))
a += numpy.array((1, 3, 2)).reshape((3, 1, 1))
b = numpy.arange(3*3*3).reshape((3, 3, 3))
sort_indices = numpy.argsort(a, axis=0)
# This is not working as expected:
c = b[sort_indices]
# This does what is expected:
c = numpy.take_along_axis(b, sort_indices, axis=0)
print(c)
"""
[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]]
"""

2 Comments

The answer should be more exhaustive and detailed
Your answer could be improved with additional supporting information. Please edit to add further details, such as citations or documentation, so that others can confirm that your answer is correct. You can find more information on how to write good answers in the help center.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.