3

I'm doing a policy gradient method in PyTorch. I wanted to move the network update into the loop and it stopped working. I'm still a PyTorch newbie so sorry if the explanation is obvious.

Here is the original code that works:

self.policy.optimizer.zero_grad()
G = T.tensor(G, dtype=T.float).to(self.policy.device) 

loss = 0
for g, logprob in zip(G, self.action_memory):
    loss += -g * logprob
                                 
loss.backward()
self.policy.optimizer.step()

And after the change:

G = T.tensor(G, dtype=T.float).to(self.policy.device) 

loss = 0
for g, logprob in zip(G, self.action_memory):
    loss = -g * logprob
    self.policy.optimizer.zero_grad()
                                 
    loss.backward()
    self.policy.optimizer.step()

I get the error:

File "g:\VScode_projects\pytorch_shenanigans\policy_gradient.py", line 86, in learn
    loss.backward()
  File "G:\Anaconda3\envs\pytorch_env\lib\site-packages\torch\tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "G:\Anaconda3\envs\pytorch_env\lib\site-packages\torch\autograd\__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128, 4]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I read that this RuntimeError often has to do with having to clone something, because we're using the same tensor to compute itself but I can't make heads of tails of what is wrong in my case.

1 Answer 1

1

This line, loss += -g * logprob, is what is wrong in your case.

Change it to this:

loss = loss + (-g * logprob)

And Yes, they are different. They perform the same operations but in different ways.

Sign up to request clarification or add additional context in comments.

1 Comment

But the code with this line works. It's the other snippet below that is problematic.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.