加载模型:如何使用预训练VGG的参数部分初始化修改版VGG模型?
加载模型:如何使用预训练VGG的部分参数初始化修改版VGG模型?
实验过程中,经常使用已有的参数来初始化模型参数,避免从头开始训练,如果在原先版本的模型上修改了部分网络,但仍想使用原模型初始化重合的部分网络,该如何做呢?
例如VGG16模型删除最后两个全连接层而修改为一个预测分支网络层和一个分类分支网络层?
加载预训练模型并修改模型
以GAN中的判别器为例,判别器为一个VGG16后加上两个网络层,一个用于预测真实性,一个用于分类。
预测:patchGAN or GAP;
分类:GAP
代码:
1.写一个VGG的模型:
from torchvision.models.vgg import VGG vggmodel = VGG(networks.make_layers(opt.cfg, batch_norm=True))
其中:
def make_layers(cfg, batch_norm=False): layers = [] in_channels = 3 for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers)
2.定义一个VGG16为基础模型,增加预测和分类的网络层的判别器
class VGGDiscriminator(nn.Module): def __init__(self, image_size, cfg, curr_dim, c_dim, norm_layer, batch_norm): super(VGGDiscriminator, self).__init__() self.features, self.repeat_num= self.build_downsample(cfg, c_dim, norm_layer, batch_norm) self.features = nn.Sequential(*self.features) kernel_size = int(image_size / np.power(2, self.repeat_num)) self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1) # , bias=False self.avgpool = nn.AvgPool2d(kernel_size, stride=1) self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=3, stride=1, padding=1) # , bias=False def build_downsample(self, cfg, c_dim, norm_layer, batch_norm=False): model = [] num = 0 in_channels = 3 for v in cfg: if v == 'M': model += [nn.MaxPool2d(kernel_size=2, stride=2)] num += 1 else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=batch_norm) if batch_norm: model += [conv2d, norm_layer(v), nn.ReLU(inplace=True)] else: model += [conv2d, nn.ReLU(inplace=True)] in_channels = v return model, num def forward(self, x): h = self.features(x) h = self.avgpool(h) out_src = self.conv1(h) out_cls = self.conv2(h) # 1*20*1*1 return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
3.加载预训练好的模型VGG16,和初始化判别器
vggmodel.load_state_dict(torch.load(opt.vgg16bn_pre_model)) model = create_model(opt)
4.获得vgg16的参数和判别器初始化后的参数
pretrained_dict = vggmodel.state_dict() model_dict = model.netD.module.state_dict()
5.剔除不在判别器中的VGG16中的网络层
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
6.模型更新来自vgg的参数
model_dict.update(pretrained_dict)
7.整体模型更新
model.netD.module.load_state_dict(model_dict)
注意:提取或更新参数时注意
请先 后发表评论~