LeNet-5 框架
LeNet-5框架代码
import torch
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) # 输入通道1,输出通道6
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) # 平均池化
self.conv2 = nn.Conv2d(6, 16, kernel_size=5) self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10)
def forward(self, x): x = torch.tanh(self.conv1(x)) # 原始论文使用 tanh 激活
x = self.pool1(x) x = torch.tanh(self.conv2(x)) x = self.pool2(x) x = x.view(-1, 16 * 5 * 5) # 展平
x = torch.tanh(self.fc1(x)) x = torch.tanh(self.fc2(x)) x = self.fc3(x) return x
MNIST分类完整完整项目
任务简介:
- 目标:识别 0-9 的手写数字。
- 数据集:MNIST 数据集(包含 60k 训练样本和 10k 测试样本)。
import torchvision
import torch.optim as optim
from torchvision import transforms
# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)), # LeNet-5 输入尺寸为32x32
transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 加载数据集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)
# 初始化模型、损失函数和优化器
model = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
for images, labels in train_loader: outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
# 测试准确率
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100*correct / total:.2f}%')
可视化训练曲线:使用 Matplotlib 绘制损失和准确率曲线。
import matplotlib.pyplot as plt
# 假设记录了每个epoch的loss和accuracy
loss_history = [0.5, 0.3, 0.2, 0.1, 0.08, 0.06, 0.05, 0.04, 0.03, 0.02]
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.show()