跳转至

FourCastNet

AI Studio快速体验

开始训练、评估前,请先下载数据集

# 风速预训练模型
python train_pretrain.py
# 风速微调模型
python train_finetune.py
# 降水模型训练
python train_precip.py
# 风速预训练模型评估
python train_pretrain.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/pretrain.pdparams
# 风速微调模型评估
python train_finetune.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/finetune.pdparams
# 降水量模型评估
python train_precip.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/precip.pdparams WIND_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/finetune.pdparams
模型 变量名称 ACC/RMSE(6h) ACC/RMSE(30h) ACC/RMSE(60h) ACC/RMSE(120h) ACC/RMSE(192h)
风速模型 U10 0.991/0.567 0.963/1.130 0.891/1.930 0.645/3.438 0.371/4.915
模型 变量名称 ACC/RMSE(6h) ACC/RMSE(12h) ACC/RMSE(24h) ACC/RMSE(36h)
降水量模型 TP 0.808/1.390 0.760/1.540 0.668/1.690 0.590/1.920

1. 背景简介

在天气预报任务中,有基于物理信息驱动和数据驱动两种方法实现天气预报。基于物理信息驱动的方法,往往依赖物理方程,通过建模大气变量之间的物理关系实现天气预报。例如在 IFS 模型中,使用了分布在 50 多个垂直高度上共 150 多个大气变量实现天气的预测。基于数据驱动的方法不依赖物理方程,但是需要大量的训练数据,一般将神经网络看作一个黑盒结构,训练网络学习输入数据与输出数据之间的函数关系,实现给定输入条件下对于输出数据的预测。FourCastNet是一种基于数据驱动方法的气象预报算法,它使用自适应傅里叶神经算子(AFNO)进行训练和预测。该算法专注于预测两大气象变量:距离地球表面10米处的风速和6小时总降水量,以对极端天气、自然灾害等进行预警。相比于 IFS 模型,它仅仅使用了 5 个垂直高度上共 20 个大气变量,具有大气变量输入个数少,推理理速度快的特点。

2. 模型原理

本章节仅对 FourCastNet 的模型原理进行简单地介绍,详细的理论推导请阅读 FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators

FourCastNet 的网络模型使用了 AFNO 网络,该网络此前常用于图像分割任务。这个网络通过 FNO 弥补了 ViT 网络的缺点,使用傅立叶变换完成不同 token 信息交互,显著减少了高分辨率下 ViT 中 self-attention 的计算量。关于 AFNOFNOVIT 的相关原理也请阅读对应论文。

模型的总体结构如图所示:

fourcastnet-arch

FourCastNet 网络模型

FourCastNet论文中训练了风速模型和降水量模型,接下来将介绍这两个模型的训练、推理过程。

2.1 风速模型的训练、推理过程

模型的训练过程主要分为两个步骤:模型预训练、模型微调。

模型预训练阶段是基于随机初始化的网络权重对模型进行训练,如下图所示,其中 \(X(k)\) 表示第 \(k\) 时刻的大气数据,\(X(k+1)\) 表示第 \(k+1\) 时刻模型预测的大气数据,\(X_{true}(k+1)\) 表示第 \(k+1\) 时刻的真实大气数据。最后网络模型预测的输出和真值计算 L2 损失函数。

fourcastnet-pretraining

风速模型预训练

模型训练的第二个阶段是模型微调,这个阶段的训练主要是为了提高模型在中长期天气预报的精度。具体地,当模型输入 \(k\) 时刻的数据,预测了 \(k+1\) 时刻的数据后,再将其重新作为输入预测 \(k+2\) 时刻的数据,以连续预测两个时刻的训练方式,提高模型长时预测能力。

fourcastnet-finetuning

风速模型微调

在推理阶段,给定 \(k\) 时刻的数据,可以通过不断迭代,得到 \(k+1\)\(k+2\)\(k+3\) 等时刻的预测结果。

fourcastnet-inference

风速模型推理

2.2 降水量模型的训练、推理过程

降水量模型的训练依赖于风速模型,如下图所示,使用 \(k\) 时刻的大气变量数据 \(X(k)\) 输入训练好的风速模型,得到预测的 \(k+1\) 时刻的大气变量数据 \(X(k+1)\)。降水量模型以 \(X(k+1)\) 为输入,输出为 \(k+1\) 时刻的降水量预测结果 \(p(k+1)\)。模型训练时 \(p(k+1)\) 与真值数据 \(p_{true}(k+1)\) 计算 L2 损失函数约束网络训练。

precip-training

降水量模型训练

需要注意的是在降水量模型的训练过程中,风速模型的参数处于冻结状态,不参与优化器参数更新过程。

在推理阶段,给定 \(k\) 时刻的数据,可以通过不断迭代,利用风速模型得到 \(k+1\)\(k+2\)\(k+3\) 等时刻的大气变量预测结果,作为降水量模型的输入,预测对应时刻的降水量。

precip-inference

降水量模型推理

3. 风速模型实现

接下来开始讲解如何基于 PaddleScience 代码,实现 FourCastNet 风速模型的训练与推理。关于该案例中的其余细节请参考 API文档

Info

由于完整复现需要 5+TB 的存储空间和 64 卡的训练资源,因此如果仅仅是为了学习 FourCastNet 的算法原理,建议对一小部分训练数据集进行训练,以减小学习成本。

3.1 数据集介绍

数据集采用了 FourCastNet 中处理好的 ERA5 数据集。该数据集的分辨率大小为 0.25 度,每个变量的数据尺寸为 \(720 \times 1440\),其中单个数据点代表的实际距离为 30km 左右。FourCastNet 使用了 1979-2018 年的数据,根据年份划分为了训练集、验证集、测试集,划分结果如下:

数据集 年份
训练集 1979-2015
验证集 2016-2017
测试集 2018

该数据集可以从此处下载。

模型训练使用了分布在 5 个压力层上的 20 个大气变量,如下表所示,

fourcastnet-vars

20 个大气变量

其中 \(T\)\(U\)\(V\)\(Z\)\(RH\) 分别代表指定垂直高度上的温度、纬向风速、经向风速、位势和相对湿度;\(U_{10}\)\(V_{10}\)\(T_{2m}\) 则代表距离地面 10 米的纬向风速、经向风速和距离地面 2 米的温度。\(sp\) 代表地面气压,\(mslp\) 代表平均海平面气压。\(TCWV\) 代表整层气柱水汽总量。

对每天 24 个小时的数据间隔 6 小时采样,得到 0.00h/6.00h/12.00h/18.00h 时刻全球 20 个大气变量的数据,使用这样的数据进行模型的训练与推理。即输入0.00h 时刻的 20 个大气变量的数据,模型输出预测得到的 6.00h 时刻的 20 个大气变量的数据。

3.2 模型预训练

首先展示代码中定义的各个参数变量,每个参数的具体含义会在下面使用到时进行解释。

examples/fourcastnet/conf/fourcastnet_pretrain.yaml
# set training hyper-parameters
IMG_H: 720
IMG_W: 1440
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
# 'u850', 'v850', 'z850',  'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
# You can obtain detailed information about each variable from
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
VARS_CHANNEL: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
USE_SAMPLED_DATA: false

# set train data path
TRAIN_FILE_PATH: ./datasets/era5/train
DATA_MEAN_PATH: ./datasets/era5/stat/global_means.npy
DATA_STD_PATH: ./datasets/era5/stat/global_stds.npy
DATA_TIME_MEAN_PATH: ./datasets/era5/stat/time_means.npy

# set evaluate data path
VALID_FILE_PATH: ./datasets/era5/test

3.2.1 约束构建

本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 SupervisedConstraint 构建监督约束。在定义约束之前,需要首先指定监督约束中用于数据加载的各个参数,首先介绍数据预处理部分,代码如下:

examples/fourcastnet/train_pretrain.py
data_mean, data_std = fourcast_utils.get_mean_std(
    cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
)
data_time_mean = fourcast_utils.get_time_mean(
    cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
)
data_time_mean_normalize = np.expand_dims(
    (data_time_mean[0] - data_mean) / data_std, 0
)
# set train transforms
transforms = [
    {"SqueezeData": {}},
    {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
    {"Normalize": {"mean": data_mean, "std": data_std}},
]

数据预处理部分总共包含 3 个预处理方法,分别是:

  1. SqueezeData: 对训练数据的维度进行压缩,如果输入数据的维度为 4,则将第 0 维和第 1 维的数据压缩到一起,最终将输入数据的维度变换为 3。
  2. CropData: 从训练数据中裁剪指定位置的数据。因为 ERA5 数据集中的原始数据尺寸为 \(721 \times 1440\),本案例根据原始论文设置,将训练数据裁剪为 \(720 \times 1440\)
  3. Normalize: 根据训练数据集上的均值、方差对数据进行归一化处理。

由于完整复现 FourCastNet 需要 5TB+ 的存储空间和 64 卡的 GPU 资源,需要的存储资源比较多,因此有以下两种训练方式(实验证明两种训练方式的损失函数收敛曲线基本一致,当存储资源比较有限时,可以使用方式 b)。

方式 a: 当存储资源充足时,可以不对数据进行划分,每个节点都有一份完整5TB+的训练数据,然后直接启动训练程序进行训练,此时每个节点上的数据随机抽取自完整训练数据。本方式的训练数据的加载是使用全局 shuffle 的方式进行,如下图所示。

fourcastnet-vars

全局 shuffle

这种方式下,数据加载的代码如下:

examples/fourcastnet/train_pretrain.py
train_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.TRAIN_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": cfg.MODEL.afno.output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "transforms": transforms,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,
}

其中,"dataset" 字段定义了使用的 Dataset 类名为 ERA5Dataset,"sampler" 字段定义了使用的 Sampler 类名为 BatchSampler,设置的 batch_size 为 1,num_works 为 8。

方式 b:在存储资源有限时,需要将数据集均匀切分至每个节点上,本案例提供了随机采样数据的程序,可以执行 ppsci/fourcastnet/sample_data.py,可以根据需要进行修改。本案例默认使用方式 a, 因此使用方式 b 进行模型训练时需要手动将 USE_SAMPLED_DATA 设置为 True。本方式的训练数据的加载是使用局部 shuffle 的方式进行,如下图所示,首先将训练数据平均切分至 8 个节点上,训练时每个节点的数据随机抽取自被切分到的数据上,在这一情况下,每个节点需要约 1.2TB 的存储空间,相比于方式 a,方式 b 大大减小了对存储空间的依赖。

fourcastnet-vars

局部 shuffle

这种方式下,数据加载的代码如下:

examples/fourcastnet/train_pretrain.py
NUM_GPUS_PER_NODE = 8
train_dataloader_cfg = {
    "dataset": {
        "name": "ERA5SampledDataset",
        "file_path": cfg.TRAIN_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": cfg.MODEL.afno.output_keys,
    },
    "sampler": {
        "name": "DistributedBatchSampler",
        "drop_last": True,
        "shuffle": True,
        "num_replicas": NUM_GPUS_PER_NODE,
        "rank": dist.get_rank() % NUM_GPUS_PER_NODE,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,
}

其中,"dataset" 字段定义了使用的 Dataset 类名为 ERA5SampledDataset,"sampler" 字段定义了使用的 Sampler 类名为 DistributedBatchSampler,设置的 batch_size 为 1,num_works 为 8。

当不需要完整复现 FourCastNet 时,直接使用本案例的默认设置(方式 a)即可,

定义监督约束的代码如下:

examples/fourcastnet/train_pretrain.py
# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    name="Sup",
)
constraint = {sup_constraint.name: sup_constraint}

SupervisedConstraint 的第一个参数是数据的加载方式,这里使用上文中定义的 train_dataloader_cfg

第二个参数是损失函数的定义,这里使用 L2RelLoss

第三个参数是约束条件的名字,方便后续对其索引。此处命名为 "Sup"。

3.2.2 模型构建

在该案例中,风速模型基于 AFNONet 网络模型,用 PaddleScience 代码表示如下:

examples/fourcastnet/train_pretrain.py
# set model
model = ppsci.arch.AFNONet(**cfg.MODEL.afno)

网络模型的参数通过配置文件进行设置如下:

examples/fourcastnet/conf/fourcastnet_pretrain.yaml
# model settings
MODEL:
  afno:
    input_keys: ["input"]
    output_keys: ["output"]

其中,input_keysoutput_keys 分别代表网络模型输入、输出变量的名称。

3.2.3 学习率与优化器构建

本案例中使用的学习率方法为 Cosine,学习率大小设置为 5e-4。优化器使用 Adam,用 PaddleScience 代码表示如下:

examples/fourcastnet/train_pretrain.py
# init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()

optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

3.2.4 评估器构建

本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 SupervisedValidator 构建评估器。代码如下:

examples/fourcastnet/train_pretrain.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.VALID_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": cfg.MODEL.afno.output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "transforms": transforms,
        "training": False,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

# set validator
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric={
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H,
            std=data_std,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H,
            mean=data_time_mean_normalize,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

SupervisedValidator 评估器与 SupervisedConstraint 比较相似,不同的是评估器需要设置评价指标 metric,在这里使用了 3 个评价指标分别是 MAELatitudeWeightedRMSELatitudeWeightedACC

3.2.5 模型训练与评估

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练、评估。

examples/fourcastnet/train_pretrain.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    constraint,
    cfg.output_dir,
    optimizer,
    lr_scheduler,
    cfg.TRAIN.epochs,
    ITERS_PER_EPOCH,
    eval_during_train=True,
    seed=cfg.seed,
    validator=validator,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()
# evaluate after finished training
solver.eval()

3.3 模型微调

上文介绍了如何对风速模型进行预训练,在本节中将介绍如何利用预训练的模型进行微调。因为风速模型预训练的步骤与微调的步骤基本相似,因此本节在两者的重复部分不再介绍,而仅仅介绍模型微调特有的部分。首先将代码中定义的各个参数变量展示如下,每个参数的具体含义会在下面使用到时进行解释。

examples/fourcastnet/conf/fourcastnet_finetune.yaml
# set training hyper-parameters
IMG_H: 720
IMG_W: 1440
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
# 'u850', 'v850', 'z850',  'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
# You can obtain detailed information about each variable from
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
VARS_CHANNEL: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

# set train data path
TRAIN_FILE_PATH: ./datasets/era5/train
DATA_MEAN_PATH: ./datasets/era5/stat/global_means.npy
DATA_STD_PATH: ./datasets/era5/stat/global_stds.npy
DATA_TIME_MEAN_PATH: ./datasets/era5/stat/time_means.npy

# set evaluate data path
VALID_FILE_PATH: ./datasets/era5/test

# set test data path
TEST_FILE_PATH: ./datasets/era5/out_of_sample/2018.h5

微调模型的程序新增了 num_timestamps 参数,用于控制模型微调训练时迭代的时间步的个数。这个参数首先会在数据加载的设置中用到,用于设置数据集产生的真值的时间步大小,代码如下:

examples/fourcastnet/train_finetune.py
# set train dataloader config
train_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.TRAIN_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "num_label_timestamps": cfg.TRAIN.num_timestamps,
        "transforms": transforms,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,
}

num_timestamps 参数通过配置文件进行设置,如下:

examples/fourcastnet/conf/fourcastnet_finetune.yaml
num_timestamps: 2

另外,与预训练不同的是,微调的模型构建也需要设置 num_timestamps 参数,用于控制模型输出的预测结果的时间步大小,代码如下:

examples/fourcastnet/train_finetune.py
# set model
model_cfg = dict(cfg.MODEL.afno)
model_cfg.update(
    {"output_keys": output_keys, "num_timestamps": cfg.TRAIN.num_timestamps}
)

训练微调模型的程序中增加了在测试集上评估模型性能的代码和可视化代码,接下来将对这两部分进行详细介绍。

3.3.1 测试集上评估模型

根据论文中的设置,在测试集上进行模型评估时,num_timestamps 通过配置文件设置的为 32,相邻的两个测试样本的间隔为 8。

examples/fourcastnet/conf/fourcastnet_finetune.yaml
# evaluation settings
EVAL:
  num_timestamps: 32

构建模型的代码为:

examples/fourcastnet/train_finetune.py
# set model
model_cfg = dict(cfg.MODEL.afno)
model_cfg.update(
    {"output_keys": output_keys, "num_timestamps": cfg.EVAL.num_timestamps}
)
model = ppsci.arch.AFNONet(**model_cfg)

构建评估器的代码为:

examples/fourcastnet/train_finetune.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.TEST_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "transforms": transforms,
        "num_label_timestamps": cfg.EVAL.num_timestamps,
        "training": False,
        "stride": 8,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

# set metirc
metric = {
    "MAE": ppsci.metric.MAE(keep_batch=True),
    "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
        num_lat=cfg.IMG_H,
        std=data_std,
        keep_batch=True,
        variable_dict={"u10": 0, "v10": 1},
    ),
    "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
        num_lat=cfg.IMG_H,
        mean=data_time_mean_normalize,
        keep_batch=True,
        variable_dict={"u10": 0, "v10": 1},
    ),
}

# set validator for testing
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric=metric,
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

3.3.2 可视化器构建

风速模型使用自回归的方式进行推理,需要首先设置模型推理的输入数据,代码如下:

examples/fourcastnet/train_finetune.py
# set visualizer data
DATE_STRINGS = ("2018-09-08 00:00:00",)
vis_data = get_vis_data(
    cfg.TEST_FILE_PATH,
    DATE_STRINGS,
    cfg.EVAL.num_timestamps,
    cfg.VARS_CHANNEL,
    cfg.IMG_H,
    data_mean,
    data_std,
)
examples/fourcastnet/train_finetune.py
def get_vis_data(
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    _file = h5py.File(file_path, "r")["fields"]
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        data.append(_file[ic : ic + num_timestamps + 1, vars_channel, 0:img_h])
    data = np.asarray(data)

    vis_data = {"input": (data[:, 0] - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t + 1]
        wind_data = []
        for i in range(data_t.shape[0]):
            wind_data.append((data_t[i][0] ** 2 + data_t[i][1] ** 2) ** 0.5)
        vis_data[f"target_{hour}h"] = np.asarray(wind_data)
    return vis_data

以上的代码中会根据设置的时间参数 DATE_STRINGS 读取对应的数据用于模型的输入,另外 get_vis_datas 函数内还读取了对应时刻的真值数据,这些数据也将可视化出来,方便与模型的预测结果进行对比。

由于模型对风速的纬向和经向分开预测,因此需要把这两个方向上的风速合成为真正的风速,代码如下:

examples/fourcastnet/train_finetune.py
def output_wind_func(d, var_name, data_mean, data_std):
    output = (d[var_name] * data_std) + data_mean
    wind_data = []
    for i in range(output.shape[0]):
        wind_data.append((output[i][0] ** 2 + output[i][1] ** 2) ** 0.5)
    return paddle.to_tensor(wind_data, paddle.get_default_dtype())

vis_output_expr = {}
for i in range(cfg.EVAL.num_timestamps):
    hour = (i + 1) * 6
    vis_output_expr[f"output_{hour}h"] = functools.partial(
        output_wind_func,
        var_name=f"output_{i}",
        data_mean=paddle.to_tensor(data_mean, paddle.get_default_dtype()),
        data_std=paddle.to_tensor(data_std, paddle.get_default_dtype()),
    )
    vis_output_expr[f"target_{hour}h"] = lambda d, hour=hour: d[f"target_{hour}h"]

最后,构建可视化器的代码如下:

examples/fourcastnet/train_finetune.py
# set visualizer
visualizer = {
    "visualize_wind": ppsci.visualize.VisualizerWeather(
        vis_data,
        vis_output_expr,
        xticks=np.linspace(0, 1439, 13),
        xticklabels=[str(i) for i in range(360, -1, -30)],
        yticks=np.linspace(0, 719, 7),
        yticklabels=[str(i) for i in range(90, -91, -30)],
        vmin=0,
        vmax=25,
        colorbar_label="m\s",
        batch_size=cfg.EVAL.batch_size,
        num_timestamps=cfg.EVAL.num_timestamps,
        prefix="wind",
    )
}

以上构建好的模型、评估器、可视化器将会传递给 ppsci.solver.Solver 用于在测试集上评估性能和进行可视化。

examples/fourcastnet/train_finetune.py
solver = ppsci.solver.Solver(
    model,
    output_dir=cfg.output_dir,
    validator=validator,
    visualizer=visualizer,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
solver.eval()
# visualize prediction from pretrained_model_path
solver.visualize()

4. 降水量模型实现

首先展示代码中定义的各个参数变量,每个参数的具体含义会在下面使用到时进行解释。

examples/fourcastnet/conf/fourcastnet_precip.yaml
# set training hyper-parameters
IMG_H: 720
IMG_W: 1440
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
# 'u850', 'v850', 'z850',  'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
# You can obtain detailed information about each variable from
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
VARS_CHANNEL: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

# set train data path
WIND_TRAIN_FILE_PATH: ./datasets/era5/train
WIND_MEAN_PATH: ./datasets/era5/stat/global_means.npy
WIND_STD_PATH: ./datasets/era5/stat/global_stds.npy
WIND_TIME_MEAN_PATH: ./datasets/era5/stat/time_means.npy

TRAIN_FILE_PATH: ./datasets/era5/precip/train
TIME_MEAN_PATH: ./datasets/era5/stat/precip/time_means.npy

# set evaluate data path
WIND_VALID_FILE_PATH: ./datasets/era5/test
VALID_FILE_PATH: ./datasets/era5/precip/test

# set test data path
WIND_TEST_FILE_PATH: ./datasets/era5/out_of_sample/2018.h5
TEST_FILE_PATH: ./datasets/era5/precip/out_of_sample/2018.h5

# set wind model path
WIND_MODEL_PATH: outputs_fourcastnet_finetune/checkpoints/latest

4.1 约束构建

本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 SupervisedConstraint 构建监督约束。在定义约束之前,需要首先指定监督约束中用于数据加载的各个参数,首先介绍数据预处理部分,代码如下:

examples/fourcastnet/train_precip.py
wind_data_mean, wind_data_std = fourcast_utils.get_mean_std(
    cfg.WIND_MEAN_PATH, cfg.WIND_STD_PATH, cfg.VARS_CHANNEL
)
data_time_mean = fourcast_utils.get_time_mean(
    cfg.TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W
)

# set train transforms
transforms = [
    {"SqueezeData": {}},
    {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
    {
        "Normalize": {
            "mean": wind_data_mean,
            "std": wind_data_std,
            "apply_keys": ("input",),
        }
    },
    {"Log1p": {"scale": 1e-5, "apply_keys": ("label",)}},
]

数据预处理部分总共包含 4 个预处理方法,分别是:

  1. SqueezeData: 对训练数据的维度进行压缩,如果输入数据的维度为 4,则将第 0 维和第 1 维的数据压缩到一起,最终将输入数据的维度变换为 3。
  2. CropData: 从训练数据中裁剪指定位置的数据。因为 ERA5 数据集中的原始数据尺寸为 \(721 \times 1440\),本案例根据原始论文设置,将训练数据尺寸裁剪为 \(720 \times 1440\)
  3. Normalize: 根据训练数据集上的均值、方差对数据进行归一化处理,这里通过 apply_keys 字段设置了该预处理方法仅仅应用到输入数据上。
  4. Log1p: 将数据映射到对数空间,这里通过 apply_keys 字段设置了该预处理方法仅仅应用到真值数据上。

数据加载的代码如下:

examples/fourcastnet/train_precip.py
# set train dataloader config
train_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.WIND_TRAIN_FILE_PATH,
        "input_keys": cfg.MODEL.precip.input_keys,
        "label_keys": cfg.MODEL.precip.output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "precip_file_path": cfg.TRAIN_FILE_PATH,
        "transforms": transforms,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,
}

其中,"dataset" 字段定义了使用的 Dataset 类名为 ERA5Dataset,"sampler" 字段定义了使用的 Sampler 类名为 BatchSampler,设置的 batch_size 为 1,num_works 为 8。

定义监督约束的代码如下:

examples/fourcastnet/train_precip.py
# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    name="Sup",
)
constraint = {sup_constraint.name: sup_constraint}

SupervisedConstraint 的第一个参数是数据的加载方式,这里使用上文中定义的 train_dataloader_cfg

第二个参数是损失函数的定义,这里使用 L2RelLoss

第三个参数是约束条件的名字,方便后续对其索引。此处命名为 "Sup"。

4.2 模型构建

在该案例中,需要首先定义风速模型的网络结构并加载训练好的参数,然后定义降水量模型,用 PaddleScience 代码表示如下:

examples/fourcastnet/train_precip.py
# set model
wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
model_cfg = dict(cfg.MODEL.precip)
model_cfg.update({"wind_model": wind_model})
model = ppsci.arch.PrecipNet(**model_cfg)

定义模型的参数通过配置进行设置,如下:

examples/fourcastnet/conf/fourcastnet_precip.yaml
# model settings
MODEL:
  afno:
    input_keys: ["input"]
    output_keys: ["output"]
  precip:
    input_keys: ["input"]
    output_keys: ["output"]

其中,input_keysoutput_keys 分别代表网络模型输入、输出变量的名称。

4.3 学习率与优化器构建

本案例中使用的学习率方法为 Cosine,学习率大小设置为 2.5e-4。优化器使用 Adam,用 PaddleScience 代码表示如下:

examples/fourcastnet/train_precip.py
# init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

4.4 评估器构建

本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 SupervisedValidator 构建评估器。代码如下:

examples/fourcastnet/train_precip.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.WIND_VALID_FILE_PATH,
        "input_keys": cfg.MODEL.precip.input_keys,
        "label_keys": cfg.MODEL.precip.output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "precip_file_path": cfg.VALID_FILE_PATH,
        "transforms": transforms,
        "training": False,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

# set metric
metric = {
    "MAE": ppsci.metric.MAE(keep_batch=True),
    "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
        num_lat=cfg.IMG_H, keep_batch=True, unlog=True
    ),
    "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
        num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
    ),
}

# set validator
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric=metric,
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

SupervisedValidator 评估器与 SupervisedConstraint 比较相似,不同的是评估器需要设置评价指标 metric,在这里使用了 3 个评价指标分别是 MAELatitudeWeightedRMSELatitudeWeightedACC

4.5 模型训练与评估

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练、评估。

examples/fourcastnet/train_precip.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    constraint,
    cfg.output_dir,
    optimizer,
    lr_scheduler,
    cfg.TRAIN.epochs,
    ITERS_PER_EPOCH,
    eval_during_train=True,
    validator=validator,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()
# evaluate after finished training
solver.eval()

4.6 测试集上评估模型

根据论文中的设置,在测试集上进行模型评估时,num_timestamps 设置为 6,相邻的两个测试样本的间隔为 8。

构建模型的代码为:

examples/fourcastnet/train_precip.py
# set model for testing
wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
model_cfg = dict(cfg.MODEL.precip)
model_cfg.update(
    {
        "output_keys": output_keys,
        "num_timestamps": cfg.EVAL.num_timestamps,
        "wind_model": wind_model,
    }
)
model = ppsci.arch.PrecipNet(**model_cfg)

构建评估器的代码为:

examples/fourcastnet/train_precip.py
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.WIND_TEST_FILE_PATH,
        "input_keys": cfg.MODEL.precip.input_keys,
        "label_keys": output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "precip_file_path": cfg.TEST_FILE_PATH,
        "num_label_timestamps": cfg.EVAL.num_timestamps,
        "stride": 8,
        "transforms": transforms,
        "training": False,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}
# set metirc
metric = {
    "MAE": ppsci.metric.MAE(keep_batch=True),
    "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
        num_lat=cfg.IMG_H, keep_batch=True, unlog=True
    ),
    "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
        num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
    ),
}

# set validator for testing
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric=metric,
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

4.7 可视化器构建

降水量模型使用自回归的方式进行推理,需要首先设置模型推理的输入数据,代码如下:

examples/fourcastnet/train_precip.py
# set set visualizer data
DATE_STRINGS = ("2018-04-04 00:00:00",)
vis_data = get_vis_data(
    cfg.WIND_TEST_FILE_PATH,
    cfg.TEST_FILE_PATH,
    DATE_STRINGS,
    cfg.EVAL.num_timestamps,
    cfg.VARS_CHANNEL,
    cfg.IMG_H,
    wind_data_mean,
    wind_data_std,
)
examples/fourcastnet/train_precip.py
def get_vis_data(
    wind_file_path: str,
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    __wind_file = h5py.File(wind_file_path, "r")["fields"]
    _file = h5py.File(file_path, "r")["tp"]
    wind_data = []
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        wind_data.append(__wind_file[ic, vars_channel, 0:img_h])
        data.append(_file[ic + 1 : ic + num_timestamps + 1, 0:img_h])
    wind_data = np.asarray(wind_data)
    data = np.asarray(data)

    vis_data = {"input": (wind_data - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t]
        vis_data[f"target_{hour}h"] = np.asarray(data_t)
    return vis_data

以上的代码中会根据设置的时间参数 DATE_STRINGS 读取对应的数据用于模型的输入,另外 get_vis_datas 函数内还读取了对应时刻的真值数据,这些数据也将可视化出来,方便与模型的预测结果进行对比。

由于模型对降水量进行了对数处理,因此需要将模型结果重新映射回线性空间,代码如下:

examples/fourcastnet/train_precip.py
def output_precip_func(d, var_name):
    output = 1e-2 * paddle.expm1(d[var_name][0])
    return output

visu_output_expr = {}
for i in range(cfg.EVAL.num_timestamps):
    hour = (i + 1) * 6
    visu_output_expr[f"output_{hour}h"] = functools.partial(
        output_precip_func,
        var_name=f"output_{i}",
    )
    visu_output_expr[f"target_{hour}h"] = (
        lambda d, hour=hour: d[f"target_{hour}h"] * 1000
    )

最后,构建可视化器的代码如下:

examples/fourcastnet/train_precip.py
# set visualizer
visualizer = {
    "visualize_precip": ppsci.visualize.VisualizerWeather(
        vis_data,
        visu_output_expr,
        xticks=np.linspace(0, 1439, 13),
        xticklabels=[str(i) for i in range(360, -1, -30)],
        yticks=np.linspace(0, 719, 7),
        yticklabels=[str(i) for i in range(90, -91, -30)],
        vmin=0.001,
        vmax=130,
        colorbar_label="mm",
        log_norm=True,
        batch_size=cfg.EVAL.batch_size,
        num_timestamps=cfg.EVAL.num_timestamps,
        prefix="precip",
    )
}

以上构建好的模型、评估器、可视化器将会传递给 ppsci.solver.Solver 用于在测试集上评估性能和进行可视化。

examples/fourcastnet/train_precip.py
solver = ppsci.solver.Solver(
    model,
    output_dir=cfg.output_dir,
    validator=validator,
    visualizer=visualizer,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
solver.eval()
# visualize prediction
solver.visualize()

5. 完整代码

examples/fourcastnet/train_pretrain.py
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from os import path as osp

import hydra
import numpy as np
import paddle.distributed as dist
from omegaconf import DictConfig

import examples.fourcastnet.utils as fourcast_utils
import ppsci
from ppsci.utils import logger


def get_data_stat(cfg: DictConfig):
    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )
    return data_mean, data_std, data_time_mean_normalize


def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )
    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]

    # set train dataloader config
    if not cfg.USE_SAMPLED_DATA:
        train_dataloader_cfg = {
            "dataset": {
                "name": "ERA5Dataset",
                "file_path": cfg.TRAIN_FILE_PATH,
                "input_keys": cfg.MODEL.afno.input_keys,
                "label_keys": cfg.MODEL.afno.output_keys,
                "vars_channel": cfg.VARS_CHANNEL,
                "transforms": transforms,
            },
            "sampler": {
                "name": "BatchSampler",
                "drop_last": True,
                "shuffle": True,
            },
            "batch_size": cfg.TRAIN.batch_size,
            "num_workers": 8,
        }
    else:
        NUM_GPUS_PER_NODE = 8
        train_dataloader_cfg = {
            "dataset": {
                "name": "ERA5SampledDataset",
                "file_path": cfg.TRAIN_FILE_PATH,
                "input_keys": cfg.MODEL.afno.input_keys,
                "label_keys": cfg.MODEL.afno.output_keys,
            },
            "sampler": {
                "name": "DistributedBatchSampler",
                "drop_last": True,
                "shuffle": True,
                "num_replicas": NUM_GPUS_PER_NODE,
                "rank": dist.get_rank() % NUM_GPUS_PER_NODE,
            },
            "batch_size": cfg.TRAIN.batch_size,
            "num_workers": 8,
        }
    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": cfg.MODEL.afno.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric={
            "MAE": ppsci.metric.MAE(keep_batch=True),
            "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
                num_lat=cfg.IMG_H,
                std=data_std,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
            "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
                num_lat=cfg.IMG_H,
                mean=data_time_mean_normalize,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model = ppsci.arch.AFNONet(**cfg.MODEL.afno)

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()

    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=True,
        seed=cfg.seed,
        validator=validator,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )
    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": cfg.MODEL.afno.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric={
            "MAE": ppsci.metric.MAE(keep_batch=True),
            "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
                num_lat=cfg.IMG_H,
                std=data_std,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
            "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
                num_lat=cfg.IMG_H,
                mean=data_time_mean_normalize,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model = ppsci.arch.AFNONet(**cfg.MODEL.afno)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        log_freq=cfg.log_freq,
        seed=cfg.seed,
        validator=validator,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # evaluate
    solver.eval()


@hydra.main(
    version_base=None, config_path="./conf", config_name="fourcastnet_pretrain.yaml"
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()
examples/fourcastnet/train_finetune.py
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from os import path as osp
from typing import Tuple

import h5py
import hydra
import numpy as np
import paddle
from omegaconf import DictConfig

import examples.fourcastnet.utils as fourcast_utils
import ppsci
from ppsci.utils import logger


def get_vis_data(
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    _file = h5py.File(file_path, "r")["fields"]
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        data.append(_file[ic : ic + num_timestamps + 1, vars_channel, 0:img_h])
    data = np.asarray(data)

    vis_data = {"input": (data[:, 0] - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t + 1]
        wind_data = []
        for i in range(data_t.shape[0]):
            wind_data.append((data_t[i][0] ** 2 + data_t[i][1] ** 2) ** 0.5)
        vis_data[f"target_{hour}h"] = np.asarray(wind_data)
    return vis_data


def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.set_random_seed(cfg.seed)

    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")

    # set training hyper-parameters
    output_keys = tuple(f"output_{i}" for i in range(cfg.TRAIN.num_timestamps))

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )

    # set transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]
    # set train dataloader config
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "num_label_timestamps": cfg.TRAIN.num_timestamps,
            "transforms": transforms,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,
    }
    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "num_label_timestamps": cfg.TRAIN.num_timestamps,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set metric
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H,
            std=data_std,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H,
            mean=data_time_mean_normalize,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model_cfg = dict(cfg.MODEL.afno)
    model_cfg.update(
        {"output_keys": output_keys, "num_timestamps": cfg.TRAIN.num_timestamps}
    )

    model = ppsci.arch.AFNONet(**model_cfg)

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()
    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=True,
        validator=validator,
        pretrained_model_path=cfg.TRAIN.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    # set testing hyper-parameters
    output_keys = tuple(f"output_{i}" for i in range(cfg.EVAL.num_timestamps))

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )

    # set transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]

    # set model
    model_cfg = dict(cfg.MODEL.afno)
    model_cfg.update(
        {"output_keys": output_keys, "num_timestamps": cfg.EVAL.num_timestamps}
    )
    model = ppsci.arch.AFNONet(**model_cfg)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.TEST_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "num_label_timestamps": cfg.EVAL.num_timestamps,
            "training": False,
            "stride": 8,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set metirc
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H,
            std=data_std,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H,
            mean=data_time_mean_normalize,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
    }

    # set validator for testing
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set visualizer data
    DATE_STRINGS = ("2018-09-08 00:00:00",)
    vis_data = get_vis_data(
        cfg.TEST_FILE_PATH,
        DATE_STRINGS,
        cfg.EVAL.num_timestamps,
        cfg.VARS_CHANNEL,
        cfg.IMG_H,
        data_mean,
        data_std,
    )

    def output_wind_func(d, var_name, data_mean, data_std):
        output = (d[var_name] * data_std) + data_mean
        wind_data = []
        for i in range(output.shape[0]):
            wind_data.append((output[i][0] ** 2 + output[i][1] ** 2) ** 0.5)
        return paddle.to_tensor(wind_data, paddle.get_default_dtype())

    vis_output_expr = {}
    for i in range(cfg.EVAL.num_timestamps):
        hour = (i + 1) * 6
        vis_output_expr[f"output_{hour}h"] = functools.partial(
            output_wind_func,
            var_name=f"output_{i}",
            data_mean=paddle.to_tensor(data_mean, paddle.get_default_dtype()),
            data_std=paddle.to_tensor(data_std, paddle.get_default_dtype()),
        )
        vis_output_expr[f"target_{hour}h"] = lambda d, hour=hour: d[f"target_{hour}h"]
    # set visualizer
    visualizer = {
        "visualize_wind": ppsci.visualize.VisualizerWeather(
            vis_data,
            vis_output_expr,
            xticks=np.linspace(0, 1439, 13),
            xticklabels=[str(i) for i in range(360, -1, -30)],
            yticks=np.linspace(0, 719, 7),
            yticklabels=[str(i) for i in range(90, -91, -30)],
            vmin=0,
            vmax=25,
            colorbar_label="m\s",
            batch_size=cfg.EVAL.batch_size,
            num_timestamps=cfg.EVAL.num_timestamps,
            prefix="wind",
        )
    }

    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        validator=validator,
        visualizer=visualizer,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    solver.eval()
    # visualize prediction from pretrained_model_path
    solver.visualize()


@hydra.main(
    version_base=None, config_path="./conf", config_name="fourcastnet_finetune.yaml"
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()
examples/fourcastnet/train_precip.py
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os.path as osp
from typing import Tuple

import h5py
import hydra
import numpy as np
import paddle
from omegaconf import DictConfig

import examples.fourcastnet.utils as fourcast_utils
import ppsci
from ppsci.utils import logger


def get_vis_data(
    wind_file_path: str,
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    __wind_file = h5py.File(wind_file_path, "r")["fields"]
    _file = h5py.File(file_path, "r")["tp"]
    wind_data = []
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        wind_data.append(__wind_file[ic, vars_channel, 0:img_h])
        data.append(_file[ic + 1 : ic + num_timestamps + 1, 0:img_h])
    wind_data = np.asarray(wind_data)
    data = np.asarray(data)

    vis_data = {"input": (wind_data - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t]
        vis_data[f"target_{hour}h"] = np.asarray(data_t)
    return vis_data


def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", f"{cfg.output_dir}/train.log", "info")

    wind_data_mean, wind_data_std = fourcast_utils.get_mean_std(
        cfg.WIND_MEAN_PATH, cfg.WIND_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W
    )

    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {
            "Normalize": {
                "mean": wind_data_mean,
                "std": wind_data_std,
                "apply_keys": ("input",),
            }
        },
        {"Log1p": {"scale": 1e-5, "apply_keys": ("label",)}},
    ]

    # set train dataloader config
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.WIND_TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.precip.input_keys,
            "label_keys": cfg.MODEL.precip.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "precip_file_path": cfg.TRAIN_FILE_PATH,
            "transforms": transforms,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,
    }
    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.WIND_VALID_FILE_PATH,
            "input_keys": cfg.MODEL.precip.input_keys,
            "label_keys": cfg.MODEL.precip.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "precip_file_path": cfg.VALID_FILE_PATH,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set metric
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H, keep_batch=True, unlog=True
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
        ),
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
    ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
    model_cfg = dict(cfg.MODEL.precip)
    model_cfg.update({"wind_model": wind_model})
    model = ppsci.arch.PrecipNet(**model_cfg)

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()
    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=True,
        validator=validator,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    # set testing hyper-parameters
    output_keys = tuple(f"output_{i}" for i in range(cfg.EVAL.num_timestamps))

    # set model for testing
    wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
    ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
    model_cfg = dict(cfg.MODEL.precip)
    model_cfg.update(
        {
            "output_keys": output_keys,
            "num_timestamps": cfg.EVAL.num_timestamps,
            "wind_model": wind_model,
        }
    )
    model = ppsci.arch.PrecipNet(**model_cfg)

    wind_data_mean, wind_data_std = fourcast_utils.get_mean_std(
        cfg.WIND_MEAN_PATH, cfg.WIND_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W
    )

    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {
            "Normalize": {
                "mean": wind_data_mean,
                "std": wind_data_std,
                "apply_keys": ("input",),
            }
        },
        {"Log1p": {"scale": 1e-5, "apply_keys": ("label",)}},
    ]

    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.WIND_TEST_FILE_PATH,
            "input_keys": cfg.MODEL.precip.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "precip_file_path": cfg.TEST_FILE_PATH,
            "num_label_timestamps": cfg.EVAL.num_timestamps,
            "stride": 8,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }
    # set metirc
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H, keep_batch=True, unlog=True
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
        ),
    }

    # set validator for testing
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set set visualizer data
    DATE_STRINGS = ("2018-04-04 00:00:00",)
    vis_data = get_vis_data(
        cfg.WIND_TEST_FILE_PATH,
        cfg.TEST_FILE_PATH,
        DATE_STRINGS,
        cfg.EVAL.num_timestamps,
        cfg.VARS_CHANNEL,
        cfg.IMG_H,
        wind_data_mean,
        wind_data_std,
    )

    def output_precip_func(d, var_name):
        output = 1e-2 * paddle.expm1(d[var_name][0])
        return output

    visu_output_expr = {}
    for i in range(cfg.EVAL.num_timestamps):
        hour = (i + 1) * 6
        visu_output_expr[f"output_{hour}h"] = functools.partial(
            output_precip_func,
            var_name=f"output_{i}",
        )
        visu_output_expr[f"target_{hour}h"] = (
            lambda d, hour=hour: d[f"target_{hour}h"] * 1000
        )
    # set visualizer
    visualizer = {
        "visualize_precip": ppsci.visualize.VisualizerWeather(
            vis_data,
            visu_output_expr,
            xticks=np.linspace(0, 1439, 13),
            xticklabels=[str(i) for i in range(360, -1, -30)],
            yticks=np.linspace(0, 719, 7),
            yticklabels=[str(i) for i in range(90, -91, -30)],
            vmin=0.001,
            vmax=130,
            colorbar_label="mm",
            log_norm=True,
            batch_size=cfg.EVAL.batch_size,
            num_timestamps=cfg.EVAL.num_timestamps,
            prefix="precip",
        )
    }

    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        validator=validator,
        visualizer=visualizer,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    solver.eval()
    # visualize prediction
    solver.visualize()


@hydra.main(
    version_base=None, config_path="./conf", config_name="fourcastnet_precip.yaml"
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

6. 结果展示

下图展示了风速模型按照6小时间隔的预测结果和真值结果。

result_wind

风速模型预测结果("output")与真值结果("target")

下图展示了降水量模型按照6小时间隔的预测结果和真值结果。

result_precip

降水量模型预测结果("output")与真值结果("target")