3

I have a PyTorch LSTM model and my forward function looks like:

    def forward(self, x, hidden):
        print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
        lstm_out, hidden = self.lstm(x, hidden)
        return lstm_out, hidden

All of the print statements show torch.float64, which I believe is a double. So then why am I getting this issue?

I've cast to double in all of the relevant places already.

1 Answer 1

6

Make sure both your data and model are in dtype double.

For the model:

net = net.double()

For the data:

net(x.double())

It has been discussed on PyTorch forum.

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.