I have a PyTorch
LSTM model and my forward
function looks like:
def forward(self, x, hidden):
print('in forward', x.dtype, hidden[0].dtype, hidden[1].dtype)
lstm_out, hidden = self.lstm(x, hidden)
return lstm_out, hidden
All of the print
statements show torch.float64
, which I believe is a double. So then why am I getting this issue?
I've cast to double
in all of the relevant places already.