Skip to content
0

文章发布较早,内容可能过时,阅读注意甄别。

生成对抗网络(GAN)

以下是GAN网络的一个简单实现

输入:单张图片,输出:单张图片

1、代码

请创建5个文件

python
Config.py------------配置函数
GAN_Model.py---------模型函数
main.py--------------主函数
Train.py-------------训练函数
True_Image.py--------预处理函数

1.1 Config.py

python
import argparse


def parse_args():
    parser = argparse.ArgumentParser(description='GAN NetWork')
    parser.add_argument('--image_path', type=str, required=True, help='Path to dataset')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=2e-4, help='Learning rate')
    parser.add_argument('--epochs', type=int, default=50000, help='Number of Epoch')
    parser.add_argument('--optimizer', type=str, default='Adam', help='Optimizer to use')
    parser.add_argument('--device', type=str, default='cpu', help='Device to use')
    parser.add_argument('--latent_dim', type=int, default=100, help='Latent dimension')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--save_iter', type=int, default=100, help='Save iter')
    parser.add_argument('--save_path', type=str, default='./save', help='Path to save picture')
    args = parser.parse_args()
    return args

# print(parse_args())

1.2 GAN_Model.py

python
import torch
import torch.nn as nn


# 生成器模型
class Generator(nn.Module):
    def __init__(self, latent_dim, image_size):
        super(Generator, self).__init__()
        # self.lantent_dim = lantent_dim
        self.main = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, image_size),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)


# 判别器模型
class Discriminator(nn.Module):
    def __init__(self, image_size):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(image_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

NOTE

这里的网络使用的是线性模型搭建,可以进行修改

1.3 main.py

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from PIL import Image
from GAN_Model import Discriminator, Generator
from True_Image import get_image
import numpy as np
from Config import parse_args
from Train import train
import os
if __name__ == '__main__':
    args = parse_args()

    # 获取运行的设备
    if args.device != 'cpu':
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    # 设置随机种子以便复现结果
    torch.manual_seed(args.seed)

    # 超参数
    image_path = args.image_path
    batch_size, channel, width, length, real_images = get_image(image_path)
    image_size = width * length * channel
    latent_dim = args.latent_dim
    epochs = args.epochs

    # 创建生成器和判别器
    generator = Generator(latent_dim=latent_dim, image_size=image_size).to(device)
    discriminator = Discriminator(image_size=image_size).to(device)

    # 初始化模型、损失函数和优化器
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=args.learning_rate, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=args.learning_rate, betas=(0.5, 0.999))

    # 每save_iter打印一次,保存图片
    save_iter = args.save_iter

    # 保存路径
    save_path = args.save_path
    if not os.path.exists(save_path):
        os.makedirs(save_path, exist_ok=True)


    # 训练GAN
    train(epochs, real_images, device, discriminator, generator, criterion, latent_dim,
          optimizer_d, optimizer_g, batch_size, channel, width, length, save_iter,save_path)

1.4 Train.py

python
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torchvision.utils as vutils


# 训练GAN
def train(epochs, real_image, device, discriminator, generator, criterion, latent_dim,
          optimizer_d, optimizer_g, batch_size, channel, width, length, save_iter, save_path):

    for epoch in range(epochs):
        # for i, data in enumerate(dataloader, 0):
        #     real_images, _ = data
        real_image = real_image.view(real_image.size(0), -1).to(device)

        # 训练判别器
        discriminator.zero_grad()
        real_labels = torch.ones(real_image.size(0), 1).to(device)
        fake_labels = torch.zeros(real_image.size(0), 1).to(device)

        outputs_real = discriminator(real_image)
        d_loss_real = criterion(outputs_real, real_labels)
        d_loss_real.backward()

        z = torch.randn(real_image.size(0), latent_dim).to(device)
        fake_images = generator(z)
        outputs_fake = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs_fake, fake_labels)
        d_loss_fake.backward()

        d_loss = d_loss_real + d_loss_fake
        optimizer_d.step()

        # 训练生成器
        generator.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()

        if epoch % save_iter == 0:
            print(f"Epoch [{epoch + 1}/{epochs}]"
                  f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

            # 保存生成的图像
            with torch.no_grad():
                fake_images = generator(torch.randn(batch_size, latent_dim).to(device))
                fake_images = fake_images.view(fake_images.size(0), channel, width, length)

                # 显示生成的图像
                transform1 = transforms.ToPILImage()
                fake_images = fake_images.squeeze()
                another_image = transform1(fake_images)
                plt.imshow(another_image)

                if channel == 3:
                    vutils.save_image(fake_images, f"{save_path}/gan_generated_epoch_{epoch + 1}.jpg", normalize=True)
                    print(f"生成的图像已保存为 gan_generated_epoch_{epoch}.jpg")
                else:
                    vutils.save_image(fake_images, f"{save_path}/gan_generated_epoch_{epoch + 1}.png", normalize=True)
                    print(f"生成的图像已保存为 gan_generated_epoch_{epoch}.png")

# def plot_generated_images(save_image_path):
#     img = Image.open(save_image_path)
#     plt.imshow(img)
#     plt.axis('off')
#     plt.show()
# plot_generated_images(f"gan_generated_epoch_{epoch + 1}.png")

# def show_images(images, title=None):
#     grid = vutils.make_grid(images, normalize=True)
#     print(grid.shape)
#     np_img = grid.numpy()
#     plt.figure(figsize=(8, 8))
#     plt.imshow(np.transpose(np_img, (1, 2, 0)), cmap='gray')
#     print(np.transpose(np_img, (1, 2, 0)).shape)
#     plt.axis('off')
#     if title:
#         plt.title(title)
#     plt.show()

1.5 True_Image.py

python
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt
from PIL import Image


def get_image(image_path):
    # 定义转换操作,将PIL图片转换为tensor
    transform = transforms.Compose([
        transforms.Resize((150,200)),  # 根据需要修改
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    # 将图片转为4个维度 --> (1,3,256,256) or (1,4,256,256)
    real_images = transform(Image.open(image_path)).unsqueeze(0)
    batch, channel, width, length = real_images.shape
    # print(batch, channel, width, length)
    return batch, channel, width, length, real_images

# 数据加载和预处理
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,), (0.5,))
# ])
#
# dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
#
# # 显示数据集中的一个样本图像
# data_iter = iter(dataloader)
# sample_data = next(data_iter)
# sample_image, _ = sample_data
# sample_image = sample_image.view(sample_image.size(0), 1, 28, 28)  # 将图像大小调整为[1, 28, 28]
# plt.imshow(sample_image[0, 0].numpy(), cmap='gray')
# plt.axis('off')
# plt.title('Sample Image from the Dataset')
# plt.show()

2、运行

可以运行python main.py -h查看帮助

python
usage: main.py [-h] --image_path IMAGE_PATH [--batch_size BATCH_SIZE] [--learning_rate LEARNING_RATE] [--epochs EPOCHS] [--optimizer OPTIMIZER] [--device DEVICE]
               [--latent_dim LATENT_DIM] [--seed SEED] [--save_iter SAVE_ITER] [--save_path SAVE_PATH]

GAN NetWork

optional arguments:
  -h, --help            show this help message and exit
  --image_path IMAGE_PATH
                        Path to dataset
  --batch_size BATCH_SIZE
                        Batch size
  --learning_rate LEARNING_RATE
                        Learning rate
  --epochs EPOCHS       Number of Epoch
  --optimizer OPTIMIZER
                        Optimizer to use
  --device DEVICE       Device to use
  --latent_dim LATENT_DIM
                        Latent dimension
  --seed SEED           Random seed
  --save_iter SAVE_ITER
                        Save iter
  --save_path SAVE_PATH
                        Path to save picture

运行示例如下:

bash
python main.py --device='cuda:0' --image_path='/path/of/your/picture'

最后会生成一个save文件夹,里面会保存每一轮训练得到的图片

3、其他

使用KAN网络代替Linear网络

3.1 新建kan.py文件

需要在同级目录下新建一个文件kan.py

python
import torch
import torch.nn.functional as F
import math

"""
1.内存效率提升:原始实现需要扩展所有中间变量来执行不同的激活函数,而此代码中将计算重新制定为使用不同的基函数激活输入,
  然后线性组合它们。这种重新制定可以显著降低内存成本,并将计算变得更加高效。

2.正则化方法的改变:原始实现中使用的L1正则化需要对张量进行非线性操作,与重新制定的计算不兼容。
  因此,此代码中将L1正则化改为对权重的L1正则化,这更符合神经网络中常见的正则化方法,并且与重新制定的计算兼容。

3.激活函数缩放选项:原始实现中包括了每个激活函数的可学习缩放,但这个库提供了一个选项来禁用这个特性。
  禁用缩放可以使模型更加高效,但可能会影响结果。

4.参数初始化的改变:为了解决在MNIST数据集上的性能问题,此代码修改了参数的初始化方式,使用kaiming初始化。
"""

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,  # 网格大小,默认为 5
        spline_order=3, # 分段多项式的阶数,默认为 3
        scale_noise=0.1,  # 缩放噪声,默认为 0.1
        scale_base=1.0,   # 基础缩放,默认为 1.0
        scale_spline=1.0,    # 分段多项式的缩放,默认为 1.0
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,  # 基础激活函数,默认为 SiLU(Sigmoid Linear Unit)
        grid_eps=0.02,
        grid_range=[-1, 1],  # 网格范围,默认为 [-1, 1]
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size # 设置网格大小和分段多项式的阶数
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size   # 计算网格步长
        grid = ( # 生成网格
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)  # 将网格作为缓冲区注册

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) # 初始化基础权重和分段多项式权重
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:  # 如果启用独立的分段多项式缩放,则初始化分段多项式缩放参数
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise # 保存缩放噪声、基础缩放、分段多项式的缩放、是否启用独立的分段多项式缩放、基础激活函数和网格范围的容差
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()  # 重置参数

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)# 使用 Kaiming 均匀初始化基础权重
        with torch.no_grad():
            noise = (# 生成缩放噪声
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_( # 计算分段多项式权重
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:  # 如果启用独立的分段多项式缩放,则使用 Kaiming 均匀初始化分段多项式缩放参数
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        """
        计算给定输入张量的 B-样条基函数。

        参数:
        x (torch.Tensor): 输入张量,形状为 (batch_size, in_features)。

        返回:
        torch.Tensor: B-样条基函数张量,形状为 (batch_size, in_features, grid_size + spline_order)。
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = ( # 形状为 (in_features, grid_size + 2 * spline_order + 1)
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        """
        计算插值给定点的曲线的系数。

        参数:
        x (torch.Tensor): 输入张量,形状为 (batch_size, in_features)。
        y (torch.Tensor): 输出张量,形状为 (batch_size, in_features, out_features)。
        返回:
        torch.Tensor: 系数张量,形状为 (out_features, in_features, grid_size + spline_order)。
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)
        # 计算 B-样条基函数
        A = self.b_splines(x).transpose(
            0, 1 # 形状为 (in_features, batch_size, grid_size + spline_order)
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features) # 形状为 (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(   # 使用最小二乘法求解线性方程组
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)  # 形状为 (in_features, grid_size + spline_order, out_features)
        result = solution.permute( # 调整结果的维度顺序
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        """
        获取缩放后的分段多项式权重。

        返回:
        torch.Tensor: 缩放后的分段多项式权重张量,形状与 self.spline_weight 相同。
        """
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor): # 将输入数据通过模型的各个层,经过线性变换和激活函数处理,最终得到模型的输出结果
        """
        前向传播函数。

        参数:
        x (torch.Tensor): 输入张量,形状为 (batch_size, in_features)。

        返回:
        torch.Tensor: 输出张量,形状为 (batch_size, out_features)。
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight) # 计算基础线性层的输出
        spline_output = F.linear( # 计算分段多项式线性层的输出
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output  # 返回基础线性层输出和分段多项式线性层输出的和

    @torch.no_grad()
    # 更新网格。
    # 参数:
    # x (torch.Tensor): 输入张量,形状为 (batch_size, in_features)。
    # margin (float): 网格边缘空白的大小。默认为 0.01。
    # 根据输入数据 x 的分布情况来动态更新模型的网格,使得模型能够更好地适应输入数据的分布特点,从而提高模型的表达能力和泛化能力。
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)  # 计算 B-样条基函数
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)  # 调整维度顺序为 (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)  # 调整维度顺序为 (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0] # 对每个通道单独排序以收集数据分布
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)   # 更新网格和分段多项式权重
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        # 计算正则化损失,用于约束模型的参数,防止过拟合
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        """
        计算正则化损失。

        这是对原始 L1 正则化的简单模拟,因为原始方法需要从扩展的(batch, in_features, out_features)中间张量计算绝对值和熵,
        而这个中间张量被 F.linear 函数隐藏起来,如果我们想要一个内存高效的实现。

        现在的 L1 正则化是计算分段多项式权重的平均绝对值。作者的实现也包括这一项,除了基于样本的正则化。

        参数:
        regularize_activation (float): 正则化激活项的权重,默认为 1.0。
        regularize_entropy (float): 正则化熵项的权重,默认为 1.0。

        返回:
        torch.Tensor: 正则化损失。
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module): # 封装了一个KAN神经网络模型,可以用于对数据进行拟合和预测。
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        """
        初始化 KAN 模型。

        参数:
            layers_hidden (list): 包含每个隐藏层输入特征数量的列表。
            grid_size (int): 网格大小,默认为 5。
            spline_order (int): 分段多项式的阶数,默认为 3。
            scale_noise (float): 缩放噪声,默认为 0.1。
            scale_base (float): 基础缩放,默认为 1.0。
            scale_spline (float): 分段多项式的缩放,默认为 1.0。
            base_activation (torch.nn.Module): 基础激活函数,默认为 SiLU。
            grid_eps (float): 网格调整参数,默认为 0.02。
            grid_range (list): 网格范围,默认为 [-1, 1]。
        """
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False): # 调用每个KANLinear层的forward方法,对输入数据进行前向传播计算输出。
        """
        前向传播函数。

        参数:
            x (torch.Tensor): 输入张量,形状为 (batch_size, in_features)。
            update_grid (bool): 是否更新网格。默认为 False。

        返回:
            torch.Tensor: 输出张量,形状为 (batch_size, out_features)。
        """
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):#计算正则化损失的方法,用于约束模型的参数,防止过拟合。
        """
        计算正则化损失。

        参数:
            regularize_activation (float): 正则化激活项的权重,默认为 1.0。
            regularize_entropy (float): 正则化熵项的权重,默认为 1.0。

        返回:
            torch.Tensor: 正则化损失。
        """
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

3.2 修改GAN_Model.py文件

python
import torch
import torch.nn as nn
from kan import KAN

# Linear网络
# 生成器模型
# class Generator(nn.Module):
#     def __init__(self, latent_dim, image_size):
#         super(Generator, self).__init__()
#         # self.lantent_dim = lantent_dim
#         self.main = nn.Sequential(
#             nn.Linear(latent_dim, 256),
#             nn.ReLU(),
#             nn.Linear(256, 512),
#             nn.ReLU(),
#             nn.Linear(512, 1024),
#             nn.ReLU(),
#             nn.Linear(1024, image_size),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         return self.main(x)


# # 判别器模型
# class Discriminator(nn.Module):
#     def __init__(self, image_size):
#         super(Discriminator, self).__init__()
#         self.main = nn.Sequential(
#             nn.Linear(image_size, 1024),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.3),
#             nn.Linear(1024, 512),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.3),
#             nn.Linear(512, 256),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.3),
#             nn.Linear(256, 1),
#             nn.Sigmoid()
#         )

#     def forward(self, x):
#         return self.main(x)

# KAN网络
# 生成器模型
class Generator(nn.Module):
    def __init__(self, latent_dim, image_size):
        super(Generator, self).__init__()
        # self.lantent_dim = lantent_dim
        self.main = KAN([latent_dim, 256, 512, 1024, image_size])
        self.guiyi = nn.Tanh()
    def forward(self, x):
        x = self.main(x)
        x = self.guiyi(x)
        return x


# 判别器模型
class Discriminator(nn.Module):
    def __init__(self, image_size):
        super(Discriminator, self).__init__()
        self.main = KAN([image_size, 1024, 512, 256, 1])
        self.guiyi = nn.Sigmoid()
    def forward(self, x):
        x = self.main(x)
        x = self.guiyi(x)
        return x

WARNING

效果不怎么好

最近更新