当先锋百科网

首页 1 2 3 4 5 6 7

keras实现图像数字多分类

目标:基于mnist数据集,建立mlp模型,实现0-9数字的十分类

1.实现mnist数据载入,可视化图形数字
2.完成数据预处理,图像数据维度转化与归一化,输出结果格式转化
3.计算模型在预测数据集的准确率
4.模型结构:两层隐藏层,每层有392个模型

一、数据处理及可视化

1、获取数据集

from keras.datasets import mnist
(X_train,y_train),(X_test,y_test) = mnist.load_data()

2、查看数据

X_train.shape 
#(60000, 28, 28)

3、部分数据的可视化

import matplotlib.pyplot as plt
img1 = X_train[0]
fig = plt.figure(figsize=(3,3))
plt.imshow(img1)
plt.title(y_train[0])
plt.show()

可视化效果为:数字5
在这里插入图片描述

二、数据预处理

1、查看图片大小

# 图片的大小
img1.shape
# (28, 28)

2、维度转换

feature_size = img1.shape[0]*img1.shape[1]
X_train_format = X_train.reshape(X_train.shape[0],feature_size)
X_test_format = X_test.reshape(X_test.shape[0],feature_size)
X_train_format.shape
# (60000, 784)

3、归一化处理

由于对图像进行数字处理,所以归一化时除以255即可

X_train_normal = X_train_format/255
X_test_normal = X_test_format/255

4、输出结果格式转化

tf版本过高时,导入包: from keras.utils import to_categorical
会显示报错
ImportError: cannot import name ‘to_categorical’ from ‘keras.utils’ (/usr/local/lib/python3.7/dist-packages/keras/utils/init.py)
现在keras完全置于tf模块中,这个要从tensoflow根模块导入,修改为:
from tensorflow.keras.utils import to_categorical

from tensorflow.keras.utils import to_categorical
y_train_format = to_categorical(y_train)
y_test_format = to_categorical(y_test)
print(y_train[0])
print(y_test_format[0])

三、建立模型,训练并预测

1、建立模型

2个隐藏层,每层有392个
最后的分类结果是10个

from keras.models import Sequential
from keras.layers import Dense,Activation
mlp = Sequential()
mlp.add(Dense(units=392,activation='sigmoid',input_dim=feature_size))
mlp.add(Dense(units=392,activation='sigmoid'))
mlp.add(Dense(units=10,activation='softmax'))
mlp.summary()

2、训练模型

#模型训练
mlp.fit(X_train_normal,y_train_format,epochs=10)

3、模型评估
3.1、训练集

# 训练集
import numpy as np 
y_train_predict = mlp.predict(X_train_normal)
y_train_predict=np.argmax(y_train_predict,axis=1)
y_train_predict
# array([5, 0, 4, ..., 5, 6, 8], dtype=int64)

准确度

# 计算准确率
from sklearn.metrics import accuracy_score
accuracy_train = accuracy_score(y_train,y_train_predict)
accuracy_train
# 0.9942833333333333

3.2 测试集

y_test_predict = mlp.predict(X_test_normal)
y_test_predict = np.argmax(y_test_predict,axis=1)
accuracy_test = accuracy_score(y_test,y_test_predict)
accuracy_test
# 0.98

四、可视化验证结果

mg2 = X_test[10]
fig2 = plt.figure(figsize=(3,3))
plt.imshow(img2)
plt.title(y_test_predict[10])

分类成功
在这里插入图片描述
完整代码已上传至
https://github.com/jrt-20/-