How to fine-tune
“如果你想拥有一座自己的城堡,是会买一个旧城堡进行翻修,还是从头开始筑城?”
“那么,如果旧城堡白给呢?”
fine-tune 就是对在已有数据集上得到的预训练模型进行“翻修”,以适用于新的数据集及任务。
取决于新数据集与旧数据集的差异,我们所需要 fine-tune 的程度也有所不同。
因为CNN模型卷积层提取基础及抽象特征,全连接层提取分类特征,所以 fine-tune 以fc层为主要目标。
一般 fine-tune 的方式有以下几种方案:
- 冻结conv层参数,仅训练更新fc层参数
- 同时更新conv、fc层参数,但conv层使用较小的学习率
- 仅使用预训练模型做conv层参数初始化,进行正常训练
- 先1后3进行两阶段训练
1 Keras冻结conv层进行fine-tune
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
import os
import time
os.environ['CUDA_VISIBLE_DEVICES']= '0'
from keras import initializers
from keras.optimizers import Adam
from keras.models import Model, load_model
from keras.applications.mobilenet_v2 import MobileNetV2
from keras.layers import Dense, Dropout, Flatten, AveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
data_dir = 'data/katana_gun'
CLASS_NUM = 2
BATCH_SIZE = 8
EPOCH = 5
def fine_tune():
# 1 data generator
datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=30,
shear_range=0.2,
validation_split=0.3
)
train_generator = datagen.flow_from_directory(
data_dir,
target_size=(224,224),
class_mode='categorical',
batch_size=BATCH_SIZE,
#save_to_dir='data/augment',
subset='training'
)
validation_generator = datagen.flow_from_directory(
data_dir,
target_size=(224,224),
class_mode='categorical',
batch_size=BATCH_SIZE,
#save_to_dir='data/augment',
subset='validation'
)
# 2 create model
conv_base = MobileNetV2(
input_shape = (224, 224, 3),weights='imagenet',include_top=False)
conv_base.trainable = False
x = conv_base.output
x = AveragePooling2D(pool_size=(7, 7))(x)
x = Flatten()(x)
x = Dense(128, activation='relu', kernel_initializer=initializers.he_normal(seed=None))(x)
x = Dropout(0.25)(x)
x = Dense(64,activation='relu', kernel_initializer=initializers.he_normal(seed=None))(x)
predictions = Dense(CLASS_NUM, kernel_initializer="glorot_uniform", activation='softmax')(x)
model = Model(inputs = conv_base.input, outputs=predictions)
# 3 train
opt = Adam()
model.compile(
loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy']
)
model.fit_generator(
train_generator,
epochs=EPOCH,
steps_per_epoch=train_generator.samples // BATCH_SIZE,
validation_data=validation_generator,
validation_steps=validation_generator.samples // BATCH_SIZE
)
return model
if __name__ == '__main__':
print('start.', time.ctime())
model = fine_tune()
print('train over.', time.ctime())
model.save('katana_gun.h5')
model = load_model('katana_gun.h5')
print('reload over.', time.ctime())
w, bias = model.layers[-1].get_weights()
print('end.', time.ctime())
2 Pytorch两阶段fine-tune
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import warnings
warnings.filterwarnings('ignore')
import time
import pickle
import torch
torch.backends.cudnn.benchmark = True
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torchsummary import summary
from efficientnet_pytorch import EfficientNet
#------------------------------- parameters ------------------------------
NUM_CLASSES = 7
NUM_EPOCHS = 10000
data_path = '/data/amoko/data/0714class7S7split'
dataset = os.path.basename(data_path)
model_save_path_prefix = 'Efficientnet3.' + dataset + '.'
classes_path = dataset + '_classes.pkl'
print(data_path)
mode = 'freeze_conv'
#mode = 'ultimate'
print('type:', mode)
if mode == 'freeze_conv':
BATCH_SIZE = 256
INTERVAL = 50
LR = 0.001
if mode == 'ultimate':
BATCH_SIZE = 64
INTERVAL = 100
LR = 0.0001
LR = 0.00001
ACC = '0.8173'
model_load_path = 'Efficientnet3.' + dataset + '.' + ACC
#------------------------------- parameters ------------------------------
def load_data():
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
image_datasets = {
'train': datasets.ImageFolder(
os.path.join(data_path, 'train'),data_transforms['train']),
'val': datasets.ImageFolder(
os.path.join(data_path, 'val'),data_transforms['val'])}
train_loader = torch.utils.data.DataLoader(
image_datasets['train'], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(
image_datasets['val'], batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
return train_loader, val_loader
def train(model, device, train_loader, epoch):
model.train()
train_loss = 0
correct = 0
for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
opt.zero_grad()
loss.backward()
opt.step()
if batch_idx % INTERVAL == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\t batch_loss: {:.6f}'
.format(epoch, batch_idx * len(inputs), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_loss += loss.item() # sum up batch loss
pred = outputs.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(labels.view_as(pred)).sum().item()
train_loss /= batch_idx
train_acc = correct / len(train_loader.dataset)
print('train_set: average_batch_loss: {:.4f}, accuracy: {}/{} ({:.2f}%)'.format(
train_loss, correct, len(train_loader.dataset), 100. * train_acc))
def validation(model, device, val_loader):
model.eval()
val_loss = 0
correct = 0
with torch.no_grad():
for batch_idx, (inputs, labels) in enumerate(val_loader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() # sum up batch loss
pred = outputs.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(labels.view_as(pred)).sum().item()
val_loss /= batch_idx
val_acc = correct / len(val_loader.dataset)
print('val_set: average_batch_loss: {:.4f}, accuracy: {}/{} ({:.2f}%)'.format(
val_loss, correct, len(val_loader.dataset), 100. * val_acc))
return val_acc
if __name__ == '__main__':
# 0 data
train_loader, val_loader = load_data()
class_names = train_loader.dataset.classes
with open(classes_path, 'wb') as fp:
pickle.dump(class_names, fp)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 1 model
if mode != 'ultimate':
model = EfficientNet.from_pretrained('efficientnet-b3')
for param in model.parameters():
param.requires_grad = False # save computation
conv_output = model._fc.in_features # 512 for resnet18, 1536 for b3, 1792 for b4
model._fc = nn.Linear(conv_output, NUM_CLASSES)
else:
model = torch.load(model_load_path)
for param in model.parameters():
param.requires_grad = True # unlock conv
model = model.to(device)
# 2 loss, opt
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=LR, weight_decay=0.0001)
print(summary(model, (3, 224, 224)))
# 3 train
print('start,', time.ctime())
best_val_acc = 0
for epoch in range(1, NUM_EPOCHS + 1):
train(model, device, train_loader, epoch)
val_acc = validation(model, device, val_loader)
if val_acc > best_val_acc:
best_val_acc = val_acc
model_save_path = model_save_path_prefix + '%0.4f' % (best_val_acc)
torch.save(model, model_save_path)
print('best_val_acc: %0.4f' % (best_val_acc))
print(time.ctime())
print('-' * 60)
print('end,', time.ctime())
Reference
[1] Pytorch Transfer learning tutorial
[2] MXNet 微调