I have a very large tensor L (millions of elements), from which I gather a relatively small subtensor S (maybe a thousand of elements).
I then apply my model to S, compute loss, and backpropagate to S and to L with the intent to only update selected elements in L. Problem is PyTorch makes L's gradient to be a continuous tensor, so it basically doubles L's memory usage.
Is there an easy way to compute and apply gradient to L without doubling memory usage?
Sample code to illustrate the problem:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
net = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64,64),
nn.ReLU(),
nn.Linear(64, 1))
L = Parameter(torch.zeros([1024*1024*256], dtype=torch.float32))
L.data.uniform_(-1, 1)
indices = torch.randint(high=256*1024*1024, size=[1024])
S = torch.unsqueeze(L[indices], dim=1)
out = net(S)
loss = out.sum()
loss.backward()
print(loss)
g = L.grad
print(g.shape) # this is huge!