How to fine-tune

“如果你想拥有一座自己的城堡,是会买一个旧城堡进行翻修,还是从头开始筑城?”

“那么,如果旧城堡白给呢?”

fine-tune 就是对在已有数据集上得到的预训练模型进行“翻修”,以适用于新的数据集及任务。

取决于新数据集与旧数据集的差异,我们所需要 fine-tune 的程度也有所不同。

因为CNN模型卷积层提取基础及抽象特征,全连接层提取分类特征,所以 fine-tune 以fc层为主要目标。

一般 fine-tune 的方式有以下几种方案:

  1. 冻结conv层参数,仅训练更新fc层参数
  2. 同时更新conv、fc层参数,但conv层使用较小的学习率
  3. 仅使用预训练模型做conv层参数初始化,进行正常训练
  4. 先1后3进行两阶段训练

1 Keras冻结conv层进行fine-tune

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import os
import time
os.environ['CUDA_VISIBLE_DEVICES']= '0'
from keras import initializers
from keras.optimizers import Adam
from keras.models import Model, load_model
from keras.applications.mobilenet_v2 import MobileNetV2
from keras.layers import Dense, Dropout, Flatten, AveragePooling2D
from keras.preprocessing.image import ImageDataGenerator


data_dir = 'data/katana_gun'
CLASS_NUM = 2
BATCH_SIZE = 8
EPOCH = 5

def fine_tune():
    # 1 data generator
    datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=30,
            shear_range=0.2,
            validation_split=0.3
            )
    train_generator = datagen.flow_from_directory(
            data_dir,
            target_size=(224,224),
            class_mode='categorical',
            batch_size=BATCH_SIZE,
            #save_to_dir='data/augment',
            subset='training'
            )
    validation_generator = datagen.flow_from_directory(
            data_dir,
            target_size=(224,224),
            class_mode='categorical',
            batch_size=BATCH_SIZE,
            #save_to_dir='data/augment',
            subset='validation'
            )
    
    # 2 create model
    conv_base = MobileNetV2(
            input_shape = (224, 224, 3),weights='imagenet',include_top=False)
    conv_base.trainable = False
    x = conv_base.output
    x = AveragePooling2D(pool_size=(7, 7))(x)
    x = Flatten()(x)
    x = Dense(128, activation='relu', kernel_initializer=initializers.he_normal(seed=None))(x)
    x = Dropout(0.25)(x)
    x = Dense(64,activation='relu', kernel_initializer=initializers.he_normal(seed=None))(x)
    predictions = Dense(CLASS_NUM,  kernel_initializer="glorot_uniform", activation='softmax')(x)
    model = Model(inputs = conv_base.input, outputs=predictions)
    
    # 3 train
    opt = Adam()
    model.compile(
        loss='categorical_crossentropy',
        optimizer=opt,
        metrics=['accuracy']
    )
    model.fit_generator(
        train_generator,
        epochs=EPOCH,
        steps_per_epoch=train_generator.samples // BATCH_SIZE,
        validation_data=validation_generator,
        validation_steps=validation_generator.samples // BATCH_SIZE
    )
    return model

if __name__ == '__main__':
    print('start.', time.ctime())
    model = fine_tune()
    print('train over.', time.ctime())
    model.save('katana_gun.h5')
    model = load_model('katana_gun.h5')
    print('reload over.', time.ctime())
    w, bias = model.layers[-1].get_weights()
    print('end.', time.ctime())

2 Pytorch两阶段fine-tune

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import warnings
warnings.filterwarnings('ignore')
import time
import pickle
import torch
torch.backends.cudnn.benchmark = True
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torchsummary import summary
from efficientnet_pytorch import EfficientNet

#------------------------------- parameters ------------------------------
NUM_CLASSES = 7
NUM_EPOCHS = 10000
data_path  = '/data/amoko/data/0714class7S7split'
dataset = os.path.basename(data_path)
model_save_path_prefix = 'Efficientnet3.' + dataset + '.'
classes_path =  dataset + '_classes.pkl'
print(data_path)

mode = 'freeze_conv'
#mode = 'ultimate'
print('type:', mode)
if mode == 'freeze_conv':
    BATCH_SIZE = 256
    INTERVAL = 50
    LR = 0.001
if mode == 'ultimate':
    BATCH_SIZE = 64
    INTERVAL = 100
    LR = 0.0001
    LR = 0.00001
    ACC = '0.8173'
    model_load_path = 'Efficientnet3.' + dataset + '.' + ACC
#------------------------------- parameters ------------------------------

def load_data():
    data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    image_datasets = {
            'train': datasets.ImageFolder(
                os.path.join(data_path, 'train'),data_transforms['train']),
            'val': datasets.ImageFolder(
                os.path.join(data_path, 'val'),data_transforms['val'])}
    train_loader = torch.utils.data.DataLoader(
        image_datasets['train'], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = torch.utils.data.DataLoader(
        image_datasets['val'], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)   
    return train_loader, val_loader

def train(model, device, train_loader, epoch):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if batch_idx % INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t batch_loss: {:.6f}'
                    .format(epoch, batch_idx * len(inputs), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item())) 
  
        train_loss += loss.item() # sum up batch loss
        pred = outputs.argmax(dim=1, keepdim=True) # get the index of the max log-probability
        correct += pred.eq(labels.view_as(pred)).sum().item()
        
    train_loss /= batch_idx
    train_acc = correct / len(train_loader.dataset)
    print('train_set: average_batch_loss: {:.4f}, accuracy: {}/{} ({:.2f}%)'.format(
        train_loss, correct, len(train_loader.dataset), 100. * train_acc))
  
def validation(model, device, val_loader):
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() # sum up batch loss
            pred = outputs.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(labels.view_as(pred)).sum().item()

    val_loss /= batch_idx
    val_acc = correct / len(val_loader.dataset)
    print('val_set:   average_batch_loss: {:.4f}, accuracy: {}/{} ({:.2f}%)'.format(
        val_loss, correct, len(val_loader.dataset), 100. * val_acc))
    return val_acc

if __name__ == '__main__':
    # 0 data
    train_loader, val_loader = load_data()
    class_names = train_loader.dataset.classes
    with open(classes_path, 'wb') as fp:
        pickle.dump(class_names, fp)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # 1 model
    if mode != 'ultimate':
        model = EfficientNet.from_pretrained('efficientnet-b3')
        for param in model.parameters():
            param.requires_grad = False # save computation
        conv_output = model._fc.in_features # 512 for resnet18, 1536 for b3, 1792 for b4
        model._fc = nn.Linear(conv_output, NUM_CLASSES)
    else:
        model = torch.load(model_load_path)
        for param in model.parameters():
            param.requires_grad = True # unlock conv
    model = model.to(device)

    # 2 loss, opt
    criterion = nn.CrossEntropyLoss()  
    opt = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=LR, weight_decay=0.0001)
    print(summary(model, (3, 224, 224)))
 
    # 3 train
    print('start,', time.ctime())
    best_val_acc = 0
    for epoch in range(1, NUM_EPOCHS + 1):
        train(model, device, train_loader, epoch)
        val_acc = validation(model, device, val_loader)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model_save_path = model_save_path_prefix + '%0.4f' % (best_val_acc)
            torch.save(model, model_save_path)
        print('best_val_acc: %0.4f' % (best_val_acc))
        print(time.ctime())
        print('-' * 60)
    print('end,', time.ctime())        

Reference

[1] Pytorch Transfer learning tutorial

[2] MXNet 微调