Skip to content
0

文章发布较早,内容可能过时,阅读注意甄别。

玩转CIFAR10数据集

本次通过各种调参来玩转CIFAR10这个数据集

1、卷积神经网络

1.1、普通卷积

1.1.1、普通超参数

python
import torchvision
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import time
from pathlib import Path


class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        x = self.model(x)
        return x


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    # 在训练前创建两个列表,用于存储每个 epoch 的 loss 和 accuracy

    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f"Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}")
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([transforms.ToTensor()])
    # 准备数据集
    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=transform, download=True)
    test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=transform, download=True)
    train_data_size = len(train_data)
    test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Common_logs")
    # 利用DataLoader来加载数据集
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
    Net = Module()
    Net.to(device)
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # 学习率
    lr = 0.01
    # 优化器
    optimizer = torch.optim.SGD(Net.parameters(), lr=lr)
    # 定义模型保存的文件夹
    model_dir = './Common_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 100
    epoch = 0
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, epoch, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")
'''
Files already downloaded and verified
Files already downloaded and verified
Train Epoch:1 [0/782] loss:2.312440
Train Epoch:1 [100/782] loss:2.310811
Train Epoch:1 [200/782] loss:2.289058
Train Epoch:1 [300/782] loss:2.284024
Train Epoch:1 [400/782] loss:2.199121
Train Epoch:1 [500/782] loss:2.186231
Train Epoch:1 [600/782] loss:2.084236
Train Epoch:1 [700/782] loss:2.000870
Train Average Loss:0.03442628162384033,Accuracy:0.17418
time, 9.44 s
Test Average Loss:0.031229505383968355,Accuracy:0.2815
模型已保存
Train Epoch:2 [0/782] loss:2.353905
......
'''

每轮训练的模型保存在./Common_models中,可视化展示保存在./Common_logs

在训练完之后,在命令行敲入以下命令查看损失和准确率的变换情况。

bash
tensorboard --logdir "path/of/the/Common_logs" --port=6008(可以修改)

结果:==训练集准确率接近100%,测试集准确率保持在67%左右==

1.1.2、优化学习率

1.1.2.1、使用余弦退火算法

核心代码:

python
# 优化器
optimizer=torch.optim.SGD(Net.parameters(),lr=lr,weight_decay=1e-5)
# 余弦退火
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, EPOCH, eta_min=1e-5, last_epoch=-1)

完整代码:

python
import torchvision
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import time


class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        x = self.model(x)
        return x


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):

    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()

        if batch_index % 100 == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Train Epoch:{epoch} [{batch_index}/{len(train_load)}] lr: {current_lr}  loss:{loss.item():.6f}")
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([transforms.ToTensor()])
    # 准备数据集
    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=transform, download=True)
    test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=transform,  download=True)
    train_data_size = len(train_data)
    test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Common_cosine_logs")
    # 利用DataLoader来加载数据集
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
    # 定义模型
    Net = Module()
    Net.to(device)
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # 训练的总轮数
    EPOCH = 100
    # 学习率
    lr = 0.01
    # 优化器
    optimizer = torch.optim.SGD(Net.parameters(), lr=lr, weight_decay=1e-5)
    # 余弦退火
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, EPOCH, eta_min=1e-5, last_epoch=-1)
    # 定义模型保存的文件夹
    model_dir = './Common_cosine_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        scheduler.step()
        test_sector(Net, device, test_dataloader, loss_fn, epoch, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")
'''
Files already downloaded and verified
Files already downloaded and verified
Train Epoch:1 [0/782] lr: 0.01  loss:2.321574
Train Epoch:1 [100/782] lr: 0.01  loss:2.293098
Train Epoch:1 [200/782] lr: 0.01  loss:2.301508
Train Epoch:1 [300/782] lr: 0.01  loss:2.263761
Train Epoch:1 [400/782] lr: 0.01  loss:2.227565
Train Epoch:1 [500/782] lr: 0.01  loss:2.189559
Train Epoch:1 [600/782] lr: 0.01  loss:2.162553
Train Epoch:1 [700/782] lr: 0.01  loss:1.969154
Train Average Loss:0.0345455189204216,Accuracy:0.17764
time, 8.33 s
Test Average Loss:0.0320733146905899,Accuracy:0.249
模型已保存
Train Epoch:2 [0/782] lr: 0.009997535269026829  loss:1.968917
Train Epoch:2 [100/782] lr: 0.009997535269026829  loss:1.969962
......
'''

每轮训练的模型保存在./Common_cosine_models中,可视化展示保存在./Common_cosine_logs

在训练完之后,在命令行敲入以下命令查看损失和准确率的变换情况。

bash
tensorboard --logdir "path/of/the/Common_cosine_logs" --port=6008(可以修改)

结果:==曲线变化更加平缓,准确率无明显变化。训练集准确率接近100%,测试集准确率保持在67%左右==

1.1.3、调整优化器

之前使用的都是SGD优化器,现在修改使用其他的优化器进行讨论

1.1.3.1、使用Adam优化器

python
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import time


class Module(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            # nn.ReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        x = self.model(x)
        return x


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f"Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}")
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([transforms.ToTensor()])
    # 准备数据集
    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=transform, download=True)
    test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=transform, download=True)
    train_data_size = len(train_data)
    test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Common_Adam_logs")
    # 利用DataLoader来加载数据集
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
    # 定义模型
    Net = Module()
    Net.to(device)
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # 学习率
    lr = 0.01
    # 优化器
    optimizer = torch.optim.Adam(Net.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
    # 定义模型保存的文件夹
    model_dir = './Common_Adam_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 100
    epoch = 0
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")
'''
Files already downloaded and verified
Files already downloaded and verified
Train Epoch:1 [0/782] loss:2.306522
Train Epoch:1 [100/782] loss:1.963838
Train Epoch:1 [200/782] loss:1.936326
Train Epoch:1 [300/782] loss:1.936133
Train Epoch:1 [400/782] loss:2.125272
Train Epoch:1 [500/782] loss:2.678718
Train Epoch:1 [600/782] loss:1.772193
Train Epoch:1 [700/782] loss:1.678015
Train Average Loss:0.1620299960231781,Accuracy:0.30402
time, 12.33 s
Test Average Loss:0.028692625737190246,Accuracy:0.3801
模型已保存
Train Epoch:2 [0/782] loss:1.947087
Train Epoch:2 [100/782] loss:1.863164
......
'''

每轮训练的模型保存在./Common_Adam_models中,可视化展示保存在./Common_Adam_logs

在训练完之后,在命令行敲入以下命令查看损失和准确率的变换情况。

bash
tensorboard --logdir "path/of/the/Common_Adam_logs" --port=6008(可以修改)

结果:震荡比较明显,训练集和测试集的准确率在==50%==左右上下波动

1.2、修改卷积网络

1.2.1、添加Dropout和注意力

本次加入了通道注意力机制和空间注意力机制,同时配合Dropout

1.2.1.1、无图像增强

python
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import time


# 通道注意力
class ChannelAttention(nn.Module):
    # 初始化, in_channel代表输入特征图的通道数, ratio代表第一个全连接的通道下降倍数
    def __init__(self, in_channel, ratio=4):
        # 继承父类初始化方法
        super(ChannelAttention, self).__init__()

        # 全局最大池化 [b,c,h,w]==>[b,c,1,1]
        self.max_pool = nn.AdaptiveMaxPool2d(output_size=1)
        # 全局平均池化 [b,c,h,w]==>[b,c,1,1]
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)

        # 第一个全连接层, 通道数下降4倍
        self.fc1 = nn.Linear(in_features=in_channel, out_features=in_channel // ratio, bias=False)
        # 第二个全连接层, 恢复通道数
        self.fc2 = nn.Linear(in_features=in_channel // ratio, out_features=in_channel, bias=False)

        # relu激活函数
        self.relu = nn.ReLU()
        # sigmoid激活函数
        self.sigmoid = nn.Sigmoid()

    # 前向传播
    def forward(self, inputs):
        # 获取输入特征图的shape
        b, c, h, w = inputs.shape

        # 输入图像做全局最大池化 [b,c,h,w]==>[b,c,1,1]
        max_pool = self.max_pool(inputs)
        # 输入图像的全局平均池化 [b,c,h,w]==>[b,c,1,1]
        avg_pool = self.avg_pool(inputs)

        # 调整池化结果的维度 [b,c,1,1]==>[b,c]
        max_pool = max_pool.view([b, c])
        avg_pool = avg_pool.view([b, c])

        # 第一个全连接层下降通道数 [b,c]==>[b,c//4]
        x_maxpool = self.fc1(max_pool)
        x_avgpool = self.fc1(avg_pool)

        # 激活函数
        x_maxpool = self.relu(x_maxpool)
        x_avgpool = self.relu(x_avgpool)

        # 第二个全连接层恢复通道数 [b,c//4]==>[b,c]
        x_maxpool = self.fc2(x_maxpool)
        x_avgpool = self.fc2(x_avgpool)

        # 将这两种池化结果相加 [b,c]==>[b,c]
        x = x_maxpool + x_avgpool
        # sigmoid函数权值归一化
        x = self.sigmoid(x)
        # 调整维度 [b,c]==>[b,c,1,1]
        x = x.view([b, c, 1, 1])
        # 输入特征图和通道权重相乘 [b,c,h,w]
        outputs = inputs * x

        return outputs


# 空间注意力机制
class SpatialAttention(nn.Module):
    # 初始化,卷积核大小为7*7
    def __init__(self, kernel_size=7):
        # 继承父类初始化方法
        super(SpatialAttention, self).__init__()

        # 为了保持卷积前后的特征图shape相同,卷积时需要padding
        padding = kernel_size // 2
        # 7*7卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size,
                              padding=padding, bias=False)
        # sigmoid函数
        self.sigmoid = nn.Sigmoid()

    # 前向传播
    def forward(self, inputs):
        # 在通道维度上最大池化 [b,1,h,w]  keepdim保留原有深度
        # 返回值是在某维度的最大值和对应的索引
        x_maxpool, _ = torch.max(inputs, dim=1, keepdim=True)

        # 在通道维度上平均池化 [b,1,h,w]
        x_avgpool = torch.mean(inputs, dim=1, keepdim=True)
        # 池化后的结果在通道维度上堆叠 [b,2,h,w]
        x = torch.cat([x_maxpool, x_avgpool], dim=1)

        # 卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]
        x = self.conv(x)
        # 空间权重归一化
        x = self.sigmoid(x)
        # 输入特征图和空间权重相乘
        outputs = inputs * x

        return outputs


# CBAM注意力机制
class CBAM(nn.Module):
    # 初始化,in_channel和ratio=4代表通道注意力机制的输入通道数和第一个全连接下降的通道数
    # kernel_size代表空间注意力机制的卷积核大小
    def __init__(self, in_channel, ratio=4, kernel_size=7):
        # 继承父类初始化方法
        super(CBAM, self).__init__()

        # 实例化通道注意力机制
        self.channel_attention = ChannelAttention(in_channel=in_channel, ratio=ratio)
        # 实例化空间注意力机制
        self.spatial_attention = SpatialAttention(kernel_size=kernel_size)

    # 前向传播
    def forward(self, inputs):
        # 先将输入图像经过通道注意力机制
        x = self.channel_attention(inputs)
        # 然后经过空间注意力机制
        x = self.spatial_attention(x)

        return x


class Module(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, 5, 1, 2)
        self.conv2 = nn.Conv2d(32, 32, 5, 1, 2)
        self.conv3 = nn.Conv2d(32, 64, 5, 1, 2)
        self.pool = nn.MaxPool2d(2)
        self.cbam1 = CBAM(32)
        self.cbam2 = CBAM(64)
        # self.cbam3=cbam(64)
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        # 应用注意力机制在各自的卷积层后面
        # x = self.cbam1(self.pool(self.conv1(x)))  # cbam1应用在conv1后面
        # x = self.cbam2(self.pool(self.conv2(x)))  # cbam2应用在conv2后面
        # x = self.cbam3(self.pool(self.conv3(x)))  # cbam3应用在conv3后面

        # output = self.fc(x)
        # return output
        x = self.pool(self.cbam1(self.conv1(x)))
        x = self.pool(self.cbam1(self.conv2(x)))
        x = self.pool(self.cbam2(self.conv3(x)))
        output = self.fc(x)
        return output


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f"Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}")
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        # transforms.Resize((224, 224)),
        # transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    # 准备数据集
    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=transform, download=True)
    test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=transform,  download=True)
    train_data_size = len(train_data)
    test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Drop_Attention_logs")
    # 利用DataLoader来加载数据集
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
    # 定义模型
    Net = Module()
    Net.to(device)
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # 学习率
    lr = 0.01
    # 优化器
    optimizer = torch.optim.Adam(Net.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
    # 定义模型保存的文件夹
    model_dir = './Drop_Attention_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 80
    epoch = 0
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")

'''
Files already downloaded and verified
Files already downloaded and verified
Train Epoch:1 [0/782] loss:2.296275
Train Epoch:1 [100/782] loss:2.284125
Train Epoch:1 [200/782] loss:2.250010
Train Epoch:1 [300/782] loss:2.043480
Train Epoch:1 [400/782] loss:2.072961
Train Epoch:1 [500/782] loss:1.992128
Train Epoch:1 [600/782] loss:1.789630
Train Epoch:1 [700/782] loss:1.743162
Train Average Loss:0.031432345366477966,Accuracy:0.2442
time, 16.54 s
Test Average Loss:0.025988068795204163,Accuracy:0.3929
模型已保存
Train Epoch:2 [0/782] loss:1.481611
Train Epoch:2 [100/782] loss:1.717270
Train Epoch:2 [200/782] loss:1.813353
......
'''

结果:训练集准确率保持在85%左右,测试集准确率保持在68%左右

1.2.1.2、有图像增强

改变一下图像变换的transform

python
transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
'''
Files already downloaded and verified
Files already downloaded and verified
Train Epoch:1 [0/782] loss:2.311297
Train Epoch:1 [100/782] loss:1.786359
Train Epoch:1 [200/782] loss:1.672049
Train Epoch:1 [300/782] loss:1.480483
Train Epoch:1 [400/782] loss:1.702224
Train Epoch:1 [500/782] loss:2.059962
Train Epoch:1 [600/782] loss:1.430759
Train Epoch:1 [700/782] loss:1.554896
Train Average Loss:0.02562361562013626,Accuracy:0.4016
time, 21.76 s
Test Average Loss:0.020984113544225694,Accuracy:0.5245
模型已保存
Train Epoch:2 [0/782] loss:1.366217
Train Epoch:2 [100/782] loss:1.592983
......
'''

结果:训练集保持在85%左右,测试集保持在75%左右

每轮训练的模型保存在./Common_Drop_Attention_models中,可视化展示保存在./Common_Drop_Attention_logs

在训练完之后,在命令行敲入以下命令查看损失和准确率的变换情况。

bash
tensorboard --logdir "path/of/the/Common_Drop_Attention_logs" --port=6008(可以修改)

1.3、使用预训练模型

1.3.1、Resnet18

由于Resnet18网络与CIFAR10数据集维度不匹配,所以需要修改网络结构

1.3.1.1、方法一

1、使用图像增强,变为可以符合Resnet18的输入维度

python
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

2、将Resnet18的最后一层的输出维度改为10

python
# 定义模型
Net = models.resnet18()
# 将原来的ResNet18的最后两层全连接层拿掉,替换成一个输出单元为10的全连接层
inchannel = Net.fc.in_features
Net.fc = nn.Linear(inchannel, 10)

3、完整代码:

python
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import time


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f'Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}')
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    # 准备数据集
    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=transform, download=True)
    test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=transform, download=True)
    train_data_size = len(train_data)
    test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Resnet18_logs")
    # 利用DataLoader来加载数据集
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
    # 定义模型
    Net = models.resnet18()
    # 将原来的ResNet18的最后两层全连接层拿掉,替换成一个输出单元为10的全连接层
    inchannel = Net.fc.in_features
    Net.fc = nn.Linear(inchannel, 10)
    Net.to(device)
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # 学习率
    lr = 0.01
    # 优化器
    optimizer = torch.optim.SGD(Net.parameters(), lr=lr)
    # 定义模型保存的文件夹
    model_dir = './Resnet18_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 40
    epoch = 0
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, epoch, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")
'''
Files already downloaded and verified
Files already downloaded and verified
Train Epoch:1 [0/782] loss:2.452383
Train Epoch:1 [100/782] loss:1.911389
Train Epoch:1 [200/782] loss:1.618683
Train Epoch:1 [300/782] loss:1.749338
Train Epoch:1 [400/782] loss:1.578882
Train Epoch:1 [500/782] loss:1.804207
Train Epoch:1 [600/782] loss:1.470715
Train Epoch:1 [700/782] loss:1.287879
Train Average Loss:0.026067537117004394,Accuracy:0.3761
time, 70.00 s
Test Average Loss:0.02401012862920761,Accuracy:0.4195
模型已保存
Train Epoch:2 [0/782] loss:1.595474
Train Epoch:2 [100/782] loss:1.311501
......
'''

结果:训练集准确率稳定在100%,测试集准确率在75%左右震荡

1.3.1.2、方法二

1、定义图像增强,但是不改变图片维度

python
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值,标准差
])
transforms_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

2、改变Resnet18网络结构,使其符合CIFAR10的输入维度

python
# 定义模型
Net = models.resnet18()
# 修改模型
Net.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)  # 首层改成3x3卷积核
Net.maxpool = nn.MaxPool2d(1, 1, 0)  # 图像太小 本来就没什么特征 所以这里通过1x1的池化核让池化层失效
num_features = Net.fc.in_features  # 获取(fc)层的输入的特征数
Net.fc = nn.Linear(num_features, 10)  # 将原来的ResNet18的最后两层全连接层拿掉,替换成一个输出单元为10的全连接层

3、完整代码(注意:保存模型的文件夹和可视化的文件夹都变了,变为Resnet18_logs_2Resnet18_models_2):

python
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import time


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f'Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}')
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值,标准差
    ])
    transforms_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    # 准备数据集
    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=transform_train, download=True)
    test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=transforms_test, download=True)
    train_data_size = len(train_data)
    test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Resnet18_logs_2")
    # 利用DataLoader来加载数据集
    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)

    # 定义模型
    Net = models.resnet18()
    # 修改模型
    Net.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)  # 首层改成3x3卷积核
    Net.maxpool = nn.MaxPool2d(1, 1, 0)  # 图像太小 本来就没什么特征 所以这里通过1x1的池化核让池化层失效
    num_features = Net.fc.in_features  # 获取(fc)层的输入的特征数
    Net.fc = nn.Linear(num_features, 10)  # 将原来的ResNet18的最后两层全连接层拿掉,替换成一个输出单元为10的全连接层
    Net.to(device)
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # 学习率
    lr = 0.01
    # 优化器
    optimizer = torch.optim.SGD(Net.parameters(), lr=lr)
    # 定义模型保存的文件夹
    model_dir = './Resnet18_models_2'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 40
    epoch = 0
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, epoch, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")
'''
Files already downloaded and verified
Files already downloaded and verified
Train Epoch:1 [0/782] loss:2.391769
Train Epoch:1 [100/782] loss:2.118423
Train Epoch:1 [200/782] loss:1.792327
Train Epoch:1 [300/782] loss:1.880022
Train Epoch:1 [400/782] loss:1.605273
Train Epoch:1 [500/782] loss:1.583118
Train Epoch:1 [600/782] loss:1.655769
Train Epoch:1 [700/782] loss:1.533173
Train Average Loss:0.027535110058784486,Accuracy:0.33896
time, 38.65 s
Test Average Loss:0.025295356225967406,Accuracy:0.4011
模型已保存
Train Epoch:2 [0/782] loss:1.667214
......
'''

结果:训练集结果保持在90%以上,测试集准确率保持在85%左右

每轮训练的模型保存在./Resnet18_models中,可视化展示保存在./Resnet18_logs

在训练完之后,在命令行敲入以下命令查看损失和准确率的变换情况。

bash
tensorboard --logdir "path/of/the/Resnet18_logs" --port=6008(可以修改)

1.4、使用已训练的模型进行测试

这里使用Resnet18的方法二训练的模型进行测试评估

python
# 测试CIFAR10
import torch
from pathlib import Path
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models
from torchvision import datasets, transforms


def eval_dataset(dataloader, device):
    total_loss = 0.0
    total_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(dataloader):
            image, label = image.to(device), label.to(device)
            output = Net(image)
            loss = loss_fn(output, label)
            total_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_accuracy += accuracy

        total_loss /= len(dataloader.dataset)
        total_accuracy /= len(dataloader.dataset)
    return total_loss, total_accuracy


if __name__ == '__main__':
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    epoch = 35  # 选择第几轮训练得到的模型
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值,标准差
    ])
    transforms_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_data = datasets.CIFAR10(root="./dataset", train=True, transform=transform_train, download=True)
    test_data = datasets.CIFAR10(root="./dataset", train=False, transform=transforms_test, download=True)
    # 利用DataLoader来加载数据集
    train_load = DataLoader(train_data, batch_size=64, shuffle=True)
    test_load = DataLoader(test_data, batch_size=64, shuffle=False)
    # 定义模型保存的文件夹
    model_dir = './Resnet18_models_2'
    Path(model_dir).mkdir(parents=True, exist_ok=True)

    # 定义模型
    Net = models.resnet18()
    # 修改模型
    Net.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)  # 首层改成3x3卷积核
    Net.maxpool = nn.MaxPool2d(1, 1, 0)  # 图像太小 本来就没什么特征 所以这里通过1x1的池化核让池化层失效
    num_features = Net.fc.in_features  # 获取(fc)层的输入的特征数
    Net.fc = nn.Linear(num_features, 10)  # 将原来的ResNet18的最后两层全连接层拿掉,替换成一个输出单元为10的全连接层
    Net.to(device)
    Net.load_state_dict(torch.load(f"{model_dir}/Net{epoch}.pth", map_location=device))

    loss_fn = nn.CrossEntropyLoss().to(device)
    Net.eval()
    train_loss, train_accuracy = eval_dataset(train_load, device)
    test_loss, test_accuracy = eval_dataset(test_load, device)
    print(f"Train Dataloader, Average Loss:{train_loss},Accuracy:{train_accuracy}")
    print(f"Test Dataloader,  Average Loss:{test_loss},Accuracy:{test_accuracy}")
'''
Train Dataloader, Average Loss:0.0034980082565546034,Accuracy:0.92248
Test Dataloader,  Average Loss:0.007022373244166374,Accuracy:0.8619
'''

2、Vision In Transformer(VIT)结构

2.1、手写ViT模型

2.1.1、训练评估代码

python
"""
特征提取的实例:
利用迁移学习中特征提取的方法来对CIFAR-10数据集实现对10类无体的分类
"""
import torchvision
import time
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
# from kan import KAN
import torch
from torch import nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.1):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):  # 最重要的都是forword函数了
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        # 对tensor张量分块 x :1 197 1024   qkv 最后 是一个元组,tuple,长度是3,每个元素形状:1 197 1024
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        # 分成多少个Head,与TRM生成qkv 的方式不同, 要更简单,不需要区分来自Encoder还是Decoder

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


# 1. VIT整体架构从这里开始
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        # 初始化函数内,是将输入的图片,得到 img_size ,patch_size 的宽和高
        image_height, image_width = pair(image_size)  # 224*224 *3
        patch_height, patch_width = pair(patch_size)  # 16 * 16  *3
        # 图像尺寸必须能被patch大小整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)  # 步骤1.一个图像 分成 N 个patch
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),  # 步骤2.1将patch 铺开
            nn.Linear(patch_dim, dim),  # 步骤2.2 然后映射到指定的embedding的维度
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
            # KAN([dim, num_classes])
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)  # img 1 3 224 224  输出形状x : 1 196 1024
        b, n, _ = x.shape
        # 将cls 复制 batch_size 份
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        # 将cls token在维度1 扩展到输入上
        x = torch.cat((cls_tokens, x), dim=1)
        # 添加位置编码
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        # 输入TRM
        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)


def get_vit_model():
    v = ViT(
        image_size=112,
        patch_size=16,
        num_classes=10,  # 记得修改,要与标签的数量相一致
        dim=128,
        depth=2,
        heads=4,
        mlp_dim=128,
        dropout=0.1,
        emb_dropout=0.1
    )
    return v


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f'Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}')
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # 加载和预处理数据集
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(112),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小;
         # (即先随机采集,然后对裁剪得到的图像缩放为同一大小) 默认scale=(0.08, 1.0)
         transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )
    trans_valid = transforms.Compose(
        [transforms.Resize(256),  # 是按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩。
         transforms.CenterCrop(112),  # 依据给定的size从中心裁剪
         transforms.ToTensor(),  # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
         # 归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])]  # 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc
    )

    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True, transform=trans_train)
    test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, download=False, transform=trans_valid)

    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)

    # train_data_size = len(train_data)
    # test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Custom_ViT_logs")

	# 定义模型
    Net = get_vit_model().to(device)

    # 查看总参数及训练参数
    total_params = sum(p.numel() for p in Net.parameters())
    print('总参数个数:{}'.format(total_params))
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)  # 损失函数
    # 学习率
    lr = 1e-3
    # 优化器
    # optimizer = torch.optim.SGD(Net.parameters(), lr=lr, weight_decay=1e-3, momentum=0.9)  # 优化器
    optimizer = torch.optim.Adam(Net.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-8)
    # 定义模型保存的文件夹
    model_dir = './Custom_ViT_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 2000
    epoch = 0
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, epoch, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")

2.2、使用已有的库

2.2.1、SimpleViT

1、首先安装vit_pytorch

bash
pip install vit_pytorch

2、完整代码:

2.2.1.1、训练评估代码

python
"""
特征提取的实例:
利用迁移学习中特征提取的方法来对CIFAR-10数据集实现对10类无体的分类
"""
import torch
import torchvision
import time
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from vit_pytorch import SimpleViT


def get_vit_model():
    v = SimpleViT(image_size=224, patch_size=16, num_classes=10, dim=256, depth=2, heads=4, mlp_dim=128)
    return v


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f'Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}')
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 加载和预处理数据集
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小;
         # (即先随机采集,然后对裁剪得到的图像缩放为同一大小) 默认scale=(0.08, 1.0)
         transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )
    trans_valid = transforms.Compose(
        [transforms.Resize(256),  # 是按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩。
         transforms.CenterCrop(224),  # 依据给定的size从中心裁剪
         transforms.ToTensor(),  # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
         # 归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])]  # 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc
    )

    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True, transform=trans_train)
    test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, download=False, transform=trans_valid)

    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)

    # train_data_size = len(train_data)
    # test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Simple_ViT_logs")

    # 使用预训练模型
    Net = get_vit_model().to(device)

    # 查看总参数及训练参数
    total_params = sum(p.numel() for p in Net.parameters())
    print('总参数个数:{}'.format(total_params))
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)  # 损失函数
    # 学习率
    lr = 1e-3
    # 优化器
    optimizer = torch.optim.SGD(Net.parameters(), lr=lr, weight_decay=1e-3, momentum=0.9)  # 优化器
    # 定义模型保存的文件夹
    model_dir = './Simple_ViT_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 40
    epoch = 0
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, epoch, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")
'''
Files already downloaded and verified
总参数个数:860170
Train Epoch:1 [0/782] loss:2.441724
Train Epoch:1 [100/782] loss:2.156467
Train Epoch:1 [200/782] loss:2.059536
Train Epoch:1 [300/782] loss:2.063913
Train Epoch:1 [400/782] loss:1.985299
Train Epoch:1 [500/782] loss:1.981158
Train Epoch:1 [600/782] loss:2.040962
Train Epoch:1 [700/782] loss:2.010400
Train Average Loss:0.03235056596040726,Accuracy:0.22038
time, 40.29 s
Test Average Loss:0.031189283430576323,Accuracy:0.2707
模型已保存
Train Epoch:2 [0/782] loss:2.160810
Train Epoch:2 [100/782] loss:2.005908
......
'''

结果:训练集在50%左右,测试集在55%左右,可以加大轮数继续训练

每轮训练的模型保存在./Simple_ViT_models中,可视化展示保存在./Simple_ViT_logs

在训练完之后,在命令行敲入以下命令查看损失和准确率的变换情况。

bash
tensorboard --logdir "path/of/the/Simple_ViT_logs" --port=6008(可以修改)

2.2.1.2、模型推理代码

python
from PIL import Image
import torchvision
from pathlib import Path
from vit_pytorch import SimpleViT
import torch


if __name__ == '__main__':
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']
    image_path = "dog.jpg"  # 图片途径
    image = Image.open(image_path)
    image = image.convert("RGB")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # print(image)

    transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)),
                                                torchvision.transforms.ToTensor()])
    image = transform(image)
    # print(image.shape)

    # 定义模型保存的文件夹
    model_dir = './Simple_ViT_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    epoch = 40  # 选择第几轮训练得到的模型

    # 定义模型
    model = SimpleViT(image_size=224, patch_size=16, num_classes=10, dim=256, depth=2, heads=4, mlp_dim=128)
    model.to(device)
    model.load_state_dict(torch.load(f"{model_dir}/Net{epoch}.pth", map_location=device))
    # print(model.device)

    image = torch.reshape(image, (1, 3, 224, 224)).to(device)
    model.eval()
    with torch.no_grad():
        output = model(image)

    print(output)
    print("It is a " + class_names[output.argmax(1)] + ".")
'''
tensor([[-2.2602, -3.1466,  1.0401,  1.9852,  1.2179,  2.7685,  0.9605,  1.8921,
         -2.2489, -2.2318]], device='cuda:0')
It is a dog.
'''

3、Swin Transformer结构

3.1、训练评估代码

python
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import torchvision
import time
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
# from kan import KAN
import torch
from torch import nn
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        ##################### 循环移位局部窗口自注意力 #####################
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops


class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops


class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)  # x.shape = (B, num_classes)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops


# 训练
def train_sector(net, device, train_load, loss_fn, optimizer, epoch, writer):
    net.train()
    total_train_loss = 0.0
    total_train_accuracy = 0.0
    for batch_index, (image, label) in enumerate(train_load):
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(image)
        loss = loss_fn(output, label)
        total_train_loss += loss.item()
        accuracy = (output.argmax(dim=1) == label).sum().item()
        total_train_accuracy += accuracy
        loss.backward()
        optimizer.step()
        if batch_index % 100 == 0:
            print(f'Train Epoch:{epoch} [{batch_index}/{len(train_load)}] loss:{loss.item():.6f}')
    total_train_loss /= len(train_load.dataset)
    total_train_accuracy /= len(train_load.dataset)
    writer.add_scalar("train_loss", total_train_loss, epoch)
    writer.add_scalar("train_accuracy", total_train_accuracy, epoch)
    print(f"Train Average Loss:{total_train_loss},Accuracy:{total_train_accuracy}")


# 测试
def test_sector(net, device, test_load, loss_fn, epoch, writer):
    net.eval()
    total_test_loss = 0.0
    total_test_accuracy = 0.0
    with torch.no_grad():
        for batch_index, (image, label) in enumerate(test_load):
            image, label = image.to(device), label.to(device)
            output = net(image)
            loss = loss_fn(output, label)
            total_test_loss += loss.item()
            accuracy = (output.argmax(dim=1) == label).sum().item()
            total_test_accuracy += accuracy

        total_test_loss /= len(test_load.dataset)
        total_test_accuracy /= len(test_load.dataset)
        writer.add_scalar("test_loss", total_test_loss, epoch)
        writer.add_scalar("test_accuracy", total_test_accuracy, epoch)
        print(f"Test Average Loss:{total_test_loss},Accuracy:{total_test_accuracy}")


if __name__ == '__main__':
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # 加载和预处理数据集
    trans_train = transforms.Compose(
        [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小;
         # (即先随机采集,然后对裁剪得到的图像缩放为同一大小) 默认scale=(0.08, 1.0)
         transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像,默认为0.5;
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    )
    trans_valid = transforms.Compose(
        [transforms.Resize(256),  # 是按照比例把图像最小的一个边长放缩到256,另一边按照相同比例放缩。
         transforms.CenterCrop(224),  # 依据给定的size从中心裁剪
         transforms.ToTensor(),  # 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
         # 归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])]  # 对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc
    )

    train_data = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True, transform=trans_train)
    test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, download=False, transform=trans_valid)
    
    # train_data = torchvision.datasets.ImageFolder(root="./Dataset/train", transform=trans_train)
    # test_data = torchvision.datasets.ImageFolder(root="./Dataset/test", transform=trans_valid)

    train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)

    # train_data_size = len(train_data)
    # test_data_size = len(test_data)
    # print(f"训练数据集的长度:{train_data_size}")
    # print(f"测试数据集的长度:{test_data_size}")
    writer = SummaryWriter("Custom_SwinT_logs")

    # 使用预训练模型
    Net = SwinTransformer(img_size=224, in_chans=3, num_classes=10).to(device)

    # 查看总参数及训练参数
    total_params = sum(p.numel() for p in Net.parameters())
    print('总参数个数:{}'.format(total_params))
    # 损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)  # 损失函数
    # 学习率
    lr = 1e-3
    # 优化器
    optimizer = torch.optim.SGD(Net.parameters(), lr=lr, weight_decay=1e-3, momentum=0.9)  # 优化器
    # optimizer = torch.optim.Adam(Net.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-8)
    # 定义模型保存的文件夹
    model_dir = './Custom_SwinT_models'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 训练的总轮数
    EPOCH = 2000
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        train_sector(Net, device, train_dataloader, loss_fn, optimizer, epoch, writer)
        end_time = time.time()
        print(f'time, {(end_time - start_time):.2f} s')
        test_sector(Net, device, test_dataloader, loss_fn, epoch, writer)
        torch.save(Net.state_dict(), f"{model_dir}/Net{epoch}.pth")
        print("模型已保存")

修改PatchEmbed,感觉没啥用

python
class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        image_height, image_width = to_2tuple(img_size)
        patch_height, patch_width = to_2tuple(patch_size)
        patches_resolution = [image_height // patch_height, image_width // patch_width]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None
        patch_dim = in_chans * patch_height * patch_width
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, embed_dim)
        )

    def forward(self, x):
        # B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        # assert H == self.img_size[0] and W == self.img_size[1], \
        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        # x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        x = self.to_patch_embedding(x)  # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dim
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops
最近更新