16

I'm trying to load the model using this tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference . Unfortunately I'm very beginner and I face some problems.

I have created checkpoint:

checkpoint = {'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),'loss': loss}
torch.save(checkpoint, 'checkpoint.pth')

Then I wrote class for my network and I wanted to load the file:

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 1000)
        self.fc3 = nn.Linear(1000, 102)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = log(F.softmax(x, dim=1))
        return x

Like that:

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = Network()
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

model = load_checkpoint('checkpoint.pth')

I got this error (edited to show whole communicate):

RuntimeError: Error(s) in loading state_dict for Network:
    Missing key(s) in state_dict: "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias". 
    Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.3.weight", "features.3.bias", "features.6.weight", "features.6.bias", "features.8.weight", "features.8.bias", "features.10.weight", "features.10.bias", "classifier.fc1.weight", "classifier.fc1.bias", "classifier.fc2.weight", "classifier.fc2.bias", "classifier.fc3.weight", "classifier.fc3.bias". 

This is my model.state_dict().keys():

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 
'features.3.bias', 'features.6.weight', 'features.6.bias', 
'features.8.weight', 'features.8.bias', 'features.10.weight', 
'features.10.bias', 'classifier.fc1.weight', 'classifier.fc1.bias', 
'classifier.fc2.weight', 'classifier.fc2.bias', 'classifier.fc3.weight', 
'classifier.fc3.bias'])

This is my model:

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)

((classifier): Sequential(
(fc1): Linear(in_features=9216, out_features=4096, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=4096, out_features=1000, bias=True)
(relu2): ReLU()
(fc3): Linear(in_features=1000, out_features=102, bias=True)
(output): LogSoftmax()
)
)

It's my first network ever and I'm blundering along. Thanks for steering me into right direction!

6
  • 1
    What if you just rename the corresponding keys in your model.state_dict().keys(), so that features.3.weight becomes fc3.weight, and so on? Commented Dec 23, 2018 at 21:12
  • I'll try and let you know in a moment Commented Dec 23, 2018 at 21:14
  • It's weird but when I do it, after loading the model is None Commented Dec 23, 2018 at 21:50
  • 1
    Ah OK, so because you are not using a return value on the function, when you call load_checkpoint it returns nothing; hence NoneType. If you want to return the model from your function, you need to add return model to the bottom of your function. If you do not need to return it, remove the model = from the model = load_checkpoint('checkpoint.pth') which will just call the function. Commented Dec 23, 2018 at 22:12
  • 1
    If you want to return multiple variables, you would need to return them individually. E.g. return checkpoint, model, epoc, loss etc.. and where you call the function, you will need to catch each return value in to another variable. E.g. checkpoint, model, epoc, loss = load_checkpoint('checkpoint.pth') Commented Dec 23, 2018 at 22:15

2 Answers 2

9

So your Network is essentially the classifier part of AlexNet and you're looking to load pretrained AlexNet weights into it. The problem is that the keys in state_dict are "fully qualified", which means that if you look at your network as a tree of nested modules, a key is just a list of modules in each branch, joined with dots like grandparent.parent.child. You want to

  1. Keep only the tensors with name starting with "classifier."
  2. Remove the "classifier." part of keys

so try

model = Network()
loaded_dict = checkpoint['model_state_dict']
prefix = 'classifier.'
n_clip = len(prefix)
adapted_dict = {k[n_clip:]: v for k, v in loaded_dict.items()
                if k.startswith(prefix)}
model.load_state_dict(adapted_dict)
Sign up to request clarification or add additional context in comments.

5 Comments

It doesn't return any error but when I print(model) it shows None. I mean after this model = load_checkpoint('checkpoint.pth')
I tried to achieve 1) and 2) in a more understandable way for me but the outcome is still the same. The model after loading is empty.
Right, as mentioned by Adam above, you need to return the model if you want to get it as the return value of your function.
What do you mean? AlexNet is comprised of two parts called features and classifier. Your Network implements only classifier so yes, you're losing the features part. I was assuming that's what you meant to do
I'd like to save the whole model to be able to train it in the future
5

in my case, i had to remove "module." prefix from the state dict to load.

    model= Model()
    state_dict = torch.load(model_path)
    remove_prefix = 'module.'
    state_dict = {k[len(remove_prefix):] if k.startswith(remove_prefix) else k: v for k, v in state_dict.items()}

After that,


    model.load_state_dict(state_dict)

Worked!

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.