跳转至

EarthFormer

开始训练、评估前,请先下载

ICAR-ENSO数据集

SEVIR数据集

# ICAR-ENSO 数据预训练模型
python examples/earthformer/earthformer_enso_train.py
# SEVIR 数据集预训练模型
python examples/earthformer/earthformer_sevir_train.py
# ICAR-ENSO 模型评估
python examples/earthformer/earthformer_enso_train.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/earthformer/earthformer_enso.pdparams
# SEVIR 模型评估
python examples/earthformer/earthformer_sevir_train.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/earthformer/earthformer_sevir.pdparams
# ICAR-ENSO 模型推理
python examples/earthformer/earthformer_enso_train.py mode=export
# SEVIR 模型推理
python examples/earthformer/earthformer_sevir_train.py mode=export
# ICAR-ENSO 模型推理
python examples/earthformer/earthformer_enso_train.py mode=infer
# SEVIR 模型推理
python examples/earthformer/earthformer_sevir_train.py mode=infer
模型 变量名称 C-Nino3.4-M C-Nino3.4-WM MSE(1E-4)
ENSO 模型 sst 0.74130 2.28990 2.5000
模型 变量名称 CSI-M CSI-219 CSI-181 CSI-160 CSI-133 CSI-74 CSI-16 MSE(1E-4)
SEVIR 模型 vil 0.4419 0.1791 0.2848 0.3232 0.4271 0.6860 0.7513 3.6957

1. 背景简介

地球是一个复杂的系统。地球系统的变化,从温度波动等常规事件到干旱、冰雹和厄尔尼诺/南方涛动 (ENSO) 等极端事件,影响着我们的日常生活。在所有后果中,地球系统的变化会影响农作物产量、航班延误、引发洪水和森林火灾。对这些变化进行准确及时的预测可以帮助人们采取必要的预防措施以避免危机,或者更好地利用风能和太阳能等自然资源。因此,改进地球变化(例如天气和气候)的预测模型具有巨大的社会经济影响。

Earthformer,一种用于地球系统预测的时空转换器。为了更好地探索时空注意力的设计,论文提出了 Cuboid Attention ,它是高效时空注意力的通用构建块。这个想法是将输入张量分解为不重叠的长方体,并行应用长方体级自注意力。由于我们将 O(N2) 自注意力限制在局部长方体内,因此整体复杂度大大降低。不同类型的相关性可以通过不同的长方体分解来捕获。同时论文引入了一组关注所有局部长方体的全局向量,从而收集系统的整体状态。通过关注全局向量,局部长方体可以掌握系统的总体动态并相互共享信息。

2. 模型原理

本章节仅对 EarthFormer 的模型原理进行简单地介绍,详细的理论推导请阅读 Earthformer: Exploring Space-Time Transformers for Earth System Forecasting

Earthformer 的网络模型使用了基于 Cuboid Attention 的分层 Transformer incoder-decoder 。这个想法是将数据分解为长方体并并行应用长方体级自注意力。这些长方体进一步与全局向量的集合连接。

模型的总体结构如图所示:

Earthformer-arch

EarthFormer 网络模型

EarthFormer 原代码中训练了 ICAR-ENSO 数据集中海面温度 (sst) 和 SEVIR 数据集中对云总降水量 (vil) 的估计模型,接下来将介绍这两个模型的训练、推理过程。

2.1 ICAR-ENSO 和 SEVIR 模型的训练、推理过程

模型预训练阶段是基于随机初始化的网络权重对模型进行训练,如下图所示,其中 \([x_{i}]_{i=1}^{T}\) 表示长度为 \(T\) 时空序列的输入气象数据,\([y_{i}]_{i=1}^{K}\) 表示预测未来 \(K\) 步的气象数据,\([y_{i_True}]_{i=1}^{K}\) 表示未来 \(K\) 步的真实数据,如海面温度数据和云总降水量数据。最后网络模型预测的输出和真值计算 mse 损失函数。

earthformer-pretraining

earthformer 模型预训练

在推理阶段,给定长度序列为 \(T\) 的数据,得到长度序列为 \(K\) 的预测结果。

earthformer-pretraining

earthformer 模型推理

3. 海面温度模型实现

接下来开始讲解如何基于 PaddleScience 代码,实现 EarthFormer 模型的训练与推理。关于该案例中的其余细节请参考 API文档

3.1 数据集介绍

数据集采用了 EarthFormer 处理好的 ICAR-ENSO 数据集。

本数据集由气候与应用前沿研究院 ICAR 提供。数据包括 CMIP5/6 模式的历史模拟数据和美国 SODA 模式重建的近100多年历史观测同化数据。每个样本包含以下气象及时空变量:海表温度异常 (SST) ,热含量异常 (T300),纬向风异常 (Ua),经向风异常 (Va),数据维度为 (year,month,lat,lon)。训练数据提供对应月份的 Nino3.4 index 标签数据。测试用的初始场数据为国际多个海洋资料同化结果提供的随机抽取的 n 段 12 个时间序列,数据格式采用 NPY 格式保存。

训练数据:

每个数据样本第一维度 (year) 表征数据所对应起始年份,对于 CMIP 数据共 291 年,其中 1-2265 为 CMIP6 中 15 个模式提供的 151 年的历史模拟数据 (总共:151年 15 个模式=2265) ;2266-4645 为 CMIP5 中 17 个模式提供的 140 年的历史模拟数据 (总共:140 年17 个模式=2380)。对于历史观测同化数据为美国提供的 SODA 数据。

训练数据标签

标签数据为 Nino3.4 SST 异常指数,数据维度为 (year,month)。

CMIP(SODA)_train.nc 对应的标签数据当前时刻 Nino3.4 SST 异常指数的三个月滑动平均值,因此数据维度与维度介绍同训练数据一致。

注:三个月滑动平均值为当前月与未来两个月的平均值。

测试数据

测试用的初始场 (输入) 数据为国际多个海洋资料同化结果提供的随机抽取的 n 段 12 个时间序列,数据格式采用NPY格式保存,维度为 (12,lat,lon, 4), 12 为 t 时刻及过去 11 个时刻,4 为预测因子,并按照 SST,T300,Ua,Va 的顺序存放。

EarthFFormer 模型对于 ICAR-ENSO 数据集的训练中,只对其中海面温度 (SST) 进行训练和预测。训练海温异常观测的 12 步 (一年) ,预测海温异常最多 14 步。

3.2 模型预训练

3.2.1 约束构建

本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 SupervisedConstraint 构建监督约束。在定义约束之前,需要首先指定监督约束中用于数据加载的各个参数。

数据加载的代码如下:

examples/earthformer/earthformer_enso_train.py
train_dataloader_cfg = {
    "dataset": {
        "name": "ENSODataset",
        "data_dir": cfg.FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.DATASET.label_keys,
        "in_len": cfg.DATASET.in_len,
        "out_len": cfg.DATASET.out_len,
        "in_stride": cfg.DATASET.in_stride,
        "out_stride": cfg.DATASET.out_stride,
        "train_samples_gap": cfg.DATASET.train_samples_gap,
        "eval_samples_gap": cfg.DATASET.eval_samples_gap,
        "normalize_sst": cfg.DATASET.normalize_sst,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,
}

其中,"dataset" 字段定义了使用的 Dataset 类名为 ENSODataset,"sampler" 字段定义了使用的 Sampler 类名为 BatchSampler,设置的 batch_size 为 16,num_works 为 8。

定义监督约束的代码如下:

examples/earthformer/earthformer_enso_train.py
# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func),
    name="Sup",
)
constraint = {sup_constraint.name: sup_constraint}

SupervisedConstraint 的第一个参数是数据的加载方式,这里使用上文中定义的 train_dataloader_cfg

第二个参数是损失函数的定义,这里使用自定义的损失函数 mse_loss

第三个参数是约束条件的名字,方便后续对其索引。此处命名为 Sup

3.2.2 模型构建

在该案例中,海面温度模型基于 CuboidTransformer 网络模型实现,用 PaddleScience 代码表示如下:

examples/earthformer/earthformer_enso_train.py
model = ppsci.arch.CuboidTransformer(
    **cfg.MODEL,
)

网络模型的参数通过配置文件进行设置如下:

examples/earthformer/conf/earthformer_enso_pretrain.yaml
# model settings
MODEL:
  input_keys: ["sst_data"]
  output_keys: ["sst_target","nino_target"]
  input_shape: [12, 24, 48, 1]
  target_shape: [14, 24, 48, 1]
  base_units: 64
  scale_alpha: 1.0

  enc_depth: [1, 1]
  dec_depth: [1, 1]
  enc_use_inter_ffn: true
  dec_use_inter_ffn: true
  dec_hierarchical_pos_embed: false

  downsample: 2
  downsample_type: "patch_merge"
  upsample_type: "upsample"

  num_global_vectors: 0
  use_dec_self_global: false
  dec_self_update_global: true
  use_dec_cross_global: false
  use_global_vector_ffn: false
  use_global_self_attn: false
  separate_global_qkv: false
  global_dim_ratio: 1

  self_pattern: "axial"
  cross_self_pattern: "axial"
  cross_pattern: "cross_1x1"
  dec_cross_last_n_frames: null

  attn_drop: 0.1
  proj_drop: 0.1
  ffn_drop: 0.1
  num_heads: 4

  ffn_activation: "gelu"
  gated_ffn: false
  norm_layer: "layer_norm"
  padding_type: "zeros"
  pos_embed_type: "t+h+w"
  use_relative_pos: true
  self_attn_use_final_proj: true
  dec_use_first_self_attn: false

  z_init_method: "zeros"
  initial_downsample_type: "conv"
  initial_downsample_activation: "leaky_relu"
  initial_downsample_scale: [1, 1, 2]
  initial_downsample_conv_layers: 2
  final_upsample_conv_layers: 1
  checkpoint_level: 2

  attn_linear_init_mode: "0"
  ffn_linear_init_mode: "0"
  conv_init_mode: "0"
  down_up_linear_init_mode: "0"
  norm_init_mode: "0"

其中,input_keysoutput_keys 分别代表网络模型输入、输出变量的名称。

3.2.3 学习率与优化器构建

本案例中使用的学习率方法为 Cosine,学习率大小设置为 2e-4。优化器使用 AdamW,并将参数进行分组,使用不同的 weight_decay,用 PaddleScience 代码表示如下:

examples/earthformer/earthformer_enso_train.py
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if n in decay_parameters],
        "weight_decay": cfg.TRAIN.wd,
    },
    {
        "params": [
            p for n, p in model.named_parameters() if n not in decay_parameters
        ],
        "weight_decay": 0.0,
    },
]

# # init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
    **lr_scheduler_cfg,
    iters_per_epoch=ITERS_PER_EPOCH,
    eta_min=cfg.TRAIN.min_lr_ratio * cfg.TRAIN.lr_scheduler.learning_rate,
    warmup_epoch=int(0.2 * cfg.TRAIN.epochs),
)()
optimizer = paddle.optimizer.AdamW(
    lr_scheduler, parameters=optimizer_grouped_parameters
)

3.2.4 评估器构建

本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 SupervisedValidator 构建评估器。代码如下:

examples/earthformer/earthformer_enso_train.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ENSODataset",
        "data_dir": cfg.FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.DATASET.label_keys,
        "in_len": cfg.DATASET.in_len,
        "out_len": cfg.DATASET.out_len,
        "in_stride": cfg.DATASET.in_stride,
        "out_stride": cfg.DATASET.out_stride,
        "train_samples_gap": cfg.DATASET.train_samples_gap,
        "eval_samples_gap": cfg.DATASET.eval_samples_gap,
        "normalize_sst": cfg.DATASET.normalize_sst,
        "training": "eval",
    },
    "batch_size": cfg.EVAL.batch_size,
}

sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func),
    metric={
        "rmse": ppsci.metric.FunctionalMetric(enso_metric.eval_rmse_func),
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

SupervisedValidator 评估器与 SupervisedConstraint 比较相似,不同的是评估器需要设置评价指标 metric,在这里使用了自定义的评价指标分别是 MAEMSERMSEcorr_nino3.4_epochcorr_nino3.4_weighted_epoch

3.2.5 模型训练与评估

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练、评估。

examples/earthformer/earthformer_enso_train.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    constraint,
    cfg.output_dir,
    optimizer,
    lr_scheduler,
    cfg.TRAIN.epochs,
    ITERS_PER_EPOCH,
    eval_during_train=cfg.TRAIN.eval_during_train,
    seed=cfg.seed,
    validator=validator,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()
# evaluate after finished training
solver.eval()

3.3 模型评估可视化

3.3.1 测试集上评估模型

构建模型的代码为:

examples/earthformer/earthformer_enso_train.py
model = ppsci.arch.CuboidTransformer(
    **cfg.MODEL,
)

构建评估器的代码为:

examples/earthformer/earthformer_enso_train.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ENSODataset",
        "data_dir": cfg.FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.DATASET.label_keys,
        "in_len": cfg.DATASET.in_len,
        "out_len": cfg.DATASET.out_len,
        "in_stride": cfg.DATASET.in_stride,
        "out_stride": cfg.DATASET.out_stride,
        "train_samples_gap": cfg.DATASET.train_samples_gap,
        "eval_samples_gap": cfg.DATASET.eval_samples_gap,
        "normalize_sst": cfg.DATASET.normalize_sst,
        "training": "test",
    },
    "batch_size": cfg.EVAL.batch_size,
}

sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func),
    metric={
        "rmse": ppsci.metric.FunctionalMetric(enso_metric.eval_rmse_func),
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

3.3.2 模型导出

构建模型的代码为:

examples/earthformer/earthformer_enso_train.py
# set model
model = ppsci.arch.CuboidTransformer(
    **cfg.MODEL,
)

实例化 ppsci.solver.Solver

examples/earthformer/earthformer_enso_train.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    pretrained_model_path=cfg.INFER.pretrained_model_path,
)

构建模型输入格式并导出静态模型:

examples/earthformer/earthformer_enso_train.py
input_spec = [
    {
        key: InputSpec([1, 12, 24, 48, 1], "float32", name=key)
        for key in model.input_keys
    },
]
solver.export(input_spec, cfg.INFER.export_path)

InputSpec 函数中第一个设置模型输入尺寸,第二个参数设置输入数据类型,第三个设置输入数据的 Key.

3.3.3 模型推理

创建预测器:

examples/earthformer/earthformer_enso_train.py
import predictor

predictor = predictor.EarthformerPredictor(cfg)

准备预测数据:

examples/earthformer/earthformer_enso_train.py
train_cmip = xr.open_dataset(cfg.INFER.data_path).transpose(
    "year", "month", "lat", "lon"
)
# select longitudes
lon = train_cmip.lon.values
lon = lon[np.logical_and(lon >= 95, lon <= 330)]
train_cmip = train_cmip.sel(lon=lon)
data = train_cmip.sst.values
data = enso_dataset.fold(data)

idx_sst = enso_dataset.prepare_inputs_targets(
    len_time=data.shape[0],
    input_length=cfg.INFER.in_len,
    input_gap=cfg.INFER.in_stride,
    pred_shift=cfg.INFER.out_len * cfg.INFER.out_stride,
    pred_length=cfg.INFER.out_len,
    samples_gap=cfg.INFER.samples_gap,
)
data = data[idx_sst].astype("float32")

sst_data = data[..., np.newaxis]
idx = np.random.choice(len(data), None, False)
in_seq = sst_data[idx, : cfg.INFER.in_len, ...]  # ( in_len, lat, lon, 1)
in_seq = in_seq[np.newaxis, ...]

进行模型预测与预测值保存:

examples/earthformer/earthformer_enso_train.py
pred_data = predictor.predict(in_seq, cfg.INFER.batch_size)

# save predict data
save_path = osp.join(cfg.output_dir, "result_enso_pred.npy")
np.save(save_path, pred_data)
logger.info(f"Save output to {save_path}")

4. 云总降水量 vil 模型实现

4.1 数据集介绍

数据集采用了 EarthFormer 处理好的 SEVIR 数据集。

The Storm Event ImagRy(SEVIR) 数据集是由麻省理工林肯实验室和亚马逊收集并提供的。SEVIR 是一个经过注释、整理和时空对齐的数据集,包含 10,000 多个天气事件,每个事件由 384 千米 x 384 千米的图像序列组成,时间跨度为 4 小时。SEVIR 中的图像通过五种不同的数据类型进行采样和对齐:GOES-16 高级基线成像仪的三个通道 (C02、C09、C13)、NEXRAD 垂直液态水含量 (vil) 和 GOES-16 地球静止闪电成像 (GLM) 闪烁图。

SEVIR数据集的结构包括两部分:目录 (Catalog) 和数据文件 (Data File)。目录是一个 CSV 文件,其中包含描述事件元数据的行。数据文件是一组 HDF5 文件,包含特定传感器类型的事件。这些文件中的数据以 4D 张量形式存储,形状为 N x L x W x T,其中 N 是文件中的事件数,LxW 是图像大小,T 是图像序列中的时间步数。

SEVIR

SEVIR 传感器类型说明

EarthFormer 采用 SEVIR 中的 NEXRAD 垂直液态水含量 (VIL) 作为降水预报的基准,即在 65 分钟的垂直综合液体背景下,预测未来 60 分钟的垂直综合液体。因此,分辨率为 13x384x384→12x384x384。

4.2 模型预训练

4.2.1 约束构建

本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 SupervisedConstraint 构建监督约束。在定义约束之前,需要首先指定监督约束中用于数据加载的各个参数。

数据加载的代码如下:

examples/earthformer/earthformer_sevir_train.py
# set train dataloader config
train_dataloader_cfg = {
    "dataset": {
        "name": "SEVIRDataset",
        "data_dir": cfg.FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.DATASET.label_keys,
        "data_types": cfg.DATASET.data_types,
        "seq_len": cfg.DATASET.seq_len,
        "raw_seq_len": cfg.DATASET.raw_seq_len,
        "sample_mode": cfg.DATASET.sample_mode,
        "stride": cfg.DATASET.stride,
        "batch_size": cfg.DATASET.batch_size,
        "layout": cfg.DATASET.layout,
        "in_len": cfg.DATASET.in_len,
        "out_len": cfg.DATASET.out_len,
        "split_mode": cfg.DATASET.split_mode,
        "start_date": cfg.TRAIN.start_date,
        "end_date": cfg.TRAIN.end_date,
        "preprocess": cfg.DATASET.preprocess,
        "rescale_method": cfg.DATASET.rescale_method,
        "shuffle": True,
        "verbose": False,
        "training": True,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,
}

其中,"dataset" 字段定义了使用的 Dataset 类名为 ENSODataset,"sampler" 字段定义了使用的 Sampler 类名为 BatchSampler,设置的 batch_size 为 1,num_works 为 8。

定义监督约束的代码如下:

examples/earthformer/earthformer_sevir_train.py
# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(sevir_metric.train_mse_func),
    name="Sup",
)
constraint = {sup_constraint.name: sup_constraint}

SupervisedConstraint 的第一个参数是数据的加载方式,这里使用上文中定义的 train_dataloader_cfg

第二个参数是损失函数的定义,这里使用自定义的损失函数 mse_loss

第三个参数是约束条件的名字,方便后续对其索引。此处命名为 Sup

4.2.2 模型构建

在该案例中,云总降水量模型基于 CuboidTransformer 网络模型实现,用 PaddleScience 代码表示如下:

examples/earthformer/earthformer_sevir_train.py
model = ppsci.arch.CuboidTransformer(
    **cfg.MODEL,
)

定义模型的参数通过配置进行设置,如下:

examples/earthformer/conf/earthformer_sevir_pretrain.yaml
MODEL:
  input_keys: ["input"]
  output_keys: ["vil"]
  input_shape: [13, 384, 384, 1]
  target_shape: [12, 384, 384, 1]
  base_units: 128
  scale_alpha: 1.0

  enc_depth: [1, 1]
  dec_depth: [1, 1]
  enc_use_inter_ffn: true
  dec_use_inter_ffn: true
  dec_hierarchical_pos_embed: false

  downsample: 2
  downsample_type: "patch_merge"
  upsample_type: "upsample"

  num_global_vectors: 8
  use_dec_self_global: false
  dec_self_update_global: true
  use_dec_cross_global: false
  use_global_vector_ffn: false
  use_global_self_attn: true
  separate_global_qkv: true
  global_dim_ratio: 1

  self_pattern: "axial"
  cross_self_pattern: "axial"
  cross_pattern: "cross_1x1"
  dec_cross_last_n_frames: null

  attn_drop: 0.1
  proj_drop: 0.1
  ffn_drop: 0.1
  num_heads: 4

  ffn_activation: "gelu"
  gated_ffn: false
  norm_layer: "layer_norm"
  padding_type: "zeros"
  pos_embed_type: "t+h+w"
  use_relative_pos: true
  self_attn_use_final_proj: true
  dec_use_first_self_attn: false

  z_init_method: "zeros"
  initial_downsample_type: "stack_conv"
  initial_downsample_activation: "leaky_relu"
  initial_downsample_stack_conv_num_layers: 3
  initial_downsample_stack_conv_dim_list: [16, 64, 128]
  initial_downsample_stack_conv_downscale_list: [3, 2, 2]
  initial_downsample_stack_conv_num_conv_list: [2, 2, 2]
  checkpoint_level: 2

  attn_linear_init_mode: "0"
  ffn_linear_init_mode: "0"
  conv_init_mode: "0"
  down_up_linear_init_mode: "0"
  norm_init_mode: "0"

其中,input_keysoutput_keys 分别代表网络模型输入、输出变量的名称。

4.2.3 学习率与优化器构建

本案例中使用的学习率方法为 Cosine,学习率大小设置为 1e-3。优化器使用 AdamW,并将参数进行分组,使用不同的 weight_decay,用 PaddleScience 代码表示如下:

examples/earthformer/earthformer_sevir_train.py
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if n in decay_parameters],
        "weight_decay": cfg.TRAIN.wd,
    },
    {
        "params": [
            p for n, p in model.named_parameters() if n not in decay_parameters
        ],
        "weight_decay": 0.0,
    },
]

# init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
    **lr_scheduler_cfg,
    iters_per_epoch=ITERS_PER_EPOCH,
    eta_min=cfg.TRAIN.min_lr_ratio * cfg.TRAIN.lr_scheduler.learning_rate,
    warmup_epoch=int(0.2 * cfg.TRAIN.epochs),
)()
optimizer = paddle.optimizer.AdamW(
    lr_scheduler, parameters=optimizer_grouped_parameters
)

4.2.4 评估器构建

本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 SupervisedValidator 构建评估器。代码如下:

examples/earthformer/earthformer_sevir_train.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "SEVIRDataset",
        "data_dir": cfg.FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.DATASET.label_keys,
        "data_types": cfg.DATASET.data_types,
        "seq_len": cfg.DATASET.seq_len,
        "raw_seq_len": cfg.DATASET.raw_seq_len,
        "sample_mode": cfg.DATASET.sample_mode,
        "stride": cfg.DATASET.stride,
        "batch_size": cfg.DATASET.batch_size,
        "layout": cfg.DATASET.layout,
        "in_len": cfg.DATASET.in_len,
        "out_len": cfg.DATASET.out_len,
        "split_mode": cfg.DATASET.split_mode,
        "start_date": cfg.TRAIN.end_date,
        "end_date": cfg.EVAL.end_date,
        "preprocess": cfg.DATASET.preprocess,
        "rescale_method": cfg.DATASET.rescale_method,
        "shuffle": False,
        "verbose": False,
        "training": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    loss=ppsci.loss.MSELoss(),
    metric={
        "rmse": ppsci.metric.FunctionalMetric(
            sevir_metric.eval_rmse_func(
                out_len=cfg.DATASET.seq_len,
                layout=cfg.DATASET.layout,
                metrics_mode=cfg.EVAL.metrics_mode,
                metrics_list=cfg.EVAL.metrics_list,
                threshold_list=cfg.EVAL.threshold_list,
            )
        ),
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

SupervisedValidator 评估器与 SupervisedConstraint 比较相似,不同的是评估器需要设置评价指标 metric,在这里使用了自定义的评价指标分别是 MAEMSEcsipodsucrbias,且后四个评价指标分别使用不同的阈值 [16,74,133,160,181,219]

4.2.5 模型训练

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练。

examples/earthformer/earthformer_sevir_train.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    constraint,
    cfg.output_dir,
    optimizer,
    lr_scheduler,
    cfg.TRAIN.epochs,
    ITERS_PER_EPOCH,
    eval_during_train=cfg.TRAIN.eval_during_train,
    seed=cfg.seed,
    validator=validator,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()

4.2.6 模型评估

由于目前 paddlescience 中的验证策略分为两类,一类是直接对验证数据集进行模型输出拼接,然后计算评价指标。另一类是按照每个 batch_size 计算评价指标,然后拼接,最后对所有结果求平均,该方法默认数据之间没有关联性。但是 SEVIR 数据集数据之间有关联性,所以不适用第二种方法;又由于 SEVIR 数据集量大,使用第一种方法验证显存需求大,因此验证 SEVIR 数据集使用的方法如下:

  • 1.对一个 batch size 计算 hitsmissesfas 三个数据
  • 2.对数据集所有数据保存所有 batch 的三个值的累加和.
  • 3.对三个值的累加和计算 csipodsucrbias 四个指标。
examples/earthformer/earthformer_sevir_train.py
# evaluate after finished training
metric = sevir_metric.eval_rmse_func(
    out_len=cfg.DATASET.seq_len,
    layout=cfg.DATASET.layout,
    metrics_mode=cfg.EVAL.metrics_mode,
    metrics_list=cfg.EVAL.metrics_list,
    threshold_list=cfg.EVAL.threshold_list,
)

with solver.no_grad_context_manager(True):
    for index, (input_, label, _) in enumerate(sup_validator.data_loader):
        truefield = label["vil"].squeeze(0)
        prefield = model(input_)["vil"].squeeze(0)
        metric.sevir_score.update(prefield, truefield)

metric_dict = metric.sevir_score.compute()
print(metric_dict)

4.3 模型评估可视化

4.3.1 测试集上评估模型

构建模型的代码为:

examples/earthformer/earthformer_sevir_train.py
model = ppsci.arch.CuboidTransformer(
    **cfg.MODEL,
)

构建评估器的代码为:

examples/earthformer/earthformer_sevir_train.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "SEVIRDataset",
        "data_dir": cfg.FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.DATASET.label_keys,
        "data_types": cfg.DATASET.data_types,
        "seq_len": cfg.DATASET.seq_len,
        "raw_seq_len": cfg.DATASET.raw_seq_len,
        "sample_mode": cfg.DATASET.sample_mode,
        "stride": cfg.DATASET.stride,
        "batch_size": cfg.DATASET.batch_size,
        "layout": cfg.DATASET.layout,
        "in_len": cfg.DATASET.in_len,
        "out_len": cfg.DATASET.out_len,
        "split_mode": cfg.DATASET.split_mode,
        "start_date": cfg.TEST.start_date,
        "end_date": cfg.TEST.end_date,
        "preprocess": cfg.DATASET.preprocess,
        "rescale_method": cfg.DATASET.rescale_method,
        "shuffle": False,
        "verbose": False,
        "training": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    loss=ppsci.loss.MSELoss(),
    metric={
        "rmse": ppsci.metric.FunctionalMetric(
            sevir_metric.eval_rmse_func(
                out_len=cfg.DATASET.seq_len,
                layout=cfg.DATASET.layout,
                metrics_mode=cfg.EVAL.metrics_mode,
                metrics_list=cfg.EVAL.metrics_list,
                threshold_list=cfg.EVAL.threshold_list,
            )
        ),
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

模型评估:

examples/earthformer/earthformer_sevir_train.py
# evaluate
metric = sevir_metric.eval_rmse_func(
    out_len=cfg.DATASET.seq_len,
    layout=cfg.DATASET.layout,
    metrics_mode=cfg.EVAL.metrics_mode,
    metrics_list=cfg.EVAL.metrics_list,
    threshold_list=cfg.EVAL.threshold_list,
)

with solver.no_grad_context_manager(True):
    for index, (input_, label, _) in enumerate(sup_validator.data_loader):
        truefield = label["vil"].reshape([-1, *label["vil"].shape[2:]])
        prefield = model(input_)["vil"].reshape([-1, *label["vil"].shape[2:]])
        metric.sevir_score.update(prefield, truefield)

metric_dict = metric.sevir_score.compute()
print(metric_dict)

4.3.2 模型导出

构建模型的代码为:

examples/earthformer/earthformer_sevir_train.py
# set model
model = ppsci.arch.CuboidTransformer(
    **cfg.MODEL,
)

实例化 ppsci.solver.Solver

examples/earthformer/earthformer_enso_train.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    pretrained_model_path=cfg.INFER.pretrained_model_path,
)

构建模型输入格式并导出静态模型:

examples/earthformer/earthformer_sevir_train.py
input_spec = [
    {
        key: InputSpec([1, 13, 384, 384, 1], "float32", name=key)
        for key in model.input_keys
    },
]
solver.export(input_spec, cfg.INFER.export_path)

InputSpec 函数中第一个设置模型输入尺寸,第二个参数设置输入数据类型,第三个设置输入数据的 Key.

4.3.3 模型推理

创建预测器:

examples/earthformer/earthformer_sevir_train.py
predictor = predictor.EarthformerPredictor(cfg)

准备预测数据并进行对应模式的数据预处理:

examples/earthformer/earthformer_sevir_train.py
if cfg.INFER.rescale_method == "sevir":
    scale_dict = sevir_dataset.PREPROCESS_SCALE_SEVIR
    offset_dict = sevir_dataset.PREPROCESS_OFFSET_SEVIR
elif cfg.INFER.rescale_method == "01":
    scale_dict = sevir_dataset.PREPROCESS_SCALE_01
    offset_dict = sevir_dataset.PREPROCESS_OFFSET_01
else:
    raise ValueError(f"Invalid rescale option: {cfg.INFER.rescale_method}.")

# read h5 data
h5data = h5py.File(cfg.INFER.data_path, "r")
data = np.array(h5data[cfg.INFER.data_type]).transpose([0, 3, 1, 2])

idx = np.random.choice(len(data), None, False)
data = (
    scale_dict[cfg.INFER.data_type] * data[idx] + offset_dict[cfg.INFER.data_type]
)

input_data = data[: cfg.INFER.in_len, ...]
input_data = input_data.reshape(1, *input_data.shape, 1).astype(np.float32)

进行模型预测并可视化:

examples/earthformer/earthformer_sevir_train.py
pred_data = predictor.predict(input_data, cfg.INFER.batch_size)

sevir_vis_seq.save_example_vis_results(
    save_dir=cfg.INFER.sevir_vis_save,
    save_prefix=f"data_{idx}",
    in_seq=input_data,
    target_seq=target_data,
    pred_seq=pred_data,
    layout=cfg.INFER.layout,
    plot_stride=cfg.INFER.plot_stride,
    label=cfg.INFER.logging_prefix,
    interval_real_time=cfg.INFER.interval_real_time,
)

5. 完整代码

examples/earthformer/earthformer_enso_train.py
from os import path as osp

import hydra
import numpy as np
import paddle
from omegaconf import DictConfig
from paddle import nn

import examples.earthformer.enso_metric as enso_metric
import ppsci
from ppsci.data.dataset import enso_dataset
from ppsci.utils import logger

try:
    import xarray as xr
except ModuleNotFoundError:
    raise ModuleNotFoundError("Please install xarray with `pip install xarray`.")


def get_parameter_names(model, forbidden_layer_types):
    result = []
    for name, child in model.named_children():
        result += [
            f"{name}.{n}"
            for n in get_parameter_names(child, forbidden_layer_types)
            if not isinstance(child, tuple(forbidden_layer_types))
        ]
    # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
    result += list(model._parameters.keys())
    return result


def train(cfg: DictConfig):
    # set train dataloader config
    train_dataloader_cfg = {
        "dataset": {
            "name": "ENSODataset",
            "data_dir": cfg.FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.DATASET.label_keys,
            "in_len": cfg.DATASET.in_len,
            "out_len": cfg.DATASET.out_len,
            "in_stride": cfg.DATASET.in_stride,
            "out_stride": cfg.DATASET.out_stride,
            "train_samples_gap": cfg.DATASET.train_samples_gap,
            "eval_samples_gap": cfg.DATASET.eval_samples_gap,
            "normalize_sst": cfg.DATASET.normalize_sst,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,
    }

    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)
    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ENSODataset",
            "data_dir": cfg.FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.DATASET.label_keys,
            "in_len": cfg.DATASET.in_len,
            "out_len": cfg.DATASET.out_len,
            "in_stride": cfg.DATASET.in_stride,
            "out_stride": cfg.DATASET.out_stride,
            "train_samples_gap": cfg.DATASET.train_samples_gap,
            "eval_samples_gap": cfg.DATASET.eval_samples_gap,
            "normalize_sst": cfg.DATASET.normalize_sst,
            "training": "eval",
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func),
        metric={
            "rmse": ppsci.metric.FunctionalMetric(enso_metric.eval_rmse_func),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    model = ppsci.arch.CuboidTransformer(
        **cfg.MODEL,
    )

    decay_parameters = get_parameter_names(model, [nn.LayerNorm])
    decay_parameters = [name for name in decay_parameters if "bias" not in name]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if n in decay_parameters],
            "weight_decay": cfg.TRAIN.wd,
        },
        {
            "params": [
                p for n, p in model.named_parameters() if n not in decay_parameters
            ],
            "weight_decay": 0.0,
        },
    ]

    # # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
        **lr_scheduler_cfg,
        iters_per_epoch=ITERS_PER_EPOCH,
        eta_min=cfg.TRAIN.min_lr_ratio * cfg.TRAIN.lr_scheduler.learning_rate,
        warmup_epoch=int(0.2 * cfg.TRAIN.epochs),
    )()
    optimizer = paddle.optimizer.AdamW(
        lr_scheduler, parameters=optimizer_grouped_parameters
    )

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=cfg.TRAIN.eval_during_train,
        seed=cfg.seed,
        validator=validator,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ENSODataset",
            "data_dir": cfg.FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.DATASET.label_keys,
            "in_len": cfg.DATASET.in_len,
            "out_len": cfg.DATASET.out_len,
            "in_stride": cfg.DATASET.in_stride,
            "out_stride": cfg.DATASET.out_stride,
            "train_samples_gap": cfg.DATASET.train_samples_gap,
            "eval_samples_gap": cfg.DATASET.eval_samples_gap,
            "normalize_sst": cfg.DATASET.normalize_sst,
            "training": "test",
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(enso_metric.train_mse_func),
        metric={
            "rmse": ppsci.metric.FunctionalMetric(enso_metric.eval_rmse_func),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    model = ppsci.arch.CuboidTransformer(
        **cfg.MODEL,
    )

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        log_freq=cfg.log_freq,
        seed=cfg.seed,
        validator=validator,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # evaluate
    solver.eval()


def export(cfg: DictConfig):
    # set model
    model = ppsci.arch.CuboidTransformer(
        **cfg.MODEL,
    )

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        pretrained_model_path=cfg.INFER.pretrained_model_path,
    )
    # export model
    from paddle.static import InputSpec

    input_spec = [
        {
            key: InputSpec([1, 12, 24, 48, 1], "float32", name=key)
            for key in model.input_keys
        },
    ]
    solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
    import predictor

    predictor = predictor.EarthformerPredictor(cfg)

    train_cmip = xr.open_dataset(cfg.INFER.data_path).transpose(
        "year", "month", "lat", "lon"
    )
    # select longitudes
    lon = train_cmip.lon.values
    lon = lon[np.logical_and(lon >= 95, lon <= 330)]
    train_cmip = train_cmip.sel(lon=lon)
    data = train_cmip.sst.values
    data = enso_dataset.fold(data)

    idx_sst = enso_dataset.prepare_inputs_targets(
        len_time=data.shape[0],
        input_length=cfg.INFER.in_len,
        input_gap=cfg.INFER.in_stride,
        pred_shift=cfg.INFER.out_len * cfg.INFER.out_stride,
        pred_length=cfg.INFER.out_len,
        samples_gap=cfg.INFER.samples_gap,
    )
    data = data[idx_sst].astype("float32")

    sst_data = data[..., np.newaxis]
    idx = np.random.choice(len(data), None, False)
    in_seq = sst_data[idx, : cfg.INFER.in_len, ...]  # ( in_len, lat, lon, 1)
    in_seq = in_seq[np.newaxis, ...]
    target_seq = sst_data[idx, cfg.INFER.in_len :, ...]  # ( out_len, lat, lon, 1)
    target_seq = target_seq[np.newaxis, ...]

    pred_data = predictor.predict(in_seq, cfg.INFER.batch_size)

    # save predict data
    save_path = osp.join(cfg.output_dir, "result_enso_pred.npy")
    np.save(save_path, pred_data)
    logger.info(f"Save output to {save_path}")


@hydra.main(
    version_base=None,
    config_path="./conf",
    config_name="earthformer_enso_pretrain.yaml",
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()
examples/earthformer/earthformer_sevir_train.py
import h5py
import hydra
import numpy as np
import paddle
import sevir_metric
import sevir_vis_seq
from omegaconf import DictConfig
from paddle import nn

import ppsci


def get_parameter_names(model, forbidden_layer_types):
    result = []
    for name, child in model.named_children():
        result += [
            f"{name}.{n}"
            for n in get_parameter_names(child, forbidden_layer_types)
            if not isinstance(child, tuple(forbidden_layer_types))
        ]
    # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
    result += list(model._parameters.keys())
    return result


def train(cfg: DictConfig):
    # set train dataloader config
    train_dataloader_cfg = {
        "dataset": {
            "name": "SEVIRDataset",
            "data_dir": cfg.FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.DATASET.label_keys,
            "data_types": cfg.DATASET.data_types,
            "seq_len": cfg.DATASET.seq_len,
            "raw_seq_len": cfg.DATASET.raw_seq_len,
            "sample_mode": cfg.DATASET.sample_mode,
            "stride": cfg.DATASET.stride,
            "batch_size": cfg.DATASET.batch_size,
            "layout": cfg.DATASET.layout,
            "in_len": cfg.DATASET.in_len,
            "out_len": cfg.DATASET.out_len,
            "split_mode": cfg.DATASET.split_mode,
            "start_date": cfg.TRAIN.start_date,
            "end_date": cfg.TRAIN.end_date,
            "preprocess": cfg.DATASET.preprocess,
            "rescale_method": cfg.DATASET.rescale_method,
            "shuffle": True,
            "verbose": False,
            "training": True,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,
    }

    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(sevir_metric.train_mse_func),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)
    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "SEVIRDataset",
            "data_dir": cfg.FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.DATASET.label_keys,
            "data_types": cfg.DATASET.data_types,
            "seq_len": cfg.DATASET.seq_len,
            "raw_seq_len": cfg.DATASET.raw_seq_len,
            "sample_mode": cfg.DATASET.sample_mode,
            "stride": cfg.DATASET.stride,
            "batch_size": cfg.DATASET.batch_size,
            "layout": cfg.DATASET.layout,
            "in_len": cfg.DATASET.in_len,
            "out_len": cfg.DATASET.out_len,
            "split_mode": cfg.DATASET.split_mode,
            "start_date": cfg.TRAIN.end_date,
            "end_date": cfg.EVAL.end_date,
            "preprocess": cfg.DATASET.preprocess,
            "rescale_method": cfg.DATASET.rescale_method,
            "shuffle": False,
            "verbose": False,
            "training": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        loss=ppsci.loss.MSELoss(),
        metric={
            "rmse": ppsci.metric.FunctionalMetric(
                sevir_metric.eval_rmse_func(
                    out_len=cfg.DATASET.seq_len,
                    layout=cfg.DATASET.layout,
                    metrics_mode=cfg.EVAL.metrics_mode,
                    metrics_list=cfg.EVAL.metrics_list,
                    threshold_list=cfg.EVAL.threshold_list,
                )
            ),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    model = ppsci.arch.CuboidTransformer(
        **cfg.MODEL,
    )

    decay_parameters = get_parameter_names(model, [nn.LayerNorm])
    decay_parameters = [name for name in decay_parameters if "bias" not in name]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if n in decay_parameters],
            "weight_decay": cfg.TRAIN.wd,
        },
        {
            "params": [
                p for n, p in model.named_parameters() if n not in decay_parameters
            ],
            "weight_decay": 0.0,
        },
    ]

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(
        **lr_scheduler_cfg,
        iters_per_epoch=ITERS_PER_EPOCH,
        eta_min=cfg.TRAIN.min_lr_ratio * cfg.TRAIN.lr_scheduler.learning_rate,
        warmup_epoch=int(0.2 * cfg.TRAIN.epochs),
    )()
    optimizer = paddle.optimizer.AdamW(
        lr_scheduler, parameters=optimizer_grouped_parameters
    )

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=cfg.TRAIN.eval_during_train,
        seed=cfg.seed,
        validator=validator,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    metric = sevir_metric.eval_rmse_func(
        out_len=cfg.DATASET.seq_len,
        layout=cfg.DATASET.layout,
        metrics_mode=cfg.EVAL.metrics_mode,
        metrics_list=cfg.EVAL.metrics_list,
        threshold_list=cfg.EVAL.threshold_list,
    )

    with solver.no_grad_context_manager(True):
        for index, (input_, label, _) in enumerate(sup_validator.data_loader):
            truefield = label["vil"].squeeze(0)
            prefield = model(input_)["vil"].squeeze(0)
            metric.sevir_score.update(prefield, truefield)

    metric_dict = metric.sevir_score.compute()
    print(metric_dict)


def evaluate(cfg: DictConfig):
    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "SEVIRDataset",
            "data_dir": cfg.FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.DATASET.label_keys,
            "data_types": cfg.DATASET.data_types,
            "seq_len": cfg.DATASET.seq_len,
            "raw_seq_len": cfg.DATASET.raw_seq_len,
            "sample_mode": cfg.DATASET.sample_mode,
            "stride": cfg.DATASET.stride,
            "batch_size": cfg.DATASET.batch_size,
            "layout": cfg.DATASET.layout,
            "in_len": cfg.DATASET.in_len,
            "out_len": cfg.DATASET.out_len,
            "split_mode": cfg.DATASET.split_mode,
            "start_date": cfg.TEST.start_date,
            "end_date": cfg.TEST.end_date,
            "preprocess": cfg.DATASET.preprocess,
            "rescale_method": cfg.DATASET.rescale_method,
            "shuffle": False,
            "verbose": False,
            "training": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        loss=ppsci.loss.MSELoss(),
        metric={
            "rmse": ppsci.metric.FunctionalMetric(
                sevir_metric.eval_rmse_func(
                    out_len=cfg.DATASET.seq_len,
                    layout=cfg.DATASET.layout,
                    metrics_mode=cfg.EVAL.metrics_mode,
                    metrics_list=cfg.EVAL.metrics_list,
                    threshold_list=cfg.EVAL.threshold_list,
                )
            ),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    model = ppsci.arch.CuboidTransformer(
        **cfg.MODEL,
    )

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        log_freq=cfg.log_freq,
        seed=cfg.seed,
        validator=validator,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # evaluate
    metric = sevir_metric.eval_rmse_func(
        out_len=cfg.DATASET.seq_len,
        layout=cfg.DATASET.layout,
        metrics_mode=cfg.EVAL.metrics_mode,
        metrics_list=cfg.EVAL.metrics_list,
        threshold_list=cfg.EVAL.threshold_list,
    )

    with solver.no_grad_context_manager(True):
        for index, (input_, label, _) in enumerate(sup_validator.data_loader):
            truefield = label["vil"].reshape([-1, *label["vil"].shape[2:]])
            prefield = model(input_)["vil"].reshape([-1, *label["vil"].shape[2:]])
            metric.sevir_score.update(prefield, truefield)

    metric_dict = metric.sevir_score.compute()
    print(metric_dict)


def export(cfg: DictConfig):
    # set model
    model = ppsci.arch.CuboidTransformer(
        **cfg.MODEL,
    )

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        pretrained_model_path=cfg.INFER.pretrained_model_path,
    )
    # export model
    from paddle.static import InputSpec

    input_spec = [
        {
            key: InputSpec([1, 13, 384, 384, 1], "float32", name=key)
            for key in model.input_keys
        },
    ]
    solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
    import predictor

    from ppsci.data.dataset import sevir_dataset

    predictor = predictor.EarthformerPredictor(cfg)

    if cfg.INFER.rescale_method == "sevir":
        scale_dict = sevir_dataset.PREPROCESS_SCALE_SEVIR
        offset_dict = sevir_dataset.PREPROCESS_OFFSET_SEVIR
    elif cfg.INFER.rescale_method == "01":
        scale_dict = sevir_dataset.PREPROCESS_SCALE_01
        offset_dict = sevir_dataset.PREPROCESS_OFFSET_01
    else:
        raise ValueError(f"Invalid rescale option: {cfg.INFER.rescale_method}.")

    # read h5 data
    h5data = h5py.File(cfg.INFER.data_path, "r")
    data = np.array(h5data[cfg.INFER.data_type]).transpose([0, 3, 1, 2])

    idx = np.random.choice(len(data), None, False)
    data = (
        scale_dict[cfg.INFER.data_type] * data[idx] + offset_dict[cfg.INFER.data_type]
    )

    input_data = data[: cfg.INFER.in_len, ...]
    input_data = input_data.reshape(1, *input_data.shape, 1).astype(np.float32)
    target_data = data[cfg.INFER.in_len : cfg.INFER.in_len + cfg.INFER.out_len, ...]
    target_data = target_data.reshape(1, *target_data.shape, 1).astype(np.float32)

    pred_data = predictor.predict(input_data, cfg.INFER.batch_size)

    sevir_vis_seq.save_example_vis_results(
        save_dir=cfg.INFER.sevir_vis_save,
        save_prefix=f"data_{idx}",
        in_seq=input_data,
        target_seq=target_data,
        pred_seq=pred_data,
        layout=cfg.INFER.layout,
        plot_stride=cfg.INFER.plot_stride,
        label=cfg.INFER.logging_prefix,
        interval_real_time=cfg.INFER.interval_real_time,
    )


@hydra.main(
    version_base=None,
    config_path="./conf",
    config_name="earthformer_sevir_pretrain.yaml",
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()

6. 结果展示

下图展示了云总降水量模型按照65分钟的输入数据,得到60分钟间隔的预测结果和真值结果。

SEVIR-predict

SEVIR 中 vil 的预测结果("prediction")与真值结果("target")

说明:

Hit:TP, Miss:FN, False Alarm:FP

第一行: 输入数据;

第二行: 真值结果;

第三行: 预测结果;

第四行: 设定阈值为 74 情况下,TP、FN、FP 三种情况标记

第五行: 在所有阈值情况下,TP、FN、FP 三种情况标记