|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | + Created on 2021/7/1 20:34 |
| 4 | + Filename : demo_mnist_cnn.py |
| 5 | + Author : Taosy.W |
| 6 | + Zhihu : https://www.zhihu.com/people/1105936347 |
| 7 | + Github : https://github.com/AFei19911012/PythonSamples |
| 8 | + Description: |
| 9 | +""" |
| 10 | + |
| 11 | +# ======================================================= |
| 12 | +from __future__ import print_function |
| 13 | +import os |
| 14 | +import random |
| 15 | +import keras |
| 16 | +import numpy as np |
| 17 | +from PIL import Image |
| 18 | +from keras.layers.core import Dense, Activation, Flatten |
| 19 | +from keras.models import Sequential, load_model |
| 20 | +from keras.layers import Conv2D, MaxPooling2D, Dropout |
| 21 | +from keras.preprocessing.image import ImageDataGenerator |
| 22 | + |
| 23 | + |
| 24 | +def load_data(): |
| 25 | + """ 读取图像数据 """ |
| 26 | + """ 读取文件夹 mnist 下的 42000 张图片,图片为灰度图,图像大小 28*28 """ |
| 27 | + data = np.empty((42000, 28, 28, 1), dtype="float32") |
| 28 | + label = np.empty((42000,), dtype="uint8") |
| 29 | + imgs = os.listdir("D:/MyPrograms/DataSet/mnist") |
| 30 | + num = len(imgs) |
| 31 | + for i in range(num): |
| 32 | + img = Image.open("D:/MyPrograms/DataSet/mnist/" + imgs[i]) |
| 33 | + arr = np.asarray(img, dtype="float32") |
| 34 | + data[i, :, :, 0] = arr |
| 35 | + label[i] = int(imgs[i].split('.')[0]) |
| 36 | + data /= np.max(data) |
| 37 | + data -= np.mean(data) |
| 38 | + return data, label |
| 39 | + |
| 40 | + |
| 41 | +def data_augmentation(x_train): |
| 42 | + datagen = ImageDataGenerator( |
| 43 | + featurewise_center=False, # 将整个数据集的均值设为0 |
| 44 | + samplewise_center=False, # 将每个样本的均值设为0 |
| 45 | + featurewise_std_normalization=False, # 将输入除以整个数据集的标准差 |
| 46 | + samplewise_std_normalization=False, # 将输入除以其标准差 |
| 47 | + zca_whitening=False, # 运用 ZCA 白化 |
| 48 | + zca_epsilon=1e-06, # ZCA 白化的 epsilon 值 |
| 49 | + rotation_range=0, # 随机旋转图像范围 (角度, 0 to 180) |
| 50 | + width_shift_range=0.1, # 随机水平移动图像 (总宽度的百分比) |
| 51 | + height_shift_range=0.1, # 随机垂直移动图像 (总高度的百分比) |
| 52 | + shear_range=0., # 设置随机裁剪范围 |
| 53 | + zoom_range=0., # 设置随机放大范围 |
| 54 | + channel_shift_range=0., # 设置随机通道切换的范围 |
| 55 | + fill_mode='nearest', # 设置填充输入边界之外的点的模式 |
| 56 | + cval=0., # 在 fill_mode = "constant" 时使用的值 |
| 57 | + horizontal_flip=True, # 随机水平翻转图像 |
| 58 | + vertical_flip=False, # 随机垂直翻转图像 |
| 59 | + rescale=None, # 设置缩放因子 (在其他转换之前使用) |
| 60 | + preprocessing_function=None, # 设置将应用于每一个输入的函数 |
| 61 | + data_format=None, # 图像数据格式,"channels_first" 或 "channels_last" 之一 |
| 62 | + validation_split=0.0) # 保留用于验证的图像比例(严格在 0 和 1 之间) |
| 63 | + datagen.fit(x_train) |
| 64 | + return datagen |
| 65 | + |
| 66 | + |
| 67 | +def train(): |
| 68 | + """ 训练模型 """ |
| 69 | + """ 加载数据 """ |
| 70 | + data, label = load_data() |
| 71 | + print(data.shape[0], ' samples') |
| 72 | + |
| 73 | + """ label 为 0~9 共 10 个类别,keras 要求格式为 binary class matrices """ |
| 74 | + label = keras.utils.to_categorical(label, 10) |
| 75 | + print(data.shape[1:]) |
| 76 | + """ 构建 CNN 模型 """ |
| 77 | + model = Sequential() |
| 78 | + model.add(Conv2D(32, (3, 3), padding='same', input_shape=data.shape[1:])) |
| 79 | + model.add(Activation('relu')) |
| 80 | + model.add(Conv2D(32, (3, 3))) |
| 81 | + model.add(Activation('relu')) |
| 82 | + model.add(MaxPooling2D(pool_size=(2, 2))) |
| 83 | + model.add(Dropout(0.25)) |
| 84 | + |
| 85 | + model.add(Conv2D(64, (3, 3), padding='same')) |
| 86 | + model.add(Activation('relu')) |
| 87 | + model.add(Conv2D(64, (3, 3))) |
| 88 | + model.add(Activation('relu')) |
| 89 | + model.add(MaxPooling2D(pool_size=(2, 2))) |
| 90 | + model.add(Dropout(0.25)) |
| 91 | + |
| 92 | + model.add(Flatten()) |
| 93 | + model.add(Dense(512)) |
| 94 | + model.add(Activation('relu')) |
| 95 | + model.add(Dropout(0.5)) |
| 96 | + model.add(Dense(10)) |
| 97 | + model.add(Activation('softmax')) |
| 98 | + |
| 99 | + """ 初始化 RMSprop 优化器 """ |
| 100 | + opt = keras.optimizers.RMSprop(learning_rate=0.0001, decay=1e-6) |
| 101 | + |
| 102 | + """ 利用 RMSprop 来训练模型 """ |
| 103 | + model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) |
| 104 | + |
| 105 | + """ 将 20% 的数据作为验证集 """ |
| 106 | + model.fit(data, label, batch_size=32, epochs=100, shuffle=True, verbose=1, validation_split=0.2) |
| 107 | + # model.fit(data, label, batch_size=32, epochs=100, validation_data=(x_test, y_test), shuffle=True) |
| 108 | + """ 数据增强 """ |
| 109 | + # datagen = data_augmentation(data) |
| 110 | + # model.fit_generator(datagen.flow(data, label, batch_size=32), epochs=100, validation_split=0.2, workers=4) |
| 111 | + |
| 112 | + """ 保存模型 """ |
| 113 | + model.save('models/mnist_cnn.h5') |
| 114 | + print('Saved trained model at: models/mnist_cnn.h5') |
| 115 | + |
| 116 | + """ 评估训练模型 """ |
| 117 | + scores = model.evaluate(data, label, verbose=1) |
| 118 | + print('Test loss:', scores[0]) |
| 119 | + print('Test accuracy:', scores[1]) |
| 120 | + |
| 121 | + |
| 122 | +def test(): |
| 123 | + """ 模型测试 """ |
| 124 | + """ 加载模型 """ |
| 125 | + model = load_model('models/mnist_cnn.h5') |
| 126 | + """ 加载测试数据 """ |
| 127 | + data, label = load_data() |
| 128 | + """ 随机测试 """ |
| 129 | + accu = 0 |
| 130 | + x = np.empty((1, 28, 28, 1), dtype="float32") |
| 131 | + for i in range(0, 100): |
| 132 | + index = random.randint(0, data.shape[0]) |
| 133 | + x[0] = data[index] |
| 134 | + y = label[index] |
| 135 | + """ 显示数字 """ |
| 136 | + # plt.imshow(x) |
| 137 | + # plt.title(f'number {y}') |
| 138 | + # plt.show() |
| 139 | + """ 预处理 """ |
| 140 | + predict = model.predict(x) |
| 141 | + predict = np.argmax(predict) |
| 142 | + """ 计算正确率 """ |
| 143 | + if y == predict: |
| 144 | + accu += 1 |
| 145 | + print(f'number: {y}') |
| 146 | + print(f'predicted: {predict}') |
| 147 | + print(' --- ') |
| 148 | + print(f'accu = {accu / 100}') |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == '__main__': |
| 152 | + # train() |
| 153 | + test() |
0 commit comments