基于VGG的迁移学习
步骤
- 读取本地的图片数据以及类别
- 模型的结构修改(
添加我们自定的分类层
) freeze掉原始VGG模型
编译以及训练和保存模型方式
输入数据进行预测
读取本地图片
ImageDataGenerator:生产图片的批次张量值并且提供数据增强功能
参数:
- rescale=1.0 / 255,:标准化
- zca_whitening=False: # zca白化的作用是针对图片进行PCA降维操作,减少图片的冗余信息
- rotation_range=20:默认0, 旋转角度,在这个角度范围随机生成一个值
- width_shift_range=0.2,:默认0,水平平移
- height_shift_range=0.2:默认0, 垂直平移
- shear_range=0.2:# 平移变换
- zoom_range=0.2:
- horizontal_flip=True:水平翻转
使用flow
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
gen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
for e in range(epochs):
print('Epoch', e)
batches = 0
for x_batch, y_batch in gen.flow(x_train, y_train, batch_size=32):
model.fit(x_batch, y_batch)
使用flow_from_directory
- irectory=path,# 读取目录
- target_size=(h,w),# 目标形状
- batch_size=size,# 批数量大小
- class_mode=’binary’, # 目标值格式,One of “categorical”, “binary”, “sparse”,
- “categorical” :2D one-hot encoded labels
- “binary” will be 1D binary labels
- shuffle=True
这个API固定了读取的目录格式,参考:
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
'data/validation',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
# 使用fit_generator
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
VGG模型的修改
notop模型:不包含最后的3个全连接层。用来做fine-tuning专用,专门开源了这类模型。
# 在__init__中添加
self.base_model = VGG16(weights='imagenet', include_top=False)
做法:一个GlobalAveragePooling2D + 两个全连接层
如下:
from keras.layers import Dense, Input, Conv2D
from keras.layers import MaxPooling2D, GlobalAveragePooling2D
x = Input(shape=[8, 8, 2048])
# 假定最后一层CNN的层输出为(None, 8, 8, 2048)
x = GlobalAveragePooling2D(name='avg_pool')(x) # shape=(?, 2048)
# 取每一个特征图的平均值作为输出,用以替代全连接层
x = Dense(1000, activation='softmax', name='predictions')(x) # shape=(?, 1000)
freeze 模型
让VGG结构当中的权重参数不参与训练,只训练我们添加的最后两层全连接网络的权重参数。
通过使用每一层的layer.trainable=False
def freeze_vgg_model(self):
for layer in self.base_model.layers:
layer.trainable = False
编译和训练
在迁移学习中算法:学习率初始化较小的值,0.001,0.0001,因为已经在已训练好的模型基础之上更新,所以不需要太大学习率去学习
def compile(self, model):
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
使用ModelCheckpoint指定相关参数:
calls = keras.callbacks.ModelCheckpoint(
filepath='./snn_model/transfer-{epoch:02d}-{acc:.2f}.h5',
monitor='val_acc',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1
)
fine_model.fit_generator(train_g, epochs=3, validation_data=test_g, callbacks=[calls])
预测
读取图片以及处理到模型中预测,加载我们训练的模型
def predict(self, model):
model.load_weights("./Transfer.h5")
# 2、对图片进行加载和类型修改
image = load_img("./data/test/dinosaurs/402.jpg", target_size=(224, 224))
# 转换成numpy array数组
image = img_to_array(image)
# 形状从3维度修改成4维
img = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
print("改变形状结果:", img.shape)
# 3、处理图像内容,归一化处理等,进行预测
img = preprocess_input(img)
y_predict = model.predict(img)
index = np.argmax(y_predict, axis=1)
print(self.label_dict[str(index[0])])
完整代码
import numpy as np
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python import keras
import tensorflow as tf
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.preprocessing.image import load_img, img_to_array
from tensorflow.python.keras.applications.vgg16 import preprocess_input, decode_predictions
class Transfer(object):
def __init__(self):
# 定义数据定义方式
self.train_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.test_generator = ImageDataGenerator(rescale=1.0 / 255.0)
self.train_dir = "./data/train"
self.test_dir = "./data/test"
# VGG16不包含全链接层模型
self.base_model = VGG16(weights='imagenet', include_top=False)
self.label_dict = {
'0': 'bus',
'1': 'dinosaurs',
'2': 'elephants',
'3': 'flowers',
'4': 'horse'
}
pass
def get_data(self):
train_g = self.train_generator.flow_from_directory(
self.train_dir,
target_size=(224, 224),
class_mode='binary',
batch_size=32,
)
test_g = self.test_generator.flow_from_directory(
self.test_dir,
target_size=(224, 224),
class_mode='binary',
batch_size=32,
)
return train_g, test_g
def refine_model(self):
# 1、获取原notop模型得出
x = self.base_model.outputs[0]
# 2、在输出后面增加我们结构
x = keras.layers.GlobalAveragePooling2D()(x)
# 新的迁移模型
x = keras.layers.Dense(1024, activation=tf.nn.relu)(x)
y_p = keras.layers.Dense(5, activation=tf.nn.softmax)(x)
fine_model = keras.models.Model(inputs=self.base_model.inputs,
outputs=y_p)
return fine_model
def freeze_model(self):
# 冻结模型,不训练
for layer in self.base_model.layers:
layer.trainable = False
def compile(self, model):
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
)
def fit(self, fine_model, train_g, test_g):
calls = keras.callbacks.ModelCheckpoint(
filepath='./snn_model/transfer-{epoch:02d}-{acc:.2f}.h5',
monitor='val_acc',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1
)
fine_model.fit_generator(train_g, epochs=3, validation_data=test_g, callbacks=[calls])
def predict(self, model):
# 加载我们自己模型
model.load_weights("./snn_model/transfer-03-0.98.h5")
# 读取图片
img = load_img("./data/test/bus/300.jpg", target_size=(224, 224))
image = img_to_array(img)
# 四维(224,224,3)—>(1,224,224,3)
img = image.reshape([1, image.shape[0], image.shape[1], image.shape[2]])
# 归一化处理
image = preprocess_input(img)
y_p = model.predict(image)
# print(y_p)
# 解码
# label = decode_predictions(y_p)
res = np.argmax(y_p, axis=1)
print(f"预测了类别为:{self.label_dict[str(res[0])]}")
def train(cnn):
train_g, test_g = cnn.get_data()
model = cnn.refine_model()
cnn.freeze_model()
cnn.compile(model)
cnn.fit(model, train_g, test_g)
def use_train(cnn):
model = cnn.refine_model()
cnn.predict(model)
if __name__ == '__main__':
cnn = Transfer()
# train(cnn)
use_train(cnn)
旋律一张汽车的图片:
本作品采用《CC 协议》,转载必须注明作者和本文链接