My model trains perfectly fine, but when I switch it to evaluation mode it does not like the data types of the input samples:
Traceback (most recent call last):
File "model.py", line 558, in <module>
main_function(train_sequicity=args.train)
File "model.py", line 542, in main_function
out = model(user, bspan, response_, degree)
File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "model.py", line 336, in forward
self.params['bspan_size'])
File "model.py", line 283, in _greedy_decode_output
out = decoder(input_, encoder_output)
File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "model.py", line 142, in forward
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
RuntimeError: Expected object of scalar type Long but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'
This seems to occur in a part of the code where concatenation happens. This is in an architecture similar to the pytorch transformer, just modified to have two decoders:
def forward(self, tgt, memory):
""" Call decoder
the decoder should be called repeatedly
Args:
tgt: input to transformer_decoder, shape: (seq, batch)
memory: output from the encoder
Returns:
output from linear layer, (vocab size), pre softmax
"""
go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3 # GO_2 token has index 3
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
+
mask = tgt.eq(0).transpose(0,1) # 0 corresponds to <pad>
tgt = self.embedding(tgt) * self.ninp
tgt = self.pos_encoder(tgt)
tgt_mask = self._generate_square_subsequent_mask(tgt.size(0))
output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask)
output = self.linear(output)
return output
The concatenation bit in the middle of the codeblock is where the problem happens. The odd thing is that it works perfectly fine and trains, with loss going down in train mode. This issue only comes up in eval mode. What could the problem be?