GoogLeNet框架
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import T
from torch.utils.hooks import RemovableHandle
class InceptionModule(nn.Module):
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): super().__init__() # 分支1:1x1卷积
self.branch1 = nn.Conv2d(in_channels, ch1x1, kernel_size=1) # 分支2:1x1 -> 3x3卷积
self.branch2 = nn.Sequential( nn.Conv2d(in_channels, ch3x3red, kernel_size=1), nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) ) # 分支3:1x1 -> 5x5 卷积
self.branch3 = nn.Sequential( nn.Conv2d(in_channels, ch5x5red, kernel_size=1), nn.Conv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) ) # 分支4:3x3 池化 -> 1x1 卷积
self.branch4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.Conv2d(in_channels, pool_proj, kernel_size=1) ) def forward(self, x): return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], 1)
class GoogLeNet(nn.Module):
def __init__(self, num_classes=1000): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1) self.conv2 = nn.Conv2d(64, 64, kernel_size=1) self.conv3 = nn.Conv2d(64, 192, kernel_size=3, padding=1) self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1) # Inception 模块堆叠
self.inception3a = InceptionModule(192, 64, 96, 128, 16, 32, 32) self.inception3b = InceptionModule(256, 128, 128, 192, 32, 96, 64) self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1) # 更多Inception模块
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(0.4) self.fc = nn.Linear(1024, num_classes) # 假设最终特征维度为1024
def forward(self, x): x = F.relu(self.conv1(x)) x = self.maxpool1(x) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = self.maxpool2(x) x = self.inception3a(x) x = self.inception3b(x) x = self.maxpool3(x) # 继续传递更多的模块
x = self.avgpool(x) x = torch.flatten(x, 1) x = self.dropout(x) x = self.fc(x) return x
典型示例:Oxford-IIIT Pet 细粒度分类
- 目标:对 37 种宠物(猫、狗的不同品种)进行分类.
- 数据集:Oxford-IIIT Pet 数据集(包含 7,349 张图像)。
import torchvision
import torch.optim as optim
from torchvision import transforms
# 数据预处理
transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# 加载数据集
train_set = torchvision.datasets.OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
test_set = torchvision.datasets.OxfordIIITPet(root='./data', split='test', download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False)
# 初始化模型(调整为 37 类输出)
model = GoogLeNet(num_classes=37)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型(使用GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(25):
model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f'Epoch {epoch+1}, Loss: {running_loss / len(train_loader):.4f}')
# 测试准确率
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')
特征可视化:使用 PCA 降维展示高层特征分布:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
features = []
model.eval()
with torch.no_grad():
for images, _ in test_loader: images = images.to(device) feat = model.avgpool(model.inception5b(images)) # 提取某层特征
features.append(feat.cpu().view(images.size(0), -1))features = torch.cat(features, dim=0).numpy()
pca = PCA(n_components=2)
reduced = pca.fit_transform(features)
plt.scatter(reduced[:, 0], reduced[:, 1], c=test_set.targets)
plt.show()