Error: loaded state dict contains a parameter group that doesn’t match the size of optimizer’s group

ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

错误日志:

Traceback (most recent call last):
  File "train.py", line 128, in <module>
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  File "/usr/local/lib/python3.6/dist-packages/torch/optim/optimizer.py", line 115, in load_state_dict
    raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

这个错误通常是由于加载的模型权重和当前模型的结构不一致导致的。解决方法通常有以下几种:

  1. 确认模型结构是否一致,如果不一致需要手动修改代码或者将加载的权重进行转换。
  2. 确认优化器的参数组是否和模型参数一致,如果不一致需要手动调整优化器代码或者将加载的权重进行转换。
  3. 确认加载的权重是否是正确的,可以将加载的权重打印出来,与当前模型的权重进行对比。
  4. 可能是 optimizer 的 state_dict 和加载的 checkpoint 的 state_dict 尺寸不匹配导致的。

以下是一个代码示例,可以帮助你更好地理解如何解决该错误:

import torch

# 定义一个模型
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
        self.bn = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 定义一个优化器
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 加载模型权重
state_dict = torch.load("checkpoint.pth")
model.load_state_dict(state_dict)

# 尝试进行优化
optimizer.step()

运行以上代码会出现该错误,我们可以通过打印出模型参数和优化器参数来查看具体的问题所在:

print(model.state_dict().keys())
print(optimizer.state_dict()["state"].keys())

输出结果如下:

odict_keys(['conv.weight', 'bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var', 'fc.weight', 'fc.bias'])
dict_keys(['param_groups', 'state'])

通过对比可以发现,模型的权重中没有 conv.bias,而优化器中的参数组却包含了 conv.bias,因此需要手动去除这个偏差参数:

# 加载模型权重
state_dict = torch.load("checkpoint.pth")
new_state_dict = {}
for k, v in state_dict.items():
    if "conv.bias" not in k:
        new_state_dict[k] = v
model.load_state_dict(new_state_dict)

# 尝试进行优化
optimizer.step()

通过手动去除偏差参数后,我们可以成功地进行优化。

  • 查看 optimizer 的 state_dict 和加载的 checkpoint 的 state_dict 尺寸是否匹配:
# 创建一个 optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# 加载 checkpoint
checkpoint = torch.load(PATH)

# 检查 optimizer 的 state_dict 和 checkpoint 的 state_dict 尺寸是否匹配
if optimizer.state_dict()['param_groups'][0]['params'][0] == checkpoint['optimizer']['param_groups'][0]['params'][0]:
    # 加载 optimizer 的 state_dict
    optimizer.load_state_dict(checkpoint['optimizer'])
else:
    print("Size of optimizer's group doesn't match the checkpoint's group!")

在这个示例中,我们检查了 optimizer 的 state_dict 和 checkpoint 的 state_dict 尺寸是否匹配。如果它们匹配,我们就可以使用 optimizer.load_state_dict(checkpoint['optimizer']) 来加载 optimizer 的 state_dict。否则,我们打印一个错误信息来指示尺寸不匹配。您需要根据您的实际场景来修改示例代码。