0

I have a very large tensor L (millions of elements), from which I gather a relatively small subtensor S (maybe a thousand of elements).

I then apply my model to S, compute loss, and backpropagate to S and to L with the intent to only update selected elements in L. Problem is PyTorch makes L's gradient to be a continuous tensor, so it basically doubles L's memory usage.

Is there an easy way to compute and apply gradient to L without doubling memory usage?

Sample code to illustrate the problem:

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

net = nn.Sequential(
  nn.Linear(1, 64),
  nn.ReLU(),
  nn.Linear(64,64),
  nn.ReLU(),
  nn.Linear(64, 1))

L = Parameter(torch.zeros([1024*1024*256], dtype=torch.float32))
L.data.uniform_(-1, 1)

indices = torch.randint(high=256*1024*1024, size=[1024])
S = torch.unsqueeze(L[indices], dim=1)

out = net(S)

loss = out.sum()

loss.backward()

print(loss)
g = L.grad
print(g.shape)  # this is huge!

1 Answer 1

1

You don't actually need requires_grad on L as gradients will be computed and applied manually. Instead, set it on S. That will stop backpropagation at S.

Then, you can update the values of L using S.grad and your preferred optimization. Something along these lines

L = torch.zeros([1024*1024*256], dtype=torch.float32)

...

S = torch.unsqueeze(L[indices], dim=1)
S.requires_grad_()

out = net(S)

loss = torch.abs(out).sum()

loss.backward()

with torch.no_grad():   
  L[indices] -= learning_rate * torch.squeeze(S.grad)
  S.grad.zero_()
Sign up to request clarification or add additional context in comments.

5 Comments

Could you be more specific about the detach part? The training loop would start at indices = . Would I need to call S.detach()? L.detach()?
I figured it out, do you mind if I edit your answer with a full code sample?
Sure, go for it
Pretty sure you can also just do L = torch.zeros([1024*1024*256], dtype=torch.float32,requires_grad = False) `
actually, you need to explicitly set requires_grad to True on S, and don't need to detach anything. I edited your answer, feel free to edit further.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.