基于PyTorch的SimpleNN,实现MNIST手写数字识别
发布于 作者:苏南大叔 来源:程序如此灵动~
PyTorch是一个开源的深度学习框架,由Facebook开发并维护。它以动态计算图和灵活的API而闻名,非常适合研究和开发深度学习模型。本文中,苏南大叔将介绍如何使用PyTorch的全连接神经网络(SimpleNN),实现一个简单的MNIST手写数字识别模型。

苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码编程经验总结。测试环境:win10,python@3.12.9,torch@2.6.0,torchvision@0.21.0。值得注意的是:pytorch的正式包名是torch,这很容易引起混淆。另外,还需要额外安装一个图像库torchvision。
项目概述
本文依然还是处理MNIST手写数字数据集,这里通过PyTorch的全连接神经网络(SimpleNN)来处理。为了验证这个数据集最终生成的大模型的效果。最终,还会通过Gradio搭建一个交互式界面,允许用户上传手写数字图片并获得预测结果。

MNIST数据集包含7万张28x28像素的灰度手写数字图片,每张图片对应一个数字(0-9)标签。本文中的MNIST数据集,是由torchvision.datasets.MNIST()内置函数来获得的,依然不是原版的图片格式。
参考文章:
模型架构
使用的是【简单的全连接神经网络】(SimpleNN),其架构如下:
- 输入层:
28x28的图像被展平为784个特征。 - 隐藏层:一个包含
128个神经元的全连接层,激活函数为ReLU()。 - 输出层:一个包含
10个神经元的全连接层,对应数字0-9。
代码实现如下:
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.flatten(x)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x数据加载与预处理
torchvision 提供的 MNIST 数据集,并对数据进行以下预处理:
- 归一化:将像素值从
[0, 255]映射到[-1, 1]。 - 转换为张量:将图片转换为
PyTorch张量。
代码如下:
def load_data(batch_size=64):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader模型加载
这个pytorch有自己的加载已有模型的方案,参考代码:
model = SimpleNN()
if os.path.exists("pytorch_mnist_model.pth"):
model.load_state_dict(torch.load("pytorch_mnist_model.pth"))
print("Model loaded from pytorch_mnist_model.pth")
else:
train_model(model, train_loader)
torch.save(model.state_dict(), "pytorch_mnist_model.pth")
print("Model trained and saved to pytorch_mnist_model.pth")
模型训练
使用交叉熵损失函数(CrossEntropyLoss())和Adam优化器(Adam Optimizer)来训练模型。过程包括:
- 前向传播计算损失。
- 反向传播更新权重。
代码如下:
def train_model(model, train_loader, epochs=5):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
model.train()
for epoch in range(epochs):
for images, labels in train_loader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()模型评估
在测试集上评估模型的准确率:
def evaluate_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
output = model(images)
_, predicted = torch.max(output.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
return accuracy在训练 5 个epoch 后,我们的模型在测试集上的准确率约为 97%。
交互式预测界面
使用 Gradio 搭建了一个简单的交互式界面,允许用户上传手写数字图片并获得预测结果。以下是核心代码:
def predict_digit(image):
model = SimpleNN()
model.load_state_dict(torch.load("pytorch_mnist_model.pth"))
model.eval()
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("L").resize((28, 28))
image_tensor = transforms.ToTensor()(image).unsqueeze(0)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
return int(predicted.item())Gradio 界面代码:
interface = gr.Interface(
fn=predict_digit,
inputs=gr.Image(image_mode="L", label="Upload a 28x28 grayscale image"),
outputs=gr.Label(label="Prediction"),
title="MNIST Digit Prediction",
description="28*28灰度图"
)
interface.launch()参考文章:
完整代码

从图上可以看到:虽然说程序跑出来的预测率挺高,但是到实际的预测阶段,效果依然不尽如人意。
完整代码如下:
总结
PyTorch的灵活性和Gradio的易用性使得构建和部署深度学习模型变得更加简单。更多机器学习的预测文章,请点击: