I have a list of indices and a tensor with shape:
shape = [batch_size, d_0, d_1, ..., d_k]
idx = [i_0, i_1, ..., i_k]
Is there a way to efficiently index the tensor on each dim d_0, ..., d_k with the indices i_0, ..., i_k? (k is available only at run time)
The result should be:
tensor[:, i_0, i_1, ..., i_k] #tensor.shape = [batch_size]
At the moment I'm creating a tuple of slices, one for each dimension:
idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]
but I would prefer something like:
tensor[:, *idx]
Example:
a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])
I would like to index only the last len(indexes) dimensions like:
a[:, indexes[0], indexes[1], indexes[2]]
but in the general case where I don't know how long indexes is.
Note: this answer does not help since it indexes all the dimensions, and does not work for a proper subset!