ResNet框架
from msilib import make_id
import torch
import torch.nn as nn
from sqlalchemy.testing.plugin.plugin_base import config
from torch import Tensor
from mcmicm.rSLDS模型.model2_with_penalty import weights, gamma
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(self, in_channels: int, out_channels: int, stride: int = 1, downsample: nn.Module = None): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample
def forward(self, x: Tensor) -> Tensor: identity = x
out = self.conv1(x) out = self.bn1(out) out = self.relu(out)
out = self.conv2(out) out = self.bn2(out)
if self.downsample is not None: identity = self.downsample(x)
out += identity out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__( self, block: type[BasicBlock], layers: list[int],
num_classes: int = 1000 ): super().__init__() self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer( self, block: type[BasicBlock], channels: int, blocks: int, stride: int = 1 ) -> nn.Sequential: downsample = None if stride != 1 or self.in_channels != channels * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.in_channels, channels * block.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(channels * block.expansion), )
layers = [] layers.append(block( self.in_channels, channels, stride, downsample )) self.in_channels = channels * block.expansion for _ in range(1, blocks): layers.append(block( self.in_channels, channels ))
return nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x)
x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x)
x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x)
return x
经典应用案例(CIFAR-10分类)
import torch
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
config = {
"batch_size": 128, "lr": 0.1, "momentum": 0.9, "weight_decay":5e-4, "epochs": 200, "lr_milestones": [60, 120, 160], "gamma": 0.2}
train_transform = torchvision.transforms.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010] )])
test_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010] )])
train_set = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=test_transform)
train_loader = DataLoader(
train_set, batch_size=config["batch_size"], shuffle=True, num_workers=4)
test_loader = DataLoader(
test_set, batch_size=config["batch_size"], shuffle=False, num_workers=4)
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"], weights_decay=config["weight_decay"])
scheduler = optim.lr_scheduler.MultiplicativeLR(
optimizer, milestones=config["lr_milestones"], gamma=config["gamma"])
for epoch in range(config["epochs"]):
model.train() for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()
scheduler.step()
model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()
print(f"Epoch {epoch+1} | Acc: {100*correct/total:.2f}%")