LeNet-5 框架

Welcome file

LeNet-5框架代码

import torch  
import torch.nn as nn  
  
class LeNet5(nn.Module):  
 def __init__(self): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)  # 输入通道1,输出通道6  
 self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)       # 平均池化  
 self.conv2 = nn.Conv2d(6, 16, kernel_size=5) self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10)  
 def forward(self, x): x = torch.tanh(self.conv1(x))  # 原始论文使用 tanh 激活  
 x = self.pool1(x) x = torch.tanh(self.conv2(x)) x = self.pool2(x) x = x.view(-1, 16 * 5 * 5)     # 展平  
 x = torch.tanh(self.fc1(x)) x = torch.tanh(self.fc2(x)) x = self.fc3(x) return x  

MNIST分类完整完整项目

任务简介:

  • 目标:识别 0-9 的手写数字。
  • 数据集:MNIST 数据集(包含 60k 训练样本和 10k 测试样本)。
import torchvision  
import torch.optim as optim  
from torchvision import transforms  
  
# 数据预处理  
transform = transforms.Compose([transforms.Resize((32, 32)), # LeNet-5 输入尺寸为32x32  
 transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])  
# 加载数据集  
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)  
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)  
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)  
  
# 初始化模型、损失函数和优化器  
model = LeNet5()  
criterion = nn.CrossEntropyLoss()  
optimizer = optim.SGD(model.parameters(), lr=0.01)  
  
# 训练模型  
for epoch in range(10):  
 for images, labels in train_loader: outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')  
# 测试准确率  
correct = 0  
total = 0  
with torch.no_grad():  
 for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100*correct / total:.2f}%')  

可视化训练曲线:使用 Matplotlib 绘制损失和准确率曲线。

import matplotlib.pyplot as plt  
  
# 假设记录了每个epoch的loss和accuracy  
loss_history = [0.5, 0.3, 0.2, 0.1, 0.08, 0.06, 0.05, 0.04, 0.03, 0.02]  
plt.plot(loss_history)  
plt.xlabel('Epoch')  
plt.ylabel('Loss')  
plt.title('Training Loss Curve')  
plt.show()  

此博客中的热门博文

Numberical Analysis --- Interpolation & Polynomial Approximation

Compuer Animation(Postgraduate Course): Lecture 4:Keyframe interpolation and velocity control

Compuer Animation(Postgraduate Course): Lecture 3:Representation of transformation and rotation