我们相信:世界是美好的,你是我也是。 来玩一下解压小游戏吧!

PyTorch是一个开源的深度学习框架,由Facebook开发并维护。它以动态计算图和灵活的API而闻名,非常适合研究和开发深度学习模型。本文中,苏南大叔将介绍如何使用PyTorch的全连接神经网络(SimpleNN),实现一个简单的MNIST手写数字识别模型。

苏南大叔:基于PyTorch的SimpleNN,实现MNIST手写数字识别 - simplenn全连接网络
基于PyTorch的SimpleNN,实现MNIST手写数字识别(图4-1)

苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码编程经验总结。测试环境:win10python@3.12.9torch@2.6.0torchvision@0.21.0。值得注意的是:pytorch的正式包名是torch,这很容易引起混淆。另外,还需要额外安装一个图像库torchvision

项目概述

本文依然还是处理MNIST手写数字数据集,这里通过PyTorch的全连接神经网络(SimpleNN)来处理。为了验证这个数据集最终生成的大模型的效果。最终,还会通过Gradio搭建一个交互式界面,允许用户上传手写数字图片并获得预测结果。

苏南大叔:基于PyTorch的SimpleNN,实现MNIST手写数字识别 - pytorch-datasets-mnist
基于PyTorch的SimpleNN,实现MNIST手写数字识别(图4-2)

MNIST数据集包含7万张28x28像素的灰度手写数字图片,每张图片对应一个数字(0-9)标签。本文中的MNIST数据集,是由torchvision.datasets.MNIST()内置函数来获得的,依然不是原版的图片格式。

参考文章:

模型架构

使用的是【简单的全连接神经网络】(SimpleNN),其架构如下:

  1. 输入层28x28 的图像被展平为784个特征。
  2. 隐藏层:一个包含128个神经元的全连接层,激活函数为ReLU()
  3. 输出层:一个包含10个神经元的全连接层,对应数字0-9

代码实现如下:

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

数据加载与预处理

torchvision 提供的 MNIST 数据集,并对数据进行以下预处理:

  1. 归一化:将像素值从[0, 255]映射到[-1, 1]
  2. 转换为张量:将图片转换为PyTorch张量。

代码如下:

def load_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

模型加载

这个pytorch有自己的加载已有模型的方案,参考代码:

model = SimpleNN()
if os.path.exists("pytorch_mnist_model.pth"):
    model.load_state_dict(torch.load("pytorch_mnist_model.pth"))
    print("Model loaded from pytorch_mnist_model.pth")
else:
    train_model(model, train_loader)
    torch.save(model.state_dict(), "pytorch_mnist_model.pth")
    print("Model trained and saved to pytorch_mnist_model.pth")

苏南大叔:基于PyTorch的SimpleNN,实现MNIST手写数字识别 - 文件结构
基于PyTorch的SimpleNN,实现MNIST手写数字识别(图4-3)

模型训练

使用交叉熵损失函数(CrossEntropyLoss())和Adam优化器(Adam Optimizer)来训练模型。过程包括:

  1. 前向传播计算损失。
  2. 反向传播更新权重。

代码如下:

def train_model(model, train_loader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    model.train()
    for epoch in range(epochs):
        for images, labels in train_loader:
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

模型评估

在测试集上评估模型的准确率:

def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            output = model(images)
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    return accuracy

在训练 5 个epoch 后,我们的模型在测试集上的准确率约为 97%

交互式预测界面

使用 Gradio 搭建了一个简单的交互式界面,允许用户上传手写数字图片并获得预测结果。以下是核心代码:

def predict_digit(image):
    model = SimpleNN()
    model.load_state_dict(torch.load("pytorch_mnist_model.pth"))
    model.eval()
    
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    image = image.convert("L").resize((28, 28))
    image_tensor = transforms.ToTensor()(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(image_tensor)
        _, predicted = torch.max(output, 1)
    
    return int(predicted.item())

Gradio 界面代码:

interface = gr.Interface(
    fn=predict_digit,
    inputs=gr.Image(image_mode="L", label="Upload a 28x28 grayscale image"),
    outputs=gr.Label(label="Prediction"),
    title="MNIST Digit Prediction",
    description="28*28灰度图"
)
interface.launch()

参考文章:

完整代码

苏南大叔:基于PyTorch的SimpleNN,实现MNIST手写数字识别 - 预测结果
基于PyTorch的SimpleNN,实现MNIST手写数字识别(图4-4)

从图上可以看到:虽然说程序跑出来的预测率挺高,但是到实际的预测阶段,效果依然不尽如人意。

完整代码如下:

newsn.net:这里是【评论】可见内容

总结

PyTorch的灵活性和Gradio的易用性使得构建和部署深度学习模型变得更加简单。更多机器学习的预测文章,请点击:

如果本文对您有帮助,或者节约了您的时间,欢迎打赏瓶饮料,建立下友谊关系。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。