Skip to content
0

文章发布较早,内容可能过时,阅读注意甄别。

1 从零开始搭建扩散模型

以下的代码请在jupyter中运行

1.1 环境准备

1.1.1 环境的创建与导入

1、需要安装包diffusers

bash
pip install -q diffusers

2、然后导入需要的环境

python
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

1.1.2 数据集测试

读取数据集并且显示图片

python
dataset = torchvision.datasets.MNIST(root='mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())
# dataset = torchvision.datasets.FashionMNIST(root='fmnist', train=True, download=True, transform=torchvision.transforms.ToTensor())
# dataset = torchvision.datasets.CIFAR10(root='cifar10', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape', x.shape)
print('Labels:', y)
plt.imshow((torchvision.utils.make_grid(x))[1])
'''
Files already downloaded and verified
Input shape torch.Size([8, 3, 32, 32])
Labels: tensor([9, 1, 0, 0, 0, 2, 9, 9])
'''

1.2 扩散模型之退化

向图像中加入噪声并且画图展示出来

python
def corrupt(x, amount):
    noise = torch.rand_like(x)
    amount = amount.view(-1, 1, 1, 1)
    # print('amount view', amount)
    noisy_x = (1-amount)*x + amount*noise
    # noisy_x = x - amount * noise
    return noisy_x

def show_images(x):
    # x = x * 0.5 + 0.5
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(-1, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im
# 绘图,原始图像
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

#加噪后的图像
amount = torch.linspace(0, 1, x.shape[0])
print('x',x.shape)
print('amount',amount)
noise_x = corrupt(x, amount)
# print(x[0, :, :, :] == noise_x[0, :, :, :])
axs[1].set_title('Corrupted data (-- amount imcrease -->)')
axs[1].imshow(torchvision.utils.make_grid(noise_x)[0], cmap='Greys')

# show_images(x).resize((8*98, 98),resample=Image.NEAREST)
# show_images(noise_x).resize((8*98, 98),resample=Image.NEAREST)
'''
x torch.Size([8, 1, 28, 28])
amount tensor([0.0000, 0.1429, 0.2857, 0.4286, 0.5714, 0.7143, 0.8571, 1.0000])
'''

TIP

对于RGB图像,可以将imshow函数的中的, cmap='Greys'去掉,否则之后以灰度图展示。但是去掉之后显示的图像还是不够好看。

可以使用show_images函数可以将RGB图像显示出来

1.3 扩散模型之训练

1.3.1 搭建网络

扩散模型中需要网络具有如下的特性:网络的输入和输出特征维度需要相同,常用的网络是Unet网络

现在自定义一个Unet网络并且进行测试

python
class BasicUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2)
        ])
        self.up_layers = torch.nn.ModuleList([
             nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2)
        ])
        self.act = nn.SiLU()
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)
        self.batch_norm = nn.BatchNorm2d(64)
    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))  # 通过运算层和激活函数
            if i < 2:
                h.append(x)     # 排列供残差连接使用的数据
                x = self.downscale(x)
        # x = self.batch_norm(x)
        for i , l in enumerate(self.up_layers):
            if i > 0:
                x = self.upscale(x)
                # print('i', i)
                x += h.pop()   # 得到之前排列好的供残差连接使用的数据
            x = self.act(l(x))
        return x

net = BasicUNet(in_channels = 1,out_channels = 1)
x = torch.rand(8, 1, 28, 28)
# x, y = next(iter(train_dataloader))
print(net(x).shape)

print(sum([p.numel() for p in net.parameters()]))

fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('input')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
axs[1].set_title('output')
axs[1].imshow(torchvision.utils.make_grid(net(x))[0], cmap='Greys')

# show_images(x).resize((8*98, 98),resample=Image.NEAREST)
# show_images(net(x)).resize((8*98, 98),resample=Image.NEAREST)
'''
torch.Size([8, 1, 28, 28])
309185
'''

其中print(sum([p.numel() for p in net.parameters()]))是用来计算Unet网络的参数的

1.3.2 训练模型

1.3.2.1 训练代码

python
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,drop_last =True)
n_epochs=3

net = BasicUNet(in_channels=1, out_channels=1)
net.to(device)

loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)

losses=[]

# noise = torch.rand([128, 1, 28, 28])
for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        # noise = noise.to(device)
        noise_amount = torch.rand(x.shape[0]).to(device)
        noisy_x = corrupt(x, noise_amount)
        pred = net(noisy_x)
        loss = loss_fn(pred, x)

        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# 查看损失曲线
plt.plot(losses)
plt.ylim(0, 0.1)
'''
Finished epoch 0. Average loss for this epoch: 0.028598
Finished epoch 1. Average loss for this epoch: 0.021729
Finished epoch 2. Average loss for this epoch: 0.019670
'''

1.3.2.2 测试训练模型

这里尝试抓取一批数据得到具有不同程度噪声的数据,然后将它们输入模型获得预测并且观察结果

python
x, y = next(iter(train_dataloader))
x = x[:8]

amount = torch.linspace(0, 1, x.shape[0])
# amount = torch.rand(x.shape[0])
# noise = torch.rand([8, 1, 28, 28])
print('amount',amount)
noised_x = corrupt(x, amount)

with torch.no_grad():
    preds = net(noised_x.to(device)).detach().cpu()

fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupt data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys')
'''
amount tensor([0.0000, 0.1429, 0.2857, 0.4286, 0.5714, 0.7143, 0.8571, 1.0000])
'''

1.4 扩散模型之采样过程

思路:从完全随机的噪声开始,然后每次向着预测方向移动一小部分。

python
n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device)   # 从完全随机的值开始
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
    with torch.no_grad():
        pred = net(x)
    pred_output_history.append(pred.detach().cpu())

    mix_factor = 1 / (n_steps - i)        # 设置朝着预测方向移动多少
    x = x * (1 - mix_factor) + pred * mix_factor  # 移动过程
    step_history.append(x.detach().cpu())   # 记录每一次移动,以便后续绘图使用

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4),sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1),cmap='Greys')
    axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1),cmap='Greys')

当然,也可以将采样过程拆解成更多步,以获得质量更高的图像

python
n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
    # noise_amount = torch.ones((x.shape[0],)).to(device) * (1-(i/n_steps))
    with torch.no_grad():
        pred = net(x)
    min_factor = 1/(n_steps-i)
    x = x*(1-min_factor) + pred*min_factor
    # if (i % 5==0):
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
# show_images(x.detach().cpu()).resize((16*98, 500),resample=Image.NEAREST)

1.5 使用框架搭建网络

Diffusers中的库UNet2DModel有写好的Unet网络可以调用

1.5.1 定义网络

python
# UNet2DModel
#创建网络
net = UNet2DModel(
    sample_size=28,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    )
)
net = net.to(device)
print(sum([p.numel() for p in net.parameters()]))  # 1707009

1.5.2 训练网络

python
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

n_epochs=5


net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)

losses=[]

for epoch in range(n_epochs):
    for x, y in train_dataloader:
        x = x.to(device)
        noise_amount = torch.rand(x.shape[0]).to(device)
        noisy_x = corrupt(x, noise_amount)
        pred = net(noisy_x, 0).sample
        loss = loss_fn(pred, x)

        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

plt.plot(losses)
plt.ylim(0, 0.1)
"""
Finished epoch 0. Average loss for this epoch: 0.012590
Finished epoch 1. Average loss for this epoch: 0.011175
Finished epoch 2. Average loss for this epoch: 0.010676
Finished epoch 3. Average loss for this epoch: 0.010433
Finished epoch 4. Average loss for this epoch: 0.010035
"""

1.5.3 采样生成

python
n_steps = 50
x = torch.rand(32, 1, 28, 28).to(device)
for i in range(n_steps):
    # noise_amount = torch.ones((x.shape[0],)).to(device) * (1-(i/n_steps))
    with torch.no_grad():
        pred = net(x, 0).sample
    min_factor = 1/(n_steps-i +1 )
    x = x*(1-min_factor) + pred*min_factor
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
# show_images(x).resize((10*35*3, 98*3),resample=Image.NEAREST)

2 DDPM和DDIM模型

与之前的自己搭建的扩散模型不同,DDPM和DDIM是更好的扩散模型,主要有两点不同:

1、在训练过程中,模型是预测噪声,即损失函数设置为loss = F.Mse(predict, noise)(这只是一个示意写法),之前的则是预测原图

2、加入了时间步timesteps来调节噪声量

2.1 使用框架搭建

下面进行实战一下,请在jupyter打开运行

2.1.1 定义必要函数

python
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torchvision
from PIL import Image
def show_images(x):
    x = x * 0.5 + 0.5
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

def make_grid(images, size=64):
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)),(i * size, 0))
    return output_im
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print(device)

2.1.2 获取数据集

python
import torchvision
from torchvision import transforms, datasets
image_size = 32
batch_size = 32
preprocess = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])
# 从本地文件夹中加载图片
# data_path = f'/path/of/your/dataset'
# dataset = datasets.ImageFolder(root=data_path,transform=preprocess)
dataset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=preprocess)

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(len(train_dataloader))
xb, yb = next(iter(train_dataloader))
xb = xb.to(device)
print('X Shape', xb.shape)
show_images(xb).resize((show_images(xb).size[0]*2, show_images(xb).size[1]*2),resample = Image.NEAREST)
'''
Files already downloaded and verified
1563
X Shape torch.Size([32, 3, 32, 32])
'''

2.1.3 定义调度器

python
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
timesteps = torch.linspace(0, 999, xb.shape[0]).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noisy X Shape", noisy_xb.shape)
show_images(noisy_xb).resize((show_images(noisy_xb).size[0] * 2, show_images(noisy_xb).size[1]*2),resample = Image.NEAREST)
'''
Noisy X Shape torch.Size([32, 3, 32, 32])
'''

2.1.4 定义Unet网络

python
from diffusers import UNet2DModel
model = UNet2DModel(
    sample_size = image_size,
    in_channels = 3,
    out_channels = 3,
    layers_per_block = 2,
    block_out_channels = (64, 64, 128),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        # "AttnDownBlock2D"
    ),
    up_block_types=(
        # "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
    )
)
model.to(device)
print('参数量:', sum([p.numel() for p in model.parameters()]))
# 参数量: 4654723

2.1.5 加噪预测测试

python
with torch.no_grad():
    # print('noisy_xb', noisy_xb.shape)
    model_prediciton = model(noisy_xb, timesteps).sample
print(model_prediciton.shape)
show_images(model_prediciton).resize((show_images(model_prediciton).size[0] * 2, show_images(model_prediciton).size[1]*2),resample = Image.NEAREST)
'''
torch.Size([32, 3, 32, 32])
'''

2.1.6 模型训练

python
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
optimizer = torch.optim.AdamW(model.parameters(), lr = 4e-4)
losses = []
epochs = 10
for epoch in range(epochs):
    for step, (image, label) in enumerate(train_dataloader):
        clean_images = image.to(device)
        noise = torch.randn(clean_images.shape).to(device)
        bs = clean_images.shape[0]

        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs, ),
            device = device
        ).long()

        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())

        optimizer.step()
        optimizer.zero_grad()
    if (epoch + 1) % 2 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")
    # 保存模型
    # torch.save({'model': model.state_dict()}, f'{epoch}epoch_ddpm_cifar10.pth')

# 损失画图展示
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()

2.1.7 采样去噪

python
sample = torch.randn(32, 3, 32, 32).to(device)

for i, t in enumerate(noise_scheduler.timesteps):
    with torch.no_grad():
        # print(f'{i}', model(sample, t).sample.shape) # [32, 3, 32, 32]
        residual = model(sample, t).sample
        # print('t', t) # 示例: tensor(998, dtype=torch.int32)
        sample = noise_scheduler.step(residual, t, sample).prev_sample
        if (i + 1) % 100 == 0:
            show_images(sample).save(f"./pictures_cifar/{i + 1}epoch_output_image.png")
show_images(sample)

2.1.8 python一键运行版

python
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torchvision
from PIL import Image
from torchvision import transforms, datasets
from diffusers import UNet2DModel, DDPMScheduler
from torch.utils.data import DataLoader
from pathlib import Path


def show_images(x):
    x = x * 0.5 + 0.5
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im


def make_grid(images, size=64):
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im


if __name__ == "__main__":
    image_size = 32
    batch_size = 32
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
    # print(device)
    model = UNet2DModel(
        sample_size=image_size,
        in_channels=3,
        out_channels=3,
        layers_per_block=2,
        block_out_channels=(64, 128, 128, 256),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D"
        )
    )
    model.to(device)
    print('参数量:', sum([p.numel() for p in model.parameters()]))

    preprocess = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    # 从本地文件夹中加载图片
    # data_path = f'path/of/your/dataset'
    # dataset = datasets.ImageFolder(root=data_path,transform=preprocess)

    dataset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=preprocess)

    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print('len(train_dataloader)', len(train_dataloader))
    xb, yb = next(iter(train_dataloader))
    xb = xb.to(device)
    print('X Shape', xb.shape)
    # show_images(xb).resize((show_images(xb).size[0] * 2, show_images(xb).size[1] * 2), resample=Image.NEAREST)

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
    optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
    losses = []
    epochs = 10

    # 保存模型
    model_dir = f'./model'
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    # 保存采样图片
    save_pictures_dir = f'./pictures'
    Path(save_pictures_dir).mkdir(parents=True, exist_ok=True)

    # 训练过程
    model.train()
    for epoch in range(epochs):
        for step, (image, label) in enumerate(train_dataloader):
            clean_images = image.to(device)
            noise = torch.randn(clean_images.shape).to(device)
            bs = clean_images.shape[0]

            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,),
                device=device
            ).long()

            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

            loss = F.mse_loss(noise_pred, noise)
            loss.backward(loss)
            losses.append(loss.item())

            optimizer.step()
            optimizer.zero_grad()
        if (epoch + 1) % 2 == 0:
            loss_last_epoch = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
            print(f"Epoch:{epoch + 1}, loss: {loss_last_epoch}")

        torch.save({'model': model.state_dict()}, f'{model_dir}/{epoch}epoch_ddpm_cifar10.pth')
    # fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    # axs[0].plot(losses)
    # axs[1].plot(np.log(losses))
    # plt.show()

    # 采样过程
    model.eval()
    sample = torch.randn(batch_size, 3, image_size, image_size).to(device)

    for i, t in enumerate(noise_scheduler.timesteps):
        with torch.no_grad():
            # print(f'{i}', model(sample, t).sample.shape) # [32, 3, 32, 32]
            residual = model(sample, t).sample
            # print('t', t) # 示例: tensor(998, dtype=torch.int32)
            sample = noise_scheduler.step(residual, t, sample).prev_sample
            if (i + 1) % 100 == 0 or i == 0:
                show_images(sample).save(f"{save_pictures_dir}/{i + 1}epoch_output_image.png")
    print(f'Save Image in {save_pictures_dir}')

[[#342 v]]

最近更新