模型微调(finetune)

----接上次的鸟的图像分类,其acc为84%。

这次依然使用此数据集,并用resenet网络进行finetune,然后进行鸟的图像分类。

1、什么是finetune?

       利用已训练好的模型进行重构(自己的理解)。 对给定的预训练模型(用数据训练好的模型)进行微调,直接利用预训练模型进行微调可以节省许多的时间,能在比较小的epoch下就达到比较好的效果。通常进行微调,1、自己构建模型效果差,所以采用一些常用的模型,别人用数据训好的。2、数据量不够大,所以采用微调。以下是模型微调的例子:

2、数据为鸟类的数据集,其一共有4个类别,如下所示:

 1、数据的类别

 2、图像数据

数据的前期处理和划分,划分可以用random.shuffule直接进行打乱,然后划分。

2、数据导入

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

random_seed = 1
torch.manual_seed(random_seed)

transform = transforms.Compose([
    # transforms.RandomRotation(1),
    transforms.Resize(224), #
    transforms.CenterCrop(224), #
    transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
    transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]),      # 标准化至[-1, 1],规定均值和标准差

])


transform1 = transforms.Compose([
    # transforms.RandomRotation(1),
    # transforms.Resize(224), #
    # transforms.CenterCrop(224), # 从图片中间切出224*224的图片
    transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
    transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]), # 标准化至[-1, 1],规定均值和标准差
    # transforms.Normalize(mean=[.5], std=[.5])   # 通道数为1
    #input[channel] = (input[channel] - mean[channel]) / std[channel]
])


class DogLoader(torch.utils.data.Dataset):

    def __init__(self,root,train=True, img_transform = None, target_transform=None,transform = None):  # 负责将输入的数据做成list形式
        # 这样不会调用父类中的init方法,这个是重载了子类的init,并且这里的属性都属于全局的属性

        self.root = root
        self.transform = img_transform
        self.target_transform = target_transform
        self.train = train
        self.transforms = transform

        self.data = []
        self.label = []
        if self.train:
            with open('trainImg.txt') as fr:
                fr = fr.readlines()
                for imgPath in fr:
                    img = imgPath.split('\t')[0]
                    label = imgPath.split('\t')[1]

                    self.data.append(root+img)
                    self.label.append(int(label.strip()))

        else:
            with open('testImg.txt') as fr1:
                fr1 = fr1.readlines()
                for imgPath in fr1:
                    img = imgPath.split('\t')[0]
                    label = imgPath.split('\t')[1]

                    self.data.append(root+img)
                    self.label.append(int(label.strip()))


    def __getitem__(self, index):   # 对数据进行编码,然后转换成我们想要的格式

        img,label = self.data[index],self.label[index]
        img = Image.open(img).convert('RGB')   # 将图片转为RGB图像,为了有一些图像不是RGB的
        # img = Image.open(img_path).convert('L')    # 将RGB的三通道转为一通道的数
        if self.transforms:
            img=self.transforms(img)     # 传入transforms是PIL数据

        array = np.asarray(img)
        img = torch.from_numpy(array)

        return img, label

    def __len__(self):
        return len(self.data)

train_data = DogLoader('data/',train=True,transform=transform)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True, drop_last=False, num_workers=0)

test_data = DogLoader('data/',train=False,transform=transform1)   # 测试不进行数据增强,训练继续宁数据增
test_loader = DataLoader(test_data, batch_size=16, shuffle=True, drop_last=False, num_workers=0)

# for i, trainData in enumerate(train_loader):  # 将一个类的对象可以像list那样调用
#     print("第 {} 个Batch \n{}".format(i, trainData))
#
# for i, testData in enumerate(test_loader):  # 将一个类的对象可以像list那样调用
#     print("第 {} 个Batch \n{}".format(i, testData))

__init__() :类的属性形式,这里主要获取图像的地址和图像的标签。

__getitem__():  当使用这个函数时,它的实例对象(假设为P)就可以以P[key]形式取值,当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。此处使用__getitem__()函数主要是通过后面训练时for循环得到图像的数据和标签。

2、预训练模型

# -*- coding: utf-8 -*-
from torchvision import models
from torch import nn
from global_config import *


def fine_tune_resnet18(): # 这里表示为

    model_ft = models.resnet18(pretrained=True)
    '''
    这里写为True,会自动下载模型的参数,并加载到模型中。
    当然也可以手动下载模型的参数,然后将模型的参数加载到模型中
    '''
    # 把前面的特征进行了拼接
    print('num_features', model_ft)
    num_features = model_ft.fc.in_features
    # fine tune we change original fc layer into classes num of our own
    model_ft.fc = nn.Linear(num_features, 4)

    if USE_GPU:
        model_ft = model_ft.cuda()
    return model_ft


def fine_tune_vgg16():

    model_ft = models.vgg16(pretrained=True)

    print('fine_tune_vgg16() = ',model_ft)
    num_features = model_ft.classifier[6].in_features
    model_ft.classifier[6] = nn.Linear(num_features, 4)

    if USE_GPU:
        model_ft = model_ft.cuda()
    return model_ft


def fine_tune_resnet18_():
    """
     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    """
    model_ft = models.resnet18(pretrained=False)
    model_ft.load_state_dict(torch.load('resnet18-5c106cde.pth'))   # 加载已经下载好的模型参数
    print('model_ft: ',model_ft)
    print('resnet18-5c106cde.pth',model_ft.load_state_dict(torch.load('resnet18-5c106cde.pth')))
    num_features = model_ft.fc.in_features
    # fine tune we change original fc layer into classes num of our own
    model_ft.fc = nn.Linear(num_features, 4)

    if USE_GPU:
        model_ft = model_ft.cuda()
    return model_ft


def fine_tune_resnet50():
    # 实际任务中这个挺重要的
    resNet50 = models.resnet50(pretrained=True) # 调用的预训练网络
    ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)   # 自己定义的网络

    # 读取参数
    pretrained_dict = resNet50.state_dict()   # 读取预训练网络模型的参数
    model_dict = ResNet50.state_dict()   # 读自定义模型的参数

    # 将pretained_dict里不属于model_dict的键剔除掉
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}   # 剔除一些不同的网络模型参数

    # 更新现有的model_dict
    model_dict.update(pretrained_dict)

    # 加载真正需要的state_dict
    ResNet50.load_state_dict(model_dict)

预训练模型的调用,models.resnet18(),表示的是调用resnet18模型,当models.resnet18(pretrained=True)的时候,则表示直接下载了模型的参数。当models.resnet18(pretrained=False)的时候,可以手动下载好模型torch.load('resnet18-5c106cde.pth')。由于不同的数据可能的类别不同所以通常对最后的一层更改。

4、实验结果:

经过2个epoch就可以达到99的准确率。