2

Given a torch tensor:

# example tensor size 2 x 4
a = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])

and another where every n rows are repeated:

# example tensor size 4 x 3 where every 2 rows repeated
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])

how can one perform matrix multiplication:

>>> torch.mm(a, b)
tensor([[ 28.,  38.,  48.],
        [ 68.,  94., 120.]])

without copying the whole repeated row tensor into memory or iterating?

i.e. only store the first 2 rows:

# example tensor size 2 x 3 where only the first two rows from b are actually stored in memory
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])

since these rows will be repeated.

There is a function

torch.expand()

but this does work when repeating more than a single row, and also, as this question:

Repeating a pytorch tensor without copying memory

indicates and my own tests confirm often ends up copying the whole tensor into memory anyway when calling

.to(device)

It is also possible to do this iteratively, but this is relatively slow.

Is there some way to perform this operation efficiently without storing the whole repeated row tensor in memory?

Edit explanation:

Sorry, for not initially clarifying: One was used as the first dimension of the first tensor to keep the example simple, but I am actually looking for a solution to the general case for any two tensors a and b such that their dimensions are compatible for matrix multiplication and the rows of b repeat every n rows. I have updated the example to reflect this.

1 Answer 1

1

Assuming that the first dimension of a is 1 as in your example, you could do the following:

a = torch.Tensor([[1, 2, 3, 4]])
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
torch.mm(a.reshape(-1, 2), b_abbreviated).sum(axis=0, keepdim=True)

Here, instead of repeating the rows, you multiply a in chunks, then add them up column-wise to get the same result.


If the first dimension of a is not necessarily 1, you could try the following:

torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1).sum(
dim=0, keepdim=True).reshape(a.shape[0], -1)

Here, you do the following:

  • With torch.mm(a.reshape(-1,2),b_abbreviated, you again split each row of a into chunks of size 2 and stack them one over the other, and then stack each row over the other.
  • With torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), these stacks are then separated row-wise, so that each resultant component of the split corresponds to chunks of a single row.
  • With torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1) these stacks are then concatenated column-wise.
  • With .sum(dim=0, keepdim=True), results corresponding to separate chunks of individual rows in a are added up.
  • With .reshape(a.shape[0], -1), rows of a that were concatenated column-wise are again stacked row-wise.

It seems quite slow compared to direct matrix multiplication, which is not surprising, but I have not yet checked in comparison to explicit iteration. There are likely better ways of doing this, will edit if I think of any.

Sign up to request clarification or add additional context in comments.

3 Comments

Sorry, this looks like a good answer for the example that I had originally given, but I actually had not meant to imply that first dimension of first tensor would always be one. I have now updated the question with a more general example to clarify.
Ah, that is cool that it can be generalized. I notice, however, that the size of the intermediate matrix produced by torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1) increases in size in proportion to the first dimension of a, which may end up defeating the purpose of trying to avoid using a lot of memory when a has a large number of rows. I am curious to know if such large intermediate matrices are avoidable but am thinking that this may be about the best that can be done for this problem.
@nellapizza The size of the intermediate matrix is only 2x the size of the final one. I assume any savings will be in the case when the number of repetitions in b is very high. In any case, I don't think this is the best possible way at all in terms of time efficiency, but I couldn't think of a better way.