跳转至

tempoGAN(temporally Generative Adversarial Networks)

AI Studio快速体验

# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat -P datasets/tempoGAN/
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat -P datasets/tempoGAN/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat --create-dirs -o ./datasets/tempoGAN/2d_train.mat
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat --create-dirs -o ./datasets/tempoGAN/2d_valid.mat
python tempoGAN.py
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat -P datasets/tempoGAN/
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat -P datasets/tempoGAN/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat --create-dirs -o ./datasets/tempoGAN/2d_train.mat
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat --create-dirs -o ./datasets/tempoGAN/2d_valid.mat
python tempoGAN.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/tempoGAN/tempogan_pretrained.pdparams
python tempoGAN.py mode=export
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat -P datasets/tempoGAN/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat --create-dirs -o ./datasets/tempoGAN/2d_valid.mat
python tempoGAN.py mode=infer
预训练模型 指标
tempogan_pretrained.pdparams MSE: 4.21e-5
PSNR: 47.19
SSIM: 0.9974

1. 背景简介

流体模拟方面的问题,捕捉湍流的复杂细节一直是数值模拟的长期挑战,用离散模型解决这些细节会产生巨大的计算成本,对于人类空间和时间尺度上的流动来说,很快就会变得不可行。因此流体超分辨率的需求应运而生,它旨在通过流体动力学模拟和深度学习技术将低分辨率流体模拟结果恢复为高分辨率结果,以减少生成高分辨率流体过程中的巨大计算成本。该技术可以应用于各种流体模拟,例如水流、空气流动、火焰模拟等。

生成式对抗网络 GAN(Generative Adversarial Networks) 是一种使用无监督学习方法的深度学习网络,GAN 网络中(至少)包含两个模型:生成器(Generator) 和判别器(Discriminator),生成器用于生成问题的输出,判别器用于判断输出的真假,两者在相互博弈中共同优化,最终使得生成器的输出接近真实值。

tempoGAN 在 GAN 网络的基础上新增了一个与时间相关的判别器 Discriminator_tempo,该判别器的网络结构与基础判别器相同,但输入为时间连续的几帧数据,而不是单帧数据,从而将时序纳入考虑范围。

本问题主要使用该网络,通过输入的低密度流体数据,得到对应的高密度流体数据,大大节省时间和计算成本。

2. 问题定义

本问题包含三个模型:生成器(Generator)、判别器(Discriminator)和与时间相关的判别器(Discriminator_tempo),根据 GAN 网络的训练流程,这三个模型交替训练,训练顺序依次为:Discriminator、Discriminator_tempo、Generator。 GAN 网络为无监督学习,本问题网络设计中将目标值作为一个输入值,输入网络进行训练。

3. 问题求解

接下来开始讲解如何将问题一步一步地转化为 PaddleScience 代码,用深度学习的方法求解该问题。为了快速理解 PaddleScience,接下来仅对模型构建、约束构建等关键步骤进行阐述,而其余细节请参考 API文档

3.1 数据集介绍

数据集为使用开源代码包 mantaflow 生成的 2d 流体数据集,数据集中包括一定数量连续帧的低、高密度流体图像转化成的数值,以字典的形式存储在 .mat 文件中。

运行本问题代码前请下载 训练数据集验证数据集, 下载后分别存放在路径:

output_dir: ${hydra:run.dir}
log_freq: 20

3.2 模型构建

tempoGAN-arch

tempoGAN 网络模型

上图为tempoGAN 完整的模型结构图,但本问题只针对较为简单的情况进行处理,不涉及包含速度和涡度的输入、3d、数据增强、advection operator 等部分,如果您对这些文档中未包含的内容感兴趣,可以自行修改代码并进行进一步实验。

如上图所示,Generator 的输入为低密度流体数据的插值,输出为生成的高密度流体模拟数据,Discriminator 的输入为低密度流体数据的插值分别与 Generator 生成的高密度流体模拟数据、目标高密度流体数据的拼接, Discriminator_tempo 的输入为多帧连续的 Generator 生成的高密度流体模拟数据以及目标高密度流体数据。

虽然输入输出的组成看起来较为复杂,但本质都是流体的密度数据,因此 3 个网络的映射函数都是 \(f: \mathbb{R}^1 \to \mathbb{R}^1\)

与简单的 MLP 网络不同,根据要解决的问题不同,GAN 的生成器和判别器有多种网络结构可以选择,在此不再赘述。由于这种独特性,本问题中的 tempoGAN 网络没有被内置在 PaddleScience 中,需要额外实现。

本问题中的 Generator 是一个拥有 4 层改良 Res Block 的模型,Discriminator 和 Discriminator_tempo 为同一个拥有 4 层卷积结果的模型,两者网络结构相同但输入不同。Generator、Discriminator 和 Discriminator_tempo 的网络参数也需要额外定义。

具体代码请参考 完整代码 中 gan.py 文件。

由于 GAN 网络中生成器和判别器的中间结果要相互调用,参与对方的 loss 计算,因此使用 Model List 实现,用 PaddleScience 代码表示如下:

# define Generator model
model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
model_gen.register_input_transform(gen_funcs.transform_in)
disc_funcs.model_gen = model_gen

model_tuple = (model_gen,)
# define Discriminators
if cfg.USE_SPATIALDISC:
    model_disc = ppsci.arch.Discriminator(**cfg.MODEL.disc_net)
    model_disc.register_input_transform(disc_funcs.transform_in)
    model_tuple += (model_disc,)

# define temporal Discriminators
if cfg.USE_TEMPODISC:
    model_disc_tempo = ppsci.arch.Discriminator(**cfg.MODEL.tempo_net)
    model_disc_tempo.register_input_transform(disc_funcs.transform_in_tempo)
    model_tuple += (model_disc_tempo,)

# define model_list
model_list = ppsci.arch.ModelList(model_tuple)

注意到上述代码中定义的网络输入与实际网络输入不完全一样,因此需要对输入进行transform。

3.3 transform构建

Generator 的输入为低密度流体数据的插值,而数据集中保存的为原始的低密度流体数据,因此需要进行一个插值的 transform。

def transform_in(self, _in):
    ratio = 2
    input_dict = reshape_input(_in)
    density_low = input_dict["density_low"]
    density_low_inp = interpolate(density_low, ratio, "nearest")
    return {"input_gen": density_low_inp}

Discriminator 和 Discriminator_tempo 对输入的 transform 更为复杂,分别为:

def transform_in(self, _in):
    ratio = 2
    input_dict = reshape_input(_in)
    density_low = input_dict["density_low"]
    density_high_from_target = input_dict["density_high"]

    density_low_inp = interpolate(density_low, ratio, "nearest")

    density_high_from_gen = self.model_gen(input_dict)["output_gen"]
    density_high_from_gen.stop_gradient = True

    density_input_from_target = paddle.concat(
        [density_low_inp, density_high_from_target], axis=1
    )
    density_input_from_gen = paddle.concat(
        [density_low_inp, density_high_from_gen], axis=1
    )
    return {
        "input_disc_from_target": density_input_from_target,
        "input_disc_from_gen": density_input_from_gen,
    }

def transform_in_tempo(self, _in):
    density_high_from_target = _in["density_high"]

    input_dict = reshape_input(_in)
    density_high_from_gen = self.model_gen(input_dict)["output_gen"]
    density_high_from_gen.stop_gradient = True

    input_trans = {
        "input_tempo_disc_from_target": density_high_from_target,
        "input_tempo_disc_from_gen": density_high_from_gen,
    }

    return dereshape_input(input_trans, 3)

其中:

density_high_from_gen.stop_gradient = True

表示停止参数的计算梯度,这样设置是因为这个变量在这里仅作为 Discriminator 和 Discriminator_tempo 的输入,在反向计算时不应该参与梯度回传,如果不进行这样的设置,由于这个变量来源于 Generator 的输出,在反向传播时梯度会沿着这个变量传给 Generator,从而改变 Generator 中的参数,这显然不是我们想要的。

这样,我们就实例化出了一个拥有 Generator、Discriminator 和 Discriminator_tempo 并包含输入 transform 的神经网络模型 model list

3.4 参数和超参数设定

我们需要指定问题相关的参数,如数据集路径、各项 loss 的权重参数等。

output_dir: ${hydra:run.dir}
log_freq: 20
DATASET_PATH: ./datasets/tempoGAN/2d_train.mat
DATASET_PATH_VALID: ./datasets/tempoGAN/2d_valid.mat

# set working condition
USE_AMP: true
USE_SPATIALDISC: true
USE_TEMPODISC: true
WEIGHT_GEN: [5.0, 0.0, 1.0]  # lambda_l1, lambda_l2, lambda_t
WEIGHT_GEN_LAYER: [-1.0e-5, -1.0e-5, -1.0e-5, -1.0e-5, -1.0e-5]

注意到其中包含 3 个 bool 类型的变量 use_ampuse_spatialdiscuse_tempodisc,它们分别表示是否使用混合精度训练(AMP)、是否使用 Discriminator 和是否使用 Discriminator_tempo,当 use_spatialdiscuse_tempodisc 都被设置为 False 时,本问题的网络结构将会变为一个单纯的 Genrator 模型,不再是 GAN 网络了。

同时需要指定训练轮数和学习率等超参数,注意由于 GAN 网络训练流程与一般单个模型的网络不同,EPOCHS 的设置也有所不同。

# training settings
TRAIN:
  epochs: 40000
  epochs_gen: 1

3.5 优化器构建

训练使用 Adam 优化器,学习率在 Epoch 达到一半时减小到原来的 \(1/20\),因此使用 Step 方法作为学习率策略。如果将 by_epoch 设为 True,学习率将根据训练的 Epoch 改变,否则将根据 Iteration 改变。

# initialize Adam optimizer
lr_scheduler_gen = ppsci.optimizer.lr_scheduler.Step(
    step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
)()
optimizer_gen = ppsci.optimizer.Adam(lr_scheduler_gen)(model_gen)
if cfg.USE_SPATIALDISC:
    lr_scheduler_disc = ppsci.optimizer.lr_scheduler.Step(
        step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
    )()
    optimizer_disc = ppsci.optimizer.Adam(lr_scheduler_disc)(model_disc)
if cfg.USE_TEMPODISC:
    lr_scheduler_disc_tempo = ppsci.optimizer.lr_scheduler.Step(
        step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
    )()
    optimizer_disc_tempo = ppsci.optimizer.Adam(lr_scheduler_disc_tempo)(
        (model_disc_tempo,)
    )

3.6 约束构建

本问题采用无监督学习的方式,虽然不是以监督学习方式进行训练,但此处仍然可以采用监督约束 SupervisedConstraint,在定义约束之前,需要给监督约束指定文件路径等数据读取配置,因为 tempoGAN 属于自监督学习,数据集中没有标签数据,而是使用一部分输入数据作为 label,因此需要设置约束的 output_expr

{
    "output_gen": lambda out: out["output_gen"],
    "density_high": lambda out: out["density_high"],
},

3.6.1 Generator 的约束

下面是约束的具体内容,要注意上述提到的 output_expr

sup_constraint_gen = ppsci.constraint.SupervisedConstraint(
    {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {
                "density_low": dataset_train["density_low"],
                "density_high": dataset_train["density_high"],
            },
            "transforms": (
                {
                    "FunctionalTransform": {
                        "transform_func": data_funcs.transform,
                    },
                },
            ),
        },
        "batch_size": cfg.TRAIN.batch_size.sup_constraint,
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
    },
    ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
    {
        "output_gen": lambda out: out["output_gen"],
        "density_high": lambda out: out["density_high"],
    },
    name="sup_constraint_gen",
)

SupervisedConstraint 的第一个参数是监督约束的读取配置,其中 dataset 字段表示使用的训练数据集信息,各个字段分别表示:

  1. name: 数据集类型,此处 NamedArrayDataset 表示从 Array 中读取的 .mat 类型的数据集;
  2. input: Array 类型的输入数据;
  3. label: Array 类型的标签数据;
  4. transforms: 所有数据 transform 方法,此处 FunctionalTransform 为PaddleScience 预留的自定义数据 transform 类,该类支持编写代码时自定义输入数据的 transform,具体代码请参考 自定义 loss 和 data transform

batch_size 字段表示 batch的大小;

sampler 字段表示采样方法,其中各个字段表示:

  1. name: 采样器类型,此处 BatchSampler 表示批采样器;
  2. drop_last: 是否需要丢弃最后无法凑整一个 mini-batch 的样本,默认值为 False;
  3. shuffle: 是否需要在生成样本下标时打乱顺序,默认值为 False;

第二个参数是损失函数,此处的 FunctionalLoss 为 PaddleScience 预留的自定义 loss 函数类,该类支持编写代码时自定义 loss 的计算方法,而不是使用诸如 MSE 等现有方法,具体代码请参考 自定义 loss 和 data transform

第三个参数是约束条件的 output_expr,如上所述,是为了让程序可以将输入数据作为 label

第四个参数是约束条件的名字,我们需要给每一个约束条件命名,方便后续对其索引。

在约束构建完毕之后,以我们刚才的命名为关键字,封装到一个字典中,方便后续访问,由于本问题设置了use_spatialdiscuse_tempodisc,导致 Generator 的部分约束不一定存在,因此先封装一定存在的约束到字典中,当其余约束存在时,在向字典中添加约束元素。

if cfg.USE_TEMPODISC:
    sup_constraint_gen_tempo = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low_tempo"],
                    "density_high": dataset_train["density_high_tempo"],
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen_tempo),
        {
            "output_gen": lambda out: out["output_gen"],
            "density_high": lambda out: out["density_high"],
        },
        name="sup_constraint_gen_tempo",
    )
    constraint_gen[sup_constraint_gen_tempo.name] = sup_constraint_gen_tempo

3.6.2 Discriminator 的约束

if cfg.USE_SPATIALDISC:
    sup_constraint_disc = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low"],
                    "density_high": dataset_train["density_high"],
                },
                "label": {
                    "out_disc_from_target": np.ones(
                        (np.shape(dataset_train["density_high"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                    "out_disc_from_gen": np.ones(
                        (np.shape(dataset_train["density_high"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": cfg.TRAIN.batch_size.sup_constraint,
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(disc_funcs.loss_func),
        name="sup_constraint_disc",
    )
    constraint_disc = {sup_constraint_disc.name: sup_constraint_disc}

各个参数含义与Generator 的约束相同。

3.6.3 Discriminator_tempo 的约束

if cfg.USE_TEMPODISC:
    sup_constraint_disc_tempo = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low_tempo"],
                    "density_high": dataset_train["density_high_tempo"],
                },
                "label": {
                    "out_disc_tempo_from_target": np.ones(
                        (np.shape(dataset_train["density_high_tempo"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                    "out_disc_tempo_from_gen": np.ones(
                        (np.shape(dataset_train["density_high_tempo"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(disc_funcs.loss_func_tempo),
        name="sup_constraint_disc_tempo",
    )
    constraint_disc_tempo = {
        sup_constraint_disc_tempo.name: sup_constraint_disc_tempo
    }

各个参数含义与Generator 的约束相同。

3.7 可视化器构建

因为 GAN 网络训练的特性,本问题不使用 PaddleScience 中内置的可视化器,而是自定义了一个用于实现推理的函数,该函数读取验证集数据,得到推理结果并将结果以图片形式保存下来,在训练过程中按照一定间隔调用该函数即可在训练过程中监控训练效果。

def predict_and_save_plot(
    output_dir: str,
    epoch_id: int,
    solver_gen: ppsci.solver.Solver,
    dataset_valid: np.ndarray,
    tile_ratio: int = 1,
):
    """Predicting and plotting.

    Args:
        output_dir (str): Output dir path.
        epoch_id (int): Which epoch it is.
        solver_gen (ppsci.solver.Solver): Solver for predicting.
        dataset_valid (np.ndarray): Valid dataset.
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
    """
    dir_pred = "predict/"
    os.makedirs(os.path.join(output_dir, dir_pred), exist_ok=True)

    start_idx = 190
    density_low = dataset_valid["density_low"][start_idx : start_idx + 3]
    density_high = dataset_valid["density_high"][start_idx : start_idx + 3]

    # tile
    density_low = (
        split_data(density_low, tile_ratio) if tile_ratio != 1 else density_low
    )
    density_high = (
        split_data(density_high, tile_ratio) if tile_ratio != 1 else density_high
    )

    pred_dict = solver_gen.predict(
        {
            "density_low": density_low,
            "density_high": density_high,
        },
        {"density_high": lambda out: out["output_gen"]},
        batch_size=tile_ratio * tile_ratio if tile_ratio != 1 else 3,
        no_grad=False,
    )
    if epoch_id == 1:
        # plot interpolated input image
        input_img = np.expand_dims(dataset_valid["density_low"][start_idx], axis=0)
        input_img = paddle.to_tensor(input_img, dtype=paddle.get_default_dtype())
        input_img = F.interpolate(
            input_img,
            [input_img.shape[-2] * 4, input_img.shape[-1] * 4],
            mode="nearest",
        ).numpy()
        Img.imsave(
            os.path.join(output_dir, dir_pred, "input.png"),
            np.squeeze(input_img),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
        # plot target image
        Img.imsave(
            os.path.join(output_dir, dir_pred, "target.png"),
            np.squeeze(dataset_valid["density_high"][start_idx]),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
    # plot pred image
    pred_img = (
        concat_data(pred_dict["density_high"].numpy(), tile_ratio)
        if tile_ratio != 1
        else np.squeeze(pred_dict["density_high"][0].numpy())
    )
    Img.imsave(
        os.path.join(output_dir, dir_pred, f"pred_epoch_{str(epoch_id)}.png"),
        pred_img,
        vmin=0.0,
        vmax=1.0,
        cmap="gray",
    )

3.8 自定义 loss 和 data transform

由于本问题采用无监督学习,数据中不存在标签数据,loss 为计算得到,因此需要自定义 loss 。方法为先定义相关函数,再将函数名作为参数传给 FunctionalLoss。需要注意自定义 loss 函数的输入输出参数需要与 PaddleScience 中如 MSE 等其他函数保持一致,即输入为模型输出 output_dict 等字典变量,输出为 loss 值 paddle.Tensor

3.8.1 Generator 的 loss

Generator 的 loss 提供了 l1 loss、l2 loss、输出经过 Discriminator 判断的 loss 和 输出经过 Discriminator_tempo 判断的 loss。这些 loss 是否存在根据权重参数控制,若某一项 loss 的权重参数为 0,则表示训练中不添加该 loss 项。

def loss_func_gen(self, output_dict: Dict, *args) -> paddle.Tensor:
    """Calculate loss of generator when use spatial discriminator.
        The loss consists of l1 loss, l2 loss and layer loss when use spatial discriminator.
        Notice that all item of loss is optional because weight of them might be 0.

    Args:
        output_dict (Dict): output dict of model.

    Returns:
        paddle.Tensor: Loss of generator.
    """
    # l1 loss
    loss_l1 = F.l1_loss(
        output_dict["output_gen"], output_dict["density_high"], "mean"
    )
    losses = loss_l1 * self.weight_gen[0]

    # l2 loss
    loss_l2 = F.mse_loss(
        output_dict["output_gen"], output_dict["density_high"], "mean"
    )
    losses += loss_l2 * self.weight_gen[1]

    if self.weight_gen_layer is not None:
        # disc(generator_out) loss
        out_disc_from_gen = output_dict["out_disc_from_gen"][-1]
        label_ones = paddle.ones_like(out_disc_from_gen)
        loss_gen = F.binary_cross_entropy_with_logits(
            out_disc_from_gen, label_ones, reduction="mean"
        )
        losses += loss_gen

        # layer loss
        key_list = list(output_dict.keys())
        # ["out0_layer0","out0_layer1","out0_layer2","out0_layer3","out_disc_from_target",
        # "out1_layer0","out1_layer1","out1_layer2","out1_layer3","out_disc_from_gen"]
        loss_layer = 0
        for i in range(1, len(self.weight_gen_layer)):
            # i = 0,1,2,3
            loss_layer += (
                self.weight_gen_layer[i]
                * F.mse_loss(
                    output_dict[key_list[i]],
                    output_dict[key_list[5 + i]],
                    reduction="sum",
                )
                / 2
            )
        losses += loss_layer * self.weight_gen_layer[0]

    return {"output_gen": losses}

def loss_func_gen_tempo(self, output_dict: Dict, *args) -> paddle.Tensor:
    """Calculate loss of generator when use temporal discriminator.
        The loss is cross entropy loss when use temporal discriminator.

    Args:
        output_dict (Dict): output dict of model.

    Returns:
        paddle.Tensor: Loss of generator.
    """
    out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"][-1]
    label_t_ones = paddle.ones_like(out_disc_tempo_from_gen)

    loss_gen_t = F.binary_cross_entropy_with_logits(
        out_disc_tempo_from_gen, label_t_ones, reduction="mean"
    )
    losses = loss_gen_t * self.weight_gen[2]
    return {"out_disc_tempo_from_gen": losses}

3.8.2 Discriminator 的 loss

Discriminator 为判别器,它的作用是判断数据为真数据还是假数据,因此它的 loss 为 Generator 产生的数据应当判断为假而产生的 loss 和 目标值数据应当判断为真而产生的 loss。

def loss_func(self, output_dict, *args):
    out_disc_from_target = output_dict["out_disc_from_target"]
    out_disc_from_gen = output_dict["out_disc_from_gen"]

    label_ones = paddle.ones_like(out_disc_from_target)
    label_zeros = paddle.zeros_like(out_disc_from_gen)

    loss_disc_from_target = F.binary_cross_entropy_with_logits(
        out_disc_from_target, label_ones, reduction="mean"
    )
    loss_disc_from_gen = F.binary_cross_entropy_with_logits(
        out_disc_from_gen, label_zeros, reduction="mean"
    )
    losses = loss_disc_from_target * self.weight_disc + loss_disc_from_gen
    return losses

3.8.3 Discriminator_tempo 的 loss

Discriminator_tempo 的 loss 构成 与 Discriminator 相同,只是所需数据不同。

def loss_func_tempo(self, output_dict, *args):
    out_disc_tempo_from_target = output_dict["out_disc_tempo_from_target"]
    out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"]

    label_ones = paddle.ones_like(out_disc_tempo_from_target)
    label_zeros = paddle.zeros_like(out_disc_tempo_from_gen)

    loss_disc_tempo_from_target = F.binary_cross_entropy_with_logits(
        out_disc_tempo_from_target, label_ones, reduction="mean"
    )
    loss_disc_tempo_from_gen = F.binary_cross_entropy_with_logits(
        out_disc_tempo_from_gen, label_zeros, reduction="mean"
    )
    losses = (
        loss_disc_tempo_from_target * self.weight_disc + loss_disc_tempo_from_gen
    )
    return losses

3.8.4 自定义 data transform

本问题提供了一种输入数据处理方法,将输入的流体密度数据随机裁剪一块,然后进行密度值判断,若裁剪下来的块密度值低于阈值则重新裁剪,直到密度满足条件或裁剪次数达到阈值。这样做主要是为了减少训练所需的显存,同时对裁剪下来的块密度值的判断保证了块中信息的丰富程度。参数和超参数设定tile_ratio 表示原始尺寸是块的尺寸的几倍,即若tile_ratio 为 2,裁剪下来的块的大小为整张原始图片的四分之一。

class DataFuncs:
    """All functions used for data transform.

    Args:
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
        density_min (float, optional): Minimize density of one tile. Defaults to 0.02.
        max_turn (int, optional): Maximize turn of taking a tile from one image. Defaults to 20.
    """

    def __init__(
        self, tile_ratio: int = 1, density_min: float = 0.02, max_turn: int = 20
    ) -> None:
        self.tile_ratio = tile_ratio
        self.density_min = density_min
        self.max_turn = max_turn

    def transform(
        self,
        input_item: Dict[str, np.ndarray],
        label_item: Dict[str, np.ndarray],
        weight_item: Dict[str, np.ndarray],
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        if self.tile_ratio == 1:
            return input_item, label_item, weight_item
        for _ in range(self.max_turn):
            rand_ratio = np.random.rand()
            density_low = self.cut_data(input_item["density_low"], rand_ratio)
            density_high = self.cut_data(input_item["density_high"], rand_ratio)
            if self.is_valid_tile(density_low):
                break

        input_item["density_low"] = density_low
        input_item["density_high"] = density_high
        return input_item, label_item, weight_item

    def cut_data(self, data: np.ndarray, rand_ratio: float) -> paddle.Tensor:
        # data: C,H,W
        _, H, W = data.shape
        if H % self.tile_ratio != 0 or W % self.tile_ratio != 0:
            exit(
                f"ERROR: input images cannot be divided into {self.tile_ratio} parts evenly!"
            )
        tile_shape = [H // self.tile_ratio, W // self.tile_ratio]
        rand_shape = np.floor(rand_ratio * (np.array([H, W]) - np.array(tile_shape)))
        start = [int(rand_shape[0]), int(rand_shape[1])]
        end = [int(rand_shape[0] + tile_shape[0]), int(rand_shape[1] + tile_shape[1])]
        data = paddle.slice(
            paddle.to_tensor(data), axes=[-2, -1], starts=start, ends=end
        )

        return data

    def is_valid_tile(self, tile: paddle.Tensor):
        img_density = tile[0].sum()
        return img_density >= (
            self.density_min * tile.shape[0] * tile.shape[1] * tile.shape[2]
        )

注意,此处代码仅提供 data transform 的思路。当前代码中简单的分块方法由于输入包含的信息变少,显然会影响训练效果,因此本问题中当显存充足时,应当将 tile_ratio 设置为 1,当显存不足时,也建议优先考虑使用混合精度训练来减少现存占用。

3.9 模型训练

完成上述设置之后,首先需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练。

solver_gen = ppsci.solver.Solver(
    model_list,
    constraint_gen,
    cfg.output_dir,
    optimizer_gen,
    lr_scheduler_gen,
    cfg.TRAIN.epochs_gen,
    cfg.TRAIN.iters_per_epoch,
    eval_during_train=cfg.TRAIN.eval_during_train,
    use_amp=cfg.USE_AMP,
    amp_level=cfg.TRAIN.amp_level,
)

注意 GAN 类型的网络训练方法为多个模型交替训练,与单一模型或多模型分阶段训练不同,不能简单的使用 solver.train API,具体代码请参考 完整代码 中 tempoGAN.py 文件。

3.10 模型评估

3.10.1 训练中评估

训练中仅在特定 Epoch 保存特定图片的目标结果和模型输出结果,训练结束后针对最后一个 Epoch 的输出结果进行一次评估,以便直观评价模型优化效果。不使用 PaddleScience 中内置的评估器,也不在训练过程中进行评估:

for i in range(1, cfg.TRAIN.epochs + 1):
    logger.message(f"\nEpoch: {i}\n")
    # plotting during training
    if i == 1 or i % PRED_INTERVAL == 0 or i == cfg.TRAIN.epochs:
        func_module.predict_and_save_plot(
            cfg.output_dir, i, solver_gen, dataset_valid, cfg.TILE_RATIO
        )
############### evaluation for training ###############
img_target = (
    func_module.get_image_array(
        os.path.join(cfg.output_dir, "predict", "target.png")
    )
    / 255.0
)
img_pred = (
    func_module.get_image_array(
        os.path.join(
            cfg.output_dir, "predict", f"pred_epoch_{cfg.TRAIN.epochs}.png"
        )
    )
    / 255.0
)
eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(img_target, img_pred)
logger.message(f"MSE: {eval_mse}, PSNR: {eval_psnr}, SSIM: {eval_ssim}")

具体代码请参考 完整代码 中 tempoGAN.py 文件。

3.10.2 eval 中评估

本问题的评估指标为,将模型输出的超分结果与实际高分辨率图片做对比,使用三个指标 MSE(Mean-Square Error) 、PSNR(Peak Signal-to-Noise Ratio) 、SSIM(Structural SIMilarity) 来评价图片相似度。因此没有使用 PaddleScience 中的内置评估器,也没有 Solver.eval() 过程。

def evaluate(cfg: DictConfig):
    if cfg.EVAL.save_outs:
        from matplotlib import image as Img

        os.makedirs(osp.join(cfg.output_dir, "eval_outs"), exist_ok=True)

    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    gen_funcs = func_module.GenFuncs(cfg.WEIGHT_GEN, None)

    # load dataset
    dataset_valid = hdf5storage.loadmat(cfg.DATASET_PATH_VALID)

    # define Generator model
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)

    # define model_list
    model_list = ppsci.arch.ModelList((model_gen,))

    # load pretrained model
    save_load.load_pretrain(model_list, cfg.EVAL.pretrained_model_path)

    # set validator
    eval_dataloader_cfg = {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {
                "density_low": dataset_valid["density_low"],
            },
            "label": {"density_high": dataset_valid["density_high"]},
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": 1,
    }
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.MSELoss("mean"),
        {"density_high": lambda out: out["output_gen"]},
        metric={"metric": ppsci.metric.L2Rel()},
        name="sup_validator_gen",
    )

    # customized evalution
    def scale(data):
        smax = np.max(data)
        smin = np.min(data)
        return (data - smin) / (smax - smin)

    eval_mse_list = []
    eval_psnr_list = []
    eval_ssim_list = []
    for i, (input, label, _) in enumerate(sup_validator.data_loader):
        output_dict = model_list({"density_low": input["density_low"]})
        output_arr = scale(np.squeeze(output_dict["output_gen"].numpy()))
        target_arr = scale(np.squeeze(label["density_high"].numpy()))

        eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(
            target_arr, output_arr
        )
        eval_mse_list.append(eval_mse)
        eval_psnr_list.append(eval_psnr)
        eval_ssim_list.append(eval_ssim)

        if cfg.EVAL.save_outs:
            Img.imsave(
                osp.join(cfg.output_dir, "eval_outs", f"out_{i}.png"),
                output_arr,
                vmin=0.0,
                vmax=1.0,
                cmap="gray",
            )
    logger.message(
        f"MSE: {np.mean(eval_mse_list)}, PSNR: {np.mean(eval_psnr_list)}, SSIM: {np.mean(eval_ssim_list)}"
    )

另外,其中:

if cfg.EVAL.save_outs:
    Img.imsave(
        osp.join(cfg.output_dir, "eval_outs", f"out_{i}.png"),
        output_arr,
        vmin=0.0,
        vmax=1.0,
        cmap="gray",
    )

提供了保存模型输出结果的选择,以便更直观的看出超分后的结果,是否开启由配置文件 EVAL 中的 save_outs 指定:

  checkpoint_path: null

# evaluation settings
EVAL:

4. 完整代码

完整代码包含 PaddleScience 具体训练流程代码 tempoGAN.py 和所有自定义函数代码 functions.py,另外还向 ppsci.arch 添加了网络结构代码 gan.py,一并显示在下面,如果需要自定义网络结构,可以作为参考。

tempoGAN.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from os import path as osp

import functions as func_module
import hydra
import numpy as np
import paddle
from omegaconf import DictConfig

import ppsci
from ppsci.utils import checker
from ppsci.utils import logger
from ppsci.utils import save_load

if not checker.dynamic_import_to_globals("hdf5storage"):
    raise ImportError(
        "Could not import hdf5storage python package. "
        "Please install it with `pip install hdf5storage`."
    )
import hdf5storage


def train(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")

    gen_funcs = func_module.GenFuncs(
        cfg.WEIGHT_GEN, (cfg.WEIGHT_GEN_LAYER if cfg.USE_SPATIALDISC else None)
    )
    disc_funcs = func_module.DiscFuncs(cfg.WEIGHT_DISC)
    data_funcs = func_module.DataFuncs(cfg.TILE_RATIO)

    # load dataset
    logger.message(
        "Attention! Start loading datasets, this will take tens of seconds to several minutes, please wait patiently."
    )
    dataset_train = hdf5storage.loadmat(cfg.DATASET_PATH)
    logger.message("Finish loading training dataset.")
    dataset_valid = hdf5storage.loadmat(cfg.DATASET_PATH_VALID)
    logger.message("Finish loading validation dataset.")

    # define Generator model
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)
    disc_funcs.model_gen = model_gen

    model_tuple = (model_gen,)
    # define Discriminators
    if cfg.USE_SPATIALDISC:
        model_disc = ppsci.arch.Discriminator(**cfg.MODEL.disc_net)
        model_disc.register_input_transform(disc_funcs.transform_in)
        model_tuple += (model_disc,)

    # define temporal Discriminators
    if cfg.USE_TEMPODISC:
        model_disc_tempo = ppsci.arch.Discriminator(**cfg.MODEL.tempo_net)
        model_disc_tempo.register_input_transform(disc_funcs.transform_in_tempo)
        model_tuple += (model_disc_tempo,)

    # define model_list
    model_list = ppsci.arch.ModelList(model_tuple)

    # initialize Adam optimizer
    lr_scheduler_gen = ppsci.optimizer.lr_scheduler.Step(
        step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
    )()
    optimizer_gen = ppsci.optimizer.Adam(lr_scheduler_gen)(model_gen)
    if cfg.USE_SPATIALDISC:
        lr_scheduler_disc = ppsci.optimizer.lr_scheduler.Step(
            step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
        )()
        optimizer_disc = ppsci.optimizer.Adam(lr_scheduler_disc)(model_disc)
    if cfg.USE_TEMPODISC:
        lr_scheduler_disc_tempo = ppsci.optimizer.lr_scheduler.Step(
            step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
        )()
        optimizer_disc_tempo = ppsci.optimizer.Adam(lr_scheduler_disc_tempo)(
            (model_disc_tempo,)
        )

    # Generator
    # manually build constraint(s)
    sup_constraint_gen = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low"],
                    "density_high": dataset_train["density_high"],
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": cfg.TRAIN.batch_size.sup_constraint,
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
        {
            "output_gen": lambda out: out["output_gen"],
            "density_high": lambda out: out["density_high"],
        },
        name="sup_constraint_gen",
    )
    constraint_gen = {sup_constraint_gen.name: sup_constraint_gen}
    if cfg.USE_TEMPODISC:
        sup_constraint_gen_tempo = ppsci.constraint.SupervisedConstraint(
            {
                "dataset": {
                    "name": "NamedArrayDataset",
                    "input": {
                        "density_low": dataset_train["density_low_tempo"],
                        "density_high": dataset_train["density_high_tempo"],
                    },
                    "transforms": (
                        {
                            "FunctionalTransform": {
                                "transform_func": data_funcs.transform,
                            },
                        },
                    ),
                },
                "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
                "sampler": {
                    "name": "BatchSampler",
                    "drop_last": False,
                    "shuffle": False,
                },
            },
            ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen_tempo),
            {
                "output_gen": lambda out: out["output_gen"],
                "density_high": lambda out: out["density_high"],
            },
            name="sup_constraint_gen_tempo",
        )
        constraint_gen[sup_constraint_gen_tempo.name] = sup_constraint_gen_tempo

    # Discriminators
    # manually build constraint(s)
    if cfg.USE_SPATIALDISC:
        sup_constraint_disc = ppsci.constraint.SupervisedConstraint(
            {
                "dataset": {
                    "name": "NamedArrayDataset",
                    "input": {
                        "density_low": dataset_train["density_low"],
                        "density_high": dataset_train["density_high"],
                    },
                    "label": {
                        "out_disc_from_target": np.ones(
                            (np.shape(dataset_train["density_high"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                        "out_disc_from_gen": np.ones(
                            (np.shape(dataset_train["density_high"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                    },
                    "transforms": (
                        {
                            "FunctionalTransform": {
                                "transform_func": data_funcs.transform,
                            },
                        },
                    ),
                },
                "batch_size": cfg.TRAIN.batch_size.sup_constraint,
                "sampler": {
                    "name": "BatchSampler",
                    "drop_last": False,
                    "shuffle": False,
                },
            },
            ppsci.loss.FunctionalLoss(disc_funcs.loss_func),
            name="sup_constraint_disc",
        )
        constraint_disc = {sup_constraint_disc.name: sup_constraint_disc}

    # temporal Discriminators
    # manually build constraint(s)
    if cfg.USE_TEMPODISC:
        sup_constraint_disc_tempo = ppsci.constraint.SupervisedConstraint(
            {
                "dataset": {
                    "name": "NamedArrayDataset",
                    "input": {
                        "density_low": dataset_train["density_low_tempo"],
                        "density_high": dataset_train["density_high_tempo"],
                    },
                    "label": {
                        "out_disc_tempo_from_target": np.ones(
                            (np.shape(dataset_train["density_high_tempo"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                        "out_disc_tempo_from_gen": np.ones(
                            (np.shape(dataset_train["density_high_tempo"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                    },
                    "transforms": (
                        {
                            "FunctionalTransform": {
                                "transform_func": data_funcs.transform,
                            },
                        },
                    ),
                },
                "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
                "sampler": {
                    "name": "BatchSampler",
                    "drop_last": False,
                    "shuffle": False,
                },
            },
            ppsci.loss.FunctionalLoss(disc_funcs.loss_func_tempo),
            name="sup_constraint_disc_tempo",
        )
        constraint_disc_tempo = {
            sup_constraint_disc_tempo.name: sup_constraint_disc_tempo
        }

    # initialize solver
    solver_gen = ppsci.solver.Solver(
        model_list,
        constraint_gen,
        cfg.output_dir,
        optimizer_gen,
        lr_scheduler_gen,
        cfg.TRAIN.epochs_gen,
        cfg.TRAIN.iters_per_epoch,
        eval_during_train=cfg.TRAIN.eval_during_train,
        use_amp=cfg.USE_AMP,
        amp_level=cfg.TRAIN.amp_level,
    )
    if cfg.USE_SPATIALDISC:
        solver_disc = ppsci.solver.Solver(
            model_list,
            constraint_disc,
            cfg.output_dir,
            optimizer_disc,
            lr_scheduler_disc,
            cfg.TRAIN.epochs_disc,
            cfg.TRAIN.iters_per_epoch,
            eval_during_train=cfg.TRAIN.eval_during_train,
            use_amp=cfg.USE_AMP,
            amp_level=cfg.TRAIN.amp_level,
        )
    if cfg.USE_TEMPODISC:
        solver_disc_tempo = ppsci.solver.Solver(
            model_list,
            constraint_disc_tempo,
            cfg.output_dir,
            optimizer_disc_tempo,
            lr_scheduler_disc_tempo,
            cfg.TRAIN.epochs_disc_tempo,
            cfg.TRAIN.iters_per_epoch,
            eval_during_train=cfg.TRAIN.eval_during_train,
            use_amp=cfg.USE_AMP,
            amp_level=cfg.TRAIN.amp_level,
        )

    PRED_INTERVAL = 200
    for i in range(1, cfg.TRAIN.epochs + 1):
        logger.message(f"\nEpoch: {i}\n")
        # plotting during training
        if i == 1 or i % PRED_INTERVAL == 0 or i == cfg.TRAIN.epochs:
            func_module.predict_and_save_plot(
                cfg.output_dir, i, solver_gen, dataset_valid, cfg.TILE_RATIO
            )

        disc_funcs.model_gen = model_gen
        # train disc, input: (x,y,G(x))
        if cfg.USE_SPATIALDISC:
            solver_disc.train()

        # train disc tempo, input: (y_3,G(x)_3)
        if cfg.USE_TEMPODISC:
            solver_disc_tempo.train()

        # train gen, input: (x,)
        solver_gen.train()

    ############### evaluation for training ###############
    img_target = (
        func_module.get_image_array(
            os.path.join(cfg.output_dir, "predict", "target.png")
        )
        / 255.0
    )
    img_pred = (
        func_module.get_image_array(
            os.path.join(
                cfg.output_dir, "predict", f"pred_epoch_{cfg.TRAIN.epochs}.png"
            )
        )
        / 255.0
    )
    eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(img_target, img_pred)
    logger.message(f"MSE: {eval_mse}, PSNR: {eval_psnr}, SSIM: {eval_ssim}")


def evaluate(cfg: DictConfig):
    if cfg.EVAL.save_outs:
        from matplotlib import image as Img

        os.makedirs(osp.join(cfg.output_dir, "eval_outs"), exist_ok=True)

    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    gen_funcs = func_module.GenFuncs(cfg.WEIGHT_GEN, None)

    # load dataset
    dataset_valid = hdf5storage.loadmat(cfg.DATASET_PATH_VALID)

    # define Generator model
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)

    # define model_list
    model_list = ppsci.arch.ModelList((model_gen,))

    # load pretrained model
    save_load.load_pretrain(model_list, cfg.EVAL.pretrained_model_path)

    # set validator
    eval_dataloader_cfg = {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {
                "density_low": dataset_valid["density_low"],
            },
            "label": {"density_high": dataset_valid["density_high"]},
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": 1,
    }
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.MSELoss("mean"),
        {"density_high": lambda out: out["output_gen"]},
        metric={"metric": ppsci.metric.L2Rel()},
        name="sup_validator_gen",
    )

    # customized evalution
    def scale(data):
        smax = np.max(data)
        smin = np.min(data)
        return (data - smin) / (smax - smin)

    eval_mse_list = []
    eval_psnr_list = []
    eval_ssim_list = []
    for i, (input, label, _) in enumerate(sup_validator.data_loader):
        output_dict = model_list({"density_low": input["density_low"]})
        output_arr = scale(np.squeeze(output_dict["output_gen"].numpy()))
        target_arr = scale(np.squeeze(label["density_high"].numpy()))

        eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(
            target_arr, output_arr
        )
        eval_mse_list.append(eval_mse)
        eval_psnr_list.append(eval_psnr)
        eval_ssim_list.append(eval_ssim)

        if cfg.EVAL.save_outs:
            Img.imsave(
                osp.join(cfg.output_dir, "eval_outs", f"out_{i}.png"),
                output_arr,
                vmin=0.0,
                vmax=1.0,
                cmap="gray",
            )
    logger.message(
        f"MSE: {np.mean(eval_mse_list)}, PSNR: {np.mean(eval_psnr_list)}, SSIM: {np.mean(eval_ssim_list)}"
    )


def export(cfg: DictConfig):
    from paddle.static import InputSpec

    # set models
    gen_funcs = func_module.GenFuncs(cfg.WEIGHT_GEN, None)
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)

    # define model_list
    model_list = ppsci.arch.ModelList((model_gen,))

    # load pretrained model
    solver = ppsci.solver.Solver(
        model=model_list, pretrained_model_path=cfg.INFER.pretrained_model_path
    )

    # export models
    input_spec = [
        {"density_low": InputSpec([None, 1, 128, 128], "float32", name="density_low")},
    ]
    solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True)


def inference(cfg: DictConfig):
    from matplotlib import image as Img

    from deploy.python_infer import pinn_predictor

    # set model predictor
    predictor = pinn_predictor.PINNPredictor(cfg)

    # load dataset
    dataset_infer = {
        "density_low": hdf5storage.loadmat(cfg.DATASET_PATH_VALID)["density_low"]
    }

    output_dict = predictor.predict(dataset_infer, cfg.INFER.batch_size)

    # mapping data to cfg.INFER.output_keys
    output = [output_dict[key] for key in output_dict]

    def scale(data):
        smax = np.max(data)
        smin = np.min(data)
        return (data - smin) / (smax - smin)

    for i, img in enumerate(output[0]):
        img = scale(np.squeeze(img))
        Img.imsave(
            osp.join(cfg.output_dir, f"out_{i}.png"),
            img,
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )


@hydra.main(version_base=None, config_path="./conf", config_name="tempogan.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()
functions.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np
import paddle
import paddle.nn.functional as F
from matplotlib import image as Img
from PIL import Image
from skimage.metrics import mean_squared_error
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity

import ppsci
from ppsci.utils import logger


# train
def interpolate(
    data: paddle.Tensor, ratio: int, mode: str = "nearest"
) -> paddle.Tensor:
    """Interpolate twice.

    Args:
        data (paddle.Tensor): The data to be interpolated.
        ratio (int): Ratio of one interpolation.
        mode (str, optional): Interpolation method. Defaults to "nearest".

    Returns:
        paddle.Tensor: Data interpolated.
    """
    for _ in range(2):
        data = F.interpolate(
            data,
            [data.shape[-2] * ratio, data.shape[-1] * ratio],
            mode=mode,
        )
    return data


def reshape_input(input_dict: Dict[str, paddle.Tensor]) -> Dict[str, paddle.Tensor]:
    """Reshape input data for temporally Discriminator. Reshape data from N, C, W, H to N * C, 1, H, W.
        Which will merge N dimension and C dimension to 1 dimension but still keep 4 dimensions
        to ensure the data can be used for training.

    Args:
        input_dict (Dict[str, paddle.Tensor]): input data dict.

    Returns:
        Dict[str, paddle.Tensor]: reshaped data dict.
    """
    for key in input_dict:
        input = input_dict[key]
        N, C, H, W = input.shape
        input_dict[key] = paddle.reshape(input, [N * C, 1, H, W])
    return input_dict


def dereshape_input(
    input_dict: Dict[str, paddle.Tensor], C: int
) -> Dict[str, paddle.Tensor]:
    """Dereshape input data for temporally Discriminator. Deeshape data from N * C, 1, H, W to N, C, W, H.

    Args:
        input_dict (Dict[str, paddle.Tensor]): input data dict.
        C (int): Channel of dereshape.

    Returns:
        Dict[str, paddle.Tensor]: dereshaped data dict.
    """
    for key in input_dict:
        input = input_dict[key]
        N, _, H, W = input.shape
        if N < C:
            logger.warning(
                f"batch_size is smaller than {C}! Tempo needs at least {C} frames, input will be copied."
            )
            input_dict[key] = paddle.concat([input[:1]] * C, axis=1)
        else:
            N_new = int(N // C)
            input_dict[key] = paddle.reshape(input[: N_new * C], [-1, C, H, W])
    return input_dict


# predict
def split_data(data: np.ndarray, tile_ratio: int) -> np.ndarray:
    """Split a numpy image to tiles equally.

    Args:
        data (np.ndarray): The image to be Split.
        tile_ratio (int): How many tiles of one dim.
            Number of result tiles is tile_ratio * tile_ratio for a 2d image.

    Returns:
        np.ndarray: Tiles in [N,C,H,W] shape.
    """
    _, _, h, w = data.shape
    tile_h, tile_w = h // tile_ratio, w // tile_ratio
    tiles = []
    for i in range(tile_ratio):
        for j in range(tile_ratio):
            tiles.append(
                data[
                    :1,
                    :,
                    i * tile_h : i * tile_h + tile_h,
                    j * tile_w : j * tile_w + tile_w,
                ],
            )
    return np.concatenate(tiles, axis=0)


def concat_data(data: np.ndarray, tile_ratio: int) -> np.ndarray:
    """Concat numpy tiles to a image equally.

    Args:
        data (np.ndarray): The tiles to be upsplited.
        tile_ratio (int): How many tiles of one dim.
            Number of input tiles is tile_ratio * tile_ratio for 2d result.

    Returns:
        np.ndarray: Image in [H,W] shape.
    """
    _, _, tile_h, tile_w = data.shape
    h, w = tile_h * tile_ratio, tile_w * tile_ratio
    data_whole = np.ones([h, w], dtype=paddle.get_default_dtype())
    tile_idx = 0
    for i in range(tile_ratio):
        for j in range(tile_ratio):
            data_whole[
                i * tile_h : i * tile_h + tile_h,
                j * tile_w : j * tile_w + tile_w,
            ] = data[tile_idx][0]
            tile_idx += 1
    return data_whole


def predict_and_save_plot(
    output_dir: str,
    epoch_id: int,
    solver_gen: ppsci.solver.Solver,
    dataset_valid: np.ndarray,
    tile_ratio: int = 1,
):
    """Predicting and plotting.

    Args:
        output_dir (str): Output dir path.
        epoch_id (int): Which epoch it is.
        solver_gen (ppsci.solver.Solver): Solver for predicting.
        dataset_valid (np.ndarray): Valid dataset.
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
    """
    dir_pred = "predict/"
    os.makedirs(os.path.join(output_dir, dir_pred), exist_ok=True)

    start_idx = 190
    density_low = dataset_valid["density_low"][start_idx : start_idx + 3]
    density_high = dataset_valid["density_high"][start_idx : start_idx + 3]

    # tile
    density_low = (
        split_data(density_low, tile_ratio) if tile_ratio != 1 else density_low
    )
    density_high = (
        split_data(density_high, tile_ratio) if tile_ratio != 1 else density_high
    )

    pred_dict = solver_gen.predict(
        {
            "density_low": density_low,
            "density_high": density_high,
        },
        {"density_high": lambda out: out["output_gen"]},
        batch_size=tile_ratio * tile_ratio if tile_ratio != 1 else 3,
        no_grad=False,
    )
    if epoch_id == 1:
        # plot interpolated input image
        input_img = np.expand_dims(dataset_valid["density_low"][start_idx], axis=0)
        input_img = paddle.to_tensor(input_img, dtype=paddle.get_default_dtype())
        input_img = F.interpolate(
            input_img,
            [input_img.shape[-2] * 4, input_img.shape[-1] * 4],
            mode="nearest",
        ).numpy()
        Img.imsave(
            os.path.join(output_dir, dir_pred, "input.png"),
            np.squeeze(input_img),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
        # plot target image
        Img.imsave(
            os.path.join(output_dir, dir_pred, "target.png"),
            np.squeeze(dataset_valid["density_high"][start_idx]),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
    # plot pred image
    pred_img = (
        concat_data(pred_dict["density_high"].numpy(), tile_ratio)
        if tile_ratio != 1
        else np.squeeze(pred_dict["density_high"][0].numpy())
    )
    Img.imsave(
        os.path.join(output_dir, dir_pred, f"pred_epoch_{str(epoch_id)}.png"),
        pred_img,
        vmin=0.0,
        vmax=1.0,
        cmap="gray",
    )


# evaluation
def evaluate_img(
    img_target: np.ndarray, img_pred: np.ndarray
) -> Tuple[float, float, float]:
    """Evaluate two images.

    Args:
        img_target (np.ndarray): Target image.
        img_pred (np.ndarray): Image generated by prediction.

    Returns:
        Tuple[float, float, float]: MSE, PSNR, SSIM.
    """
    eval_mse = mean_squared_error(img_target, img_pred)
    eval_psnr = peak_signal_noise_ratio(img_target, img_pred)
    eval_ssim = structural_similarity(img_target, img_pred, data_range=1.0)
    return eval_mse, eval_psnr, eval_ssim


def get_image_array(img_path):
    return np.array(Image.open(img_path).convert("L"))


class GenFuncs:
    """All functions used for Generator, including functions of transform and loss.

    Args:
        weight_gen (List[float]): Weights of L1 loss.
        weight_gen_layer (List[float], optional): Weights of layers loss. Defaults to None.
    """

    def __init__(
        self, weight_gen: List[float], weight_gen_layer: List[float] = None
    ) -> None:
        self.weight_gen = weight_gen
        self.weight_gen_layer = weight_gen_layer

    def transform_in(self, _in):
        ratio = 2
        input_dict = reshape_input(_in)
        density_low = input_dict["density_low"]
        density_low_inp = interpolate(density_low, ratio, "nearest")
        return {"input_gen": density_low_inp}

    def loss_func_gen(self, output_dict: Dict, *args) -> paddle.Tensor:
        """Calculate loss of generator when use spatial discriminator.
            The loss consists of l1 loss, l2 loss and layer loss when use spatial discriminator.
            Notice that all item of loss is optional because weight of them might be 0.

        Args:
            output_dict (Dict): output dict of model.

        Returns:
            paddle.Tensor: Loss of generator.
        """
        # l1 loss
        loss_l1 = F.l1_loss(
            output_dict["output_gen"], output_dict["density_high"], "mean"
        )
        losses = loss_l1 * self.weight_gen[0]

        # l2 loss
        loss_l2 = F.mse_loss(
            output_dict["output_gen"], output_dict["density_high"], "mean"
        )
        losses += loss_l2 * self.weight_gen[1]

        if self.weight_gen_layer is not None:
            # disc(generator_out) loss
            out_disc_from_gen = output_dict["out_disc_from_gen"][-1]
            label_ones = paddle.ones_like(out_disc_from_gen)
            loss_gen = F.binary_cross_entropy_with_logits(
                out_disc_from_gen, label_ones, reduction="mean"
            )
            losses += loss_gen

            # layer loss
            key_list = list(output_dict.keys())
            # ["out0_layer0","out0_layer1","out0_layer2","out0_layer3","out_disc_from_target",
            # "out1_layer0","out1_layer1","out1_layer2","out1_layer3","out_disc_from_gen"]
            loss_layer = 0
            for i in range(1, len(self.weight_gen_layer)):
                # i = 0,1,2,3
                loss_layer += (
                    self.weight_gen_layer[i]
                    * F.mse_loss(
                        output_dict[key_list[i]],
                        output_dict[key_list[5 + i]],
                        reduction="sum",
                    )
                    / 2
                )
            losses += loss_layer * self.weight_gen_layer[0]

        return {"output_gen": losses}

    def loss_func_gen_tempo(self, output_dict: Dict, *args) -> paddle.Tensor:
        """Calculate loss of generator when use temporal discriminator.
            The loss is cross entropy loss when use temporal discriminator.

        Args:
            output_dict (Dict): output dict of model.

        Returns:
            paddle.Tensor: Loss of generator.
        """
        out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"][-1]
        label_t_ones = paddle.ones_like(out_disc_tempo_from_gen)

        loss_gen_t = F.binary_cross_entropy_with_logits(
            out_disc_tempo_from_gen, label_t_ones, reduction="mean"
        )
        losses = loss_gen_t * self.weight_gen[2]
        return {"out_disc_tempo_from_gen": losses}


class DiscFuncs:
    """All functions used for Discriminator and temporally Discriminator, including functions of transform and loss.

    Args:
        weight_disc (float): Weight of loss generated by the discriminator to judge the true target.
    """

    def __init__(self, weight_disc: float) -> None:
        self.weight_disc = weight_disc
        self.model_gen = None

    def transform_in(self, _in):
        ratio = 2
        input_dict = reshape_input(_in)
        density_low = input_dict["density_low"]
        density_high_from_target = input_dict["density_high"]

        density_low_inp = interpolate(density_low, ratio, "nearest")

        density_high_from_gen = self.model_gen(input_dict)["output_gen"]
        density_high_from_gen.stop_gradient = True

        density_input_from_target = paddle.concat(
            [density_low_inp, density_high_from_target], axis=1
        )
        density_input_from_gen = paddle.concat(
            [density_low_inp, density_high_from_gen], axis=1
        )
        return {
            "input_disc_from_target": density_input_from_target,
            "input_disc_from_gen": density_input_from_gen,
        }

    def transform_in_tempo(self, _in):
        density_high_from_target = _in["density_high"]

        input_dict = reshape_input(_in)
        density_high_from_gen = self.model_gen(input_dict)["output_gen"]
        density_high_from_gen.stop_gradient = True

        input_trans = {
            "input_tempo_disc_from_target": density_high_from_target,
            "input_tempo_disc_from_gen": density_high_from_gen,
        }

        return dereshape_input(input_trans, 3)

    def loss_func(self, output_dict, *args):
        out_disc_from_target = output_dict["out_disc_from_target"]
        out_disc_from_gen = output_dict["out_disc_from_gen"]

        label_ones = paddle.ones_like(out_disc_from_target)
        label_zeros = paddle.zeros_like(out_disc_from_gen)

        loss_disc_from_target = F.binary_cross_entropy_with_logits(
            out_disc_from_target, label_ones, reduction="mean"
        )
        loss_disc_from_gen = F.binary_cross_entropy_with_logits(
            out_disc_from_gen, label_zeros, reduction="mean"
        )
        losses = loss_disc_from_target * self.weight_disc + loss_disc_from_gen
        return losses

    def loss_func_tempo(self, output_dict, *args):
        out_disc_tempo_from_target = output_dict["out_disc_tempo_from_target"]
        out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"]

        label_ones = paddle.ones_like(out_disc_tempo_from_target)
        label_zeros = paddle.zeros_like(out_disc_tempo_from_gen)

        loss_disc_tempo_from_target = F.binary_cross_entropy_with_logits(
            out_disc_tempo_from_target, label_ones, reduction="mean"
        )
        loss_disc_tempo_from_gen = F.binary_cross_entropy_with_logits(
            out_disc_tempo_from_gen, label_zeros, reduction="mean"
        )
        losses = (
            loss_disc_tempo_from_target * self.weight_disc + loss_disc_tempo_from_gen
        )
        return losses


class DataFuncs:
    """All functions used for data transform.

    Args:
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
        density_min (float, optional): Minimize density of one tile. Defaults to 0.02.
        max_turn (int, optional): Maximize turn of taking a tile from one image. Defaults to 20.
    """

    def __init__(
        self, tile_ratio: int = 1, density_min: float = 0.02, max_turn: int = 20
    ) -> None:
        self.tile_ratio = tile_ratio
        self.density_min = density_min
        self.max_turn = max_turn

    def transform(
        self,
        input_item: Dict[str, np.ndarray],
        label_item: Dict[str, np.ndarray],
        weight_item: Dict[str, np.ndarray],
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        if self.tile_ratio == 1:
            return input_item, label_item, weight_item
        for _ in range(self.max_turn):
            rand_ratio = np.random.rand()
            density_low = self.cut_data(input_item["density_low"], rand_ratio)
            density_high = self.cut_data(input_item["density_high"], rand_ratio)
            if self.is_valid_tile(density_low):
                break

        input_item["density_low"] = density_low
        input_item["density_high"] = density_high
        return input_item, label_item, weight_item

    def cut_data(self, data: np.ndarray, rand_ratio: float) -> paddle.Tensor:
        # data: C,H,W
        _, H, W = data.shape
        if H % self.tile_ratio != 0 or W % self.tile_ratio != 0:
            exit(
                f"ERROR: input images cannot be divided into {self.tile_ratio} parts evenly!"
            )
        tile_shape = [H // self.tile_ratio, W // self.tile_ratio]
        rand_shape = np.floor(rand_ratio * (np.array([H, W]) - np.array(tile_shape)))
        start = [int(rand_shape[0]), int(rand_shape[1])]
        end = [int(rand_shape[0] + tile_shape[0]), int(rand_shape[1] + tile_shape[1])]
        data = paddle.slice(
            paddle.to_tensor(data), axes=[-2, -1], starts=start, ends=end
        )

        return data

    def is_valid_tile(self, tile: paddle.Tensor):
        img_density = tile[0].sum()
        return img_density >= (
            self.density_min * tile.shape[0] * tile.shape[1] * tile.shape[2]
        )
gan.py
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Dict
from typing import List
from typing import Tuple

import paddle
import paddle.nn as nn

from ppsci.arch import activation as act_mod
from ppsci.arch import base


class Conv2DBlock(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        stride,
        use_bn,
        act,
        mean,
        std,
        value,
    ):
        super().__init__()
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=mean, std=std)
        )
        bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(value=value))
        self.conv_2d = nn.Conv2D(
            in_channel,
            out_channel,
            kernel_size,
            stride,
            padding="SAME",
            weight_attr=weight_attr,
            bias_attr=bias_attr,
        )
        self.bn = nn.BatchNorm2D(out_channel) if use_bn else None
        self.act = act_mod.get_activation(act) if act else None

    def forward(self, x):
        y = x
        y = self.conv_2d(y)
        if self.bn:
            y = self.bn(y)
        if self.act:
            y = self.act(y)
        return y


class VariantResBlock(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channels,
        kernel_sizes,
        strides,
        use_bns,
        acts,
        mean,
        std,
        value,
    ):
        super().__init__()
        self.conv_2d_0 = Conv2DBlock(
            in_channel=in_channel,
            out_channel=out_channels[0],
            kernel_size=kernel_sizes[0],
            stride=strides[0],
            use_bn=use_bns[0],
            act=acts[0],
            mean=mean,
            std=std,
            value=value,
        )
        self.conv_2d_1 = Conv2DBlock(
            in_channel=out_channels[0],
            out_channel=out_channels[1],
            kernel_size=kernel_sizes[1],
            stride=strides[1],
            use_bn=use_bns[1],
            act=acts[1],
            mean=mean,
            std=std,
            value=value,
        )

        self.conv_2d_2 = Conv2DBlock(
            in_channel=in_channel,
            out_channel=out_channels[2],
            kernel_size=kernel_sizes[2],
            stride=strides[2],
            use_bn=use_bns[2],
            act=acts[2],
            mean=mean,
            std=std,
            value=value,
        )

        self.act = act_mod.get_activation("relu")

    def forward(self, x):
        y = x
        y = self.conv_2d_0(y)
        y = self.conv_2d_1(y)
        short = self.conv_2d_2(x)
        y = paddle.add(y, short)
        y = self.act(y)
        return y


class FCBlock(nn.Layer):
    def __init__(self, in_channel, act, mean, std, value):
        super().__init__()
        self.flatten = nn.Flatten()
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=mean, std=std)
        )
        bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(value=value))
        self.linear = nn.Linear(
            in_channel,
            1,
            weight_attr=weight_attr,
            bias_attr=bias_attr,
        )
        self.act = act_mod.get_activation(act) if act else None

    def forward(self, x):
        y = x
        y = self.flatten(y)
        y = self.linear(y)
        if self.act:
            y = self.act(y)
        return y


class Generator(base.Arch):
    """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is
        unique to "tempoGAN" example but not an open source network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int): Number of input channels of the first conv layer.
        out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers,
            such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]
        kernel_sizes_tuple (Tuple[Tuple[int, ...], ...]): Number of kernel_size of all conv layers,
            such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]]
        strides_tuple (Tuple[Tuple[int, ...], ...]): Number of stride of all conv layers,
            such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]]
        use_bns_tuple (Tuple[Tuple[bool, ...], ...]): Whether to use the batch_norm layer after each conv layer.
        acts_tuple (Tuple[Tuple[str, ...], ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]]

    Examples:
        >>> import ppsci
        >>> in_channel = 1
        >>> rb_channel0 = (2, 8, 8)
        >>> rb_channel1 = (128, 128, 128)
        >>> rb_channel2 = (32, 8, 8)
        >>> rb_channel3 = (2, 1, 1)
        >>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
        >>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4
        >>> strides_tuple = ((1, 1, 1), ) * 4
        >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
        >>> acts_tuple = (("relu", None, None), ) * 4
        >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
        >>> batch_size = 4
        >>> height = 64
        >>> width = 64
        >>> input_data = paddle.randn([batch_size, in_channel, height, width])
        >>> input_dict = {'in': input_data}
        >>> output_data = model(input_dict)
        >>> print(output_data['out'].shape)
        [4, 1, 64, 64]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels_tuple: Tuple[Tuple[int, ...], ...],
        kernel_sizes_tuple: Tuple[Tuple[int, ...], ...],
        strides_tuple: Tuple[Tuple[int, ...], ...],
        use_bns_tuple: Tuple[Tuple[bool, ...], ...],
        acts_tuple: Tuple[Tuple[str, ...], ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels_tuple = out_channels_tuple
        self.kernel_sizes_tuple = kernel_sizes_tuple
        self.strides_tuple = strides_tuple
        self.use_bns_tuple = use_bns_tuple
        self.acts_tuple = acts_tuple

        self.init_blocks()

    def init_blocks(self):
        blocks_list = []
        for i in range(len(self.out_channels_tuple)):
            in_channel = (
                self.in_channel if i == 0 else self.out_channels_tuple[i - 1][-1]
            )
            blocks_list.append(
                VariantResBlock(
                    in_channel=in_channel,
                    out_channels=self.out_channels_tuple[i],
                    kernel_sizes=self.kernel_sizes_tuple[i],
                    strides=self.strides_tuple[i],
                    use_bns=self.use_bns_tuple[i],
                    acts=self.acts_tuple[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )
        self.blocks = nn.LayerList(blocks_list)

    def forward_tensor(self, x):
        y = x
        for block in self.blocks:
            y = block(y)
        return y

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(y)
        y = self.split_to_dict(y, self.output_keys, axis=-1)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y


class Discriminator(base.Arch):
    """Discriminator Net of GAN.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int):  Number of input channels of the first conv layer.
        out_channels (Tuple[int, ...]): Number of output channels of all conv layers,
            such as (out_conv0, out_conv1, out_conv2).
        fc_channel (int):  Number of input features of linear layer. Number of output features of the layer
            is set to 1 in this Net to construct a fully_connected layer.
        kernel_sizes (Tuple[int, ...]): Number of kernel_size of all conv layers,
            such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2).
        strides (Tuple[int, ...]): Number of stride of all conv layers,
            such as (stride_conv0, stride_conv1, stride_conv2).
        use_bns (Tuple[bool, ...]): Whether to use the batch_norm layer after each conv layer.
        acts (Tuple[str, ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as (act_conv0, act_conv1, act_conv2).

    Examples:
        >>> import ppsci
        >>> in_channel = 2
        >>> in_channel_tempo = 3
        >>> out_channels = (32, 64, 128, 256)
        >>> fc_channel = 65536
        >>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4))
        >>> strides = (2, 2, 2, 1)
        >>> use_bns = (False, True, True, True)
        >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
        >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
        >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
        >>> input_data = [paddle.to_tensor(paddle.randn([1, in_channel, 128, 128])),paddle.to_tensor(paddle.randn([1, in_channel, 128, 128]))]
        >>> input_dict = {"in_1": input_data[0],"in_2": input_data[1]}
        >>> out_dict = model(input_dict)
        >>> for k, v in out_dict.items():
        ...     print(k, v.shape)
        out_1 [1, 32, 64, 64]
        out_2 [1, 64, 32, 32]
        out_3 [1, 128, 16, 16]
        out_4 [1, 256, 16, 16]
        out_5 [1, 1]
        out_6 [1, 32, 64, 64]
        out_7 [1, 64, 32, 32]
        out_8 [1, 128, 16, 16]
        out_9 [1, 256, 16, 16]
        out_10 [1, 1]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels: Tuple[int, ...],
        fc_channel: int,
        kernel_sizes: Tuple[int, ...],
        strides: Tuple[int, ...],
        use_bns: Tuple[bool, ...],
        acts: Tuple[str, ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels = out_channels
        self.fc_channel = fc_channel
        self.kernel_sizes = kernel_sizes
        self.strides = strides
        self.use_bns = use_bns
        self.acts = acts

        self.init_layers()

    def init_layers(self):
        layers_list = []
        for i in range(len(self.out_channels)):
            in_channel = self.in_channel if i == 0 else self.out_channels[i - 1]
            layers_list.append(
                Conv2DBlock(
                    in_channel=in_channel,
                    out_channel=self.out_channels[i],
                    kernel_size=self.kernel_sizes[i],
                    stride=self.strides[i],
                    use_bn=self.use_bns[i],
                    act=self.acts[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )

        layers_list.append(
            FCBlock(self.fc_channel, self.acts[4], mean=0.0, std=0.04, value=0.1)
        )
        self.layers = nn.LayerList(layers_list)

    def forward_tensor(self, x):
        y = x
        y_list = []
        for layer in self.layers:
            y = layer(y)
            y_list.append(y)
        return y_list  # y_conv1, y_conv2, y_conv3, y_conv4, y_fc(y_out)

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y_list = []
        # y1_conv1, y1_conv2, y1_conv3, y1_conv4, y1_fc, y2_conv1, y2_conv2, y2_conv3, y2_conv4, y2_fc
        for k in x:
            y_list.extend(self.forward_tensor(x[k]))

        y = self.split_to_dict(y_list, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)

        return y

    @staticmethod
    def split_to_dict(
        data_list: List[paddle.Tensor], keys: Tuple[str, ...]
    ) -> Dict[str, paddle.Tensor]:
        """Overwrite of split_to_dict() method belongs to Class base.Arch.

        Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example.
        That is because input in "tempoGAN" example is not in a regular format, but a format like:
        {
            "input1": paddle.concat([in1, in2], axis=1),
            "input2": paddle.concat([in1, in3], axis=1),
        }

        Args:
            data_list (List[paddle.Tensor]): The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.
            keys (Tuple[str, ...]): Keys of outputs.

        Returns:
            Dict[str, paddle.Tensor]: Dict with split data.
        """
        if len(keys) == 1:
            return {keys[0]: data_list[0]}
        return {key: data_list[i] for i, key in enumerate(keys)}

5. 结果展示

使用混合精度训练后,在测试集上评估与目标之间的 MSE、PSNR、SSIM,评估指标的值为:

MSE PSNR SSIM
4.21e-5 47.19 0.9974

一个流体超分样例的输入、模型预测结果、数据集介绍中开源代码包 mantaflow 直接生成的结果如下,模型预测结果与生成的目标结果基本一致。

input

输入的低密度流体

pred-amp02

混合精度训练后推理得到的高密度流体

target

目标高密度流体

6. 参考文献