DenseNet框架
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from torch.utils.data import DataLoader
import torch.nn.functional as F
class Bottleneck(nn.Module):
def __init__(self, in_channels, growth_rate, bn_size=4): super().__init__() inner_channels = bn_size * growth_rate self.bn1 = nn.BatchNorm2d(in_channels) self.conv1 = nn.Conv2d(in_channels, inner_channels, 1, bias=False) self.bn2 = nn.BatchNorm2d(inner_channels) self.conv2 = nn.Conv2d(inner_channels, growth_rate, 3, padding=1, bias=False)
def forward(self, x): out = self.conv1(F.relu(self.bn1(x))) out = self.conv2(F.relu(self.bn2(x))) return torch.cat([x, out], 1)
class Transition(nn.Module):
def __init__(self, in_channels, compression=0.5): super().__init__() out_channels = int(in_channels * compression) self.bn = nn.BatchNorm2d(in_channels) self.conv = nn.Conv2d(in_channels, out_channels, 1, bias=False) self.pool = nn.AvgPool2d(2, stride=2)
def forward(self, x): x = self.conv(F.relu(self.bn(x))) return self.pool(x)
class DenseBlock(nn.ModuleDict):
def __init__(self, num_layers, in_channls, growth_rate, bn_size=4): super().__init__() for i in range(num_layers): layer = Bottleneck( in_channls + i * growth_rate, growth_rate, bn_size ) self.add_module(f"denselayer{i+1}", layer)
def forward(self, init_features): features = [init_features] for name, layer in self.items(): new_features = layer(torch.cat(features, 1)) features.append(new_features) return torch.cat(features, 1)
class DenseNet(nn.Module):
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, compression=0.5, num_classes=1000): super().__init__()
self.features = nn.Sequential( nn.Conv2d(3, num_init_features, 7, 2, 3, bias=False), nn.BatchNorm2d(num_init_features), nn.ReLU(inplace=True), nn.MaxPool2d(3, 2, 1) )
num_features = num_init_features for i, num_layers in enumerate(block_config): block = DenseBlock( num_layers=num_layers, in_channls=num_features, growth_rate=growth_rate, bn_size=bn_size ) self.features.add_module(f"denseblock{i+1}", block) num_features += num_layers * growth_rate
if i != len(block_config) - 1: trans = Transition( num_features, compression=compression ) self.features.add_module(f"transition{i+1}",trans) num_features = int(num_features * compression)
self.classifier = nn.Linear(num_features, num_classes)
for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
def forward(self, x): features = self.features(x) out = F.adaptive_avg_pool2d(features, (1, 1)) out = torch.flatten(out, 1) return self.classifier(out)
CIFAR-100分类实战
config = {
"batch_size": 64, "epochs": 300, "lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4, "lr_schedule": { "milestones": [150, 225], "gamma": 0.1 }, "growth_rate": 12, "block_config": (16, 16, 16),
"compression": 0.5}
from torchvision import transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4 ), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )])
test_transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )])
scaler = torch.cuda.amp.GradScaler()
for inputs, labels in train_loader:
inputs = inputs.to(device) labels = labels.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, labels)
scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
from torch.utils.checkpoint import checkpoint
class MemoryEfficientDenseBlock(DenseBlock):
def forward(self, init_features): features = [init_features] for layer in self.values(): new_features = checkpoint(layer, torch.cat(features, 1)) features.append(new_features) return torch.cat(features, 1)
prune_rate = 0.3
module = model.features.denseblock1.denselayer1.conv2
prune.ln_structured(module, 'weight', amount=prune_rate, n=2, dim=0)
prune.remove(module, 'weight')
teacher_model = torchvision.models.densenet201(pretrained=True)
student_model = DenseNet(block_config=(6, 12, 24, 16))
distill_loss = nn.KLDivLoss(reduction="batchmean")
temperature = 3
alpha = 0.7
for data, labels in dataloader:
teacher_logits = teacher_model(data) student_logits = student_model(data)
loss = (alpha * distill_loss( F.log_softmax(student_logits/temperature, dim=1), F.softmax(teacher_logits/temperature, dim=1) )) + ((1-alpha) * criterion(student_logits, labels))