基于TensorFlow的卷积神经网络,实现 MNIST 手写数字识别
发布于 作者:苏南大叔 来源:程序如此灵动~

既然是MNIST
手写数字识别的话题,本文中,苏南大叔使用tensorflow
的卷积神经网络cnn
进行训练。从simplenn
切换到cnn
后,生成的模型的数字识别精准度,得到了质的提升。所以,推荐本文中的代码。
苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码编程经验总结。测试环境:win10
,python@3.12.9
,tensorflow@2.19.0
。本文将介绍如何使用TensorFlow
和卷积神经网络(CNN
)实现MNIST
手写数字识别,并通过Gradio
提供一个支持多张图片预测的交互式界面。
TensorFlow
TensorFlow
是一个由谷歌开发的开源深度学习框架,广泛应用于机器学习和人工智能领域。它支持从简单的模型到复杂的深度学习网络的构建和训练。
MNIST
数据集包含七万张 28x28 像素的灰度手写数字图片,每张图片对应一个数字(0-9)。
模型架构
使用卷积神经网络(CNN)来处理 MNIST 数据集。架构如下:
- 输入层:28x28 的灰度图像。
- 卷积层 1:32 个 3x3 的卷积核,激活函数为 ReLU。
- 池化层 1:2x2 的最大池化层。
- 卷积层 2:64 个 3x3 的卷积核,激活函数为 ReLU。
- 池化层 2:2x2 的最大池化层。
- 全连接层:128 个神经元,激活函数为 ReLU。
- 输出层:10 个神经元,激活函数为 Softmax,用于分类。
代码实现如下:
def create_model():
model = Sequential([
Input(shape=(28, 28, 1)), # Add channel dimension for grayscale images
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'), # First convolutional layer
tf.keras.layers.MaxPooling2D((2, 2)), # First pooling layer
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), # Second convolutional layer
tf.keras.layers.MaxPooling2D((2, 2)), # Second pooling layer
Flatten(),
Dense(128, activation='relu'), # Fully connected layer
Dense(10, activation='softmax') # Output layer for 10 classes
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
数据加载与预处理
使用 TensorFlow
提供的MNIST
数据集,并对数据进行以下预处理:
- 归一化:将像素值从 [0, 255] 映射到 [0, 1]。
- 独热编码:将标签转换为独热编码格式。
代码如下:
def load_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
return (x_train, y_train), (x_test, y_test)
模型训练与评估
使用交叉熵损失函数(Categorical Crossentropy
)和 Adam
优化器(Adam Optimizer
)来训练模型。训练完成后,在测试集上评估模型的准确率。
训练代码:
def train_model(model, x_train, y_train, epochs=5):
model.fit(x_train, y_train, epochs=epochs)
评估代码:
def evaluate_model(model, x_test, y_test):
loss, accuracy = model.evaluate(x_test, y_test)
return accuracy
在训练 【5个epoch
】后,模型在测试集上的准确率约为 99%。表面上来看,虽然和之前的simplenn
准确率相差不大。但是,实际上的实测效果,区别显著!
保存模型
tensorflow
的模型文件,和sklearn
或者pytorch
的模型文件,存在着明显不同,加载机制也是不同的。
from tensorflow.keras.models import load_model
#
MODEL_PATH = "tensorflow_mnist_model.h5"
if os.path.exists(MODEL_PATH):
model = load_model(MODEL_PATH)
# Recompile the model to ensure metrics are initialized
# model.compile(optimizer='adam',
# loss='categorical_crossentropy',
# metrics=['accuracy'])
print("Loaded existing model.")
else:
(x_train, y_train), (x_test, y_test) = load_data()
model = create_model()
train_model(model, x_train, y_train)
model.save(MODEL_PATH)
print("Model trained and saved.")
accuracy = evaluate_model(model, x_test, y_test)
print(f'Accuracy: {accuracy}')
在加载已有模型的时候,存在一个可能的warning
信息:
WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
使用下面的代码进行去除:
import absl.logging
# Suppress absl logging warnings
absl.logging.set_verbosity(absl.logging.ERROR)
交互式预测界面
使用 Gradio
提供了一个支持多张图片预测的交互式界面。预测多张的接口,是因为需要同时验证0-9
十个数字。以下是核心代码:
单张图片预测
def predict_digit(image):
if image is None:
return "No image provided. Please upload an image."
if len(image.shape) == 2: # If the image is grayscale (2D), add a channel dimension
image = np.expand_dims(image, axis=-1)
image = tf.image.resize(image, (28, 28)).numpy() # Resize to 28x28
image = image.reshape(1, 28, 28, 1) / 255.0 # Normalize and reshape for the model
prediction = model.predict(image)
return int(np.argmax(prediction)) # Convert to Python int
多张图片预测
def predict_digits(files):
if files is None or len(files) == 0:
return "No images provided. Please upload one or more images."
results = []
for file in files:
image = tf.io.decode_image(tf.io.read_file(file), channels=1).numpy() # Read and decode image
image_resized = tf.image.resize(image, (28, 28)).numpy() # Resize to 28x28
image_normalized = image_resized.reshape(1, 28, 28, 1) / 255.0 # Normalize and reshape for the model
prediction = model.predict(image_normalized)
predicted_label = int(np.argmax(prediction)) # Convert to Python int
results.append((image.squeeze(), f"Prediction: {predicted_label}"))
return results
Gradio 界面代码:
interface = gr.Interface(
fn=predict_digits,
inputs=gr.Files(file_types=["image"], label="Upload Images"),
outputs=gr.Gallery(label="Predictions", columns=5), # Adjust columns to control image size
live=True
)
interface.launch()
完整代码
总结
通过本文,苏南大叔使用TensorFlow
和卷积神经网络(CNN
)构建了一个高效的手写数字识别模型,并通过Gradio
提供了一个支持多张图片预测的交互式界面。以下是项目的关键点:
- 模型架构:使用 CNN 提取图像特征,提升分类性能。
- 数据预处理:归一化和独热编码。
- 训练与评估:模型在测试集上的准确率约为 99%。
- 交互式界面:支持用户上传多张图片并获得预测结果。
更多苏南大叔的机器学习的文章,请点击:


