3
import torch

def analyze_your_actual_code(): """分析你的实际代码"""

    a = torch.randn(4, requires_grad=True)
    
    print(f"a: {a}")
    print(f"a.grad: {a.grad}")
    
    loss = a.sum()
    loss.backward()
    print(f"  a.grad: {a.grad}")
    
    loss.backward()  
    print(f"  a.grad: {a.grad}")

analyze_your_actual_code()

Why I try to call loss.backward twice, but the gradients still accumulate. The compute graph don't release as expected?

Code output:

a: tensor([-0.3121, -0.2331,  0.9317, -0.5075], requires_grad=True)

a.grad: None a.grad: tensor([1., 1., 1., 1.])

a.grad: tensor([2., 2., 2., 2.])

2 Answers 2

1

While you use loss.backward(), it may caused RuntimeError on second loss.backward() call. If you want to call loss.backward() twice, you have to use retain_graph() parameter.

The behavior of the backward pass is controlled by the retain_graph parameter.

Here's the corrected code that will run without error and produce the output you described:

import torch

def analyze_your_actual_code():
    a = torch.randn(4, requires_grad=True)

    print(f"a: {a}")
    print(f"a.grad: {a.grad}")
    
    loss = a.sum()
    loss.backward(retain_graph=True)  
    print(f"  a.grad after 1st backward: {a.grad}")
    
    loss.backward()  
    print(f"  a.grad after 2nd backward: {a.grad}")

analyze_your_actual_code()
Sign up to request clarification or add additional context in comments.

1 Comment

However, it should but it do not rasie RuntimeError as expected in this situation, while it runs and the gradient accumulates. That's the problem
0

You're right that the first loss.backward() is supposed to "free the graph" (unless you specify retain_graph=True). However, freeing the graph does not delete it, but rather just frees the saved tensors stored in it. Your example is so simple that it does not contain any such saved tensor, so you're still able to backward a second time (or as many times as you want) without having an error.

Simply changing loss = a.sum() by loss = a.sum() * 2 will make autograd have to store the tensor of value 2 into the graph. Then, the first backward will free this tensor, and the second backward will trigger error that you expected.

Note that the actual graph (that is very light in memory, because it has already been freed from saved tensors by the call to backward) should be discarded (= ready to be garbage collected) whenever you don't hold any reference to loss anymore, or if you call loss = loss.detach().

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.