3

I have a list of indices and a tensor with shape:

shape = [batch_size, d_0, d_1, ..., d_k]
idx = [i_0, i_1, ..., i_k]

Is there a way to efficiently index the tensor on each dim d_0, ..., d_k with the indices i_0, ..., i_k? (k is available only at run time)

The result should be:

tensor[:, i_0, i_1, ..., i_k] #tensor.shape = [batch_size]

At the moment I'm creating a tuple of slices, one for each dimension:

idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]

but I would prefer something like:

tensor[:, *idx]

Example:

a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])

I would like to index only the last len(indexes) dimensions like:

a[:, indexes[0], indexes[1], indexes[2]]

but in the general case where I don't know how long indexes is.


Note: this answer does not help since it indexes all the dimensions, and does not work for a proper subset!

0

1 Answer 1

2

Unfortunately you can't provide1 a mix of slices and iterators to an indexing (e.g. a[:,*idx]). However, you can achieve almost the same thing by wrapping it in brackets to cast to an iterator:

a[(slice(None), *idx)]

  1. In Python, x[(exp1, exp2, ..., expN)] is equivalent to x[exp1, exp2, ..., expN]; the latter is just syntactic sugar for the former.

    https://numpy.org/doc/stable/reference/arrays.indexing.html

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.