跳转至

NSFNet4

AI Studio快速体验

# VP_NSFNet4
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip -P ./data/
unzip ./data/NSF4_data.zip
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip --create-dirs -o ./data/NSF4_data.zip
# unzip ./data/NSF4_data.zip
python VP_NSFNet4.py    mode=eval  data_dir=./data/  EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/nsfnet/nsfnet4.pdparams
# VP_NSFNet4
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip -P ./data/
unzip ./data/NSF4_data.zip
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip --create-dirs -o ./data/NSF4_data.zip
# unzip ./data/NSF4_data.zip
python VP_NSFNet4.py data_dir=./data/
python VP_NSFNet4.py mode=export
# VP_NSFNet4
# linux
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip -P ./data/
unzip ./data/NSF4_data.zip
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/NSFNet/NSF4_data.zip --create-dirs -o ./data/NSF4_data.zip
# unzip ./data/NSF4_data.zip
python VP_NSFNet4.py mode=infer

1. 背景简介

最近几年, 深度学习在很多领域取得了非凡的成就, 尤其是计算机视觉和自然语言处理方面, 而受启发于深度学习的快速发展, 基于深度学习强大的函数逼近能力, 神经网络在科学计算领域也取得了成功, 现阶段的研究主要分为两大类, 一类是将物理信息以及物理限制加入损失函数来对神经网络进行训练, 其代表有 PINN 以及 Deep Ritz Net, 另一类是通过数据驱动的深度神经网络算子, 其代表有 FNO 以及 DeepONet。这些方法都在科学实践中获得了广泛应用, 比如天气预测, 量子化学, 生物工程, 以及计算流体等领域。而为充分探索PINN对流体方程的求解能力, 本次复现论文作者设计了NSFNets, 并且先后使用具有解析解或数值解的二维、三维纳韦斯托克方程以及使用DNS方法进行高精度求解的数据集作为参考, 进行正问题求解训练。论文实验表明PINN对不可压纳韦斯托克方程具有优秀的数值求解能力, 本项目主要目标是使用PaddleScience复现论文所实现的高精度求解纳韦斯托克方程的代码。

2. 问题定义

本问题所使用的为最经典的PINN模型, 对此不再赘述。

主要介绍所求解的几类纳韦斯托克方程:

不可压纳韦斯托克方程可以表示为:

\[\frac{\partial \mathbf{u}}{\partial t}+(\mathbf{u} \cdot \nabla) \mathbf{u} =-\nabla p+\frac{1}{Re} \nabla^2 \mathbf{u} \quad \text { in } \Omega, \]
\[\nabla \cdot \mathbf{u} =0 \quad \text { in } \Omega, \]
\[\mathbf{u} =\mathbf{u}_{\Gamma} \quad \text { on } \Gamma_D, \]
\[\frac{\partial \mathbf{u}}{\partial n} =0 \quad \text { on } \Gamma_N.\]

2.1 JHTDB 数据集

数据集为使用DNS求解Re=999.35的三维不可压强迫各向同性湍流的高精度数据集, 详细参数可见readme.

3. 问题求解

3.1 模型构建

本文使用PINN经典的MLP模型进行训练。

model = ppsci.arch.MLP(**cfg.MODEL)

3.2 数据生成

先后取边界点、初值点、以及用于计算残差的内部点(具体取法见论文节3.3)以及生成测试点。

# load data
(
    x_train,
    y_train,
    z_train,
    t_train,
    x0_train,
    y0_train,
    z0_train,
    t0_train,
    u0_train,
    v0_train,
    w0_train,
    xb_train,
    yb_train,
    zb_train,
    tb_train,
    ub_train,
    vb_train,
    wb_train,
    x_star,
    y_star,
    z_star,
    t_star,
    u_star,
    v_star,
    w_star,
    p_star,
) = generate_data(cfg.data_dir)

3.3 归一化处理

为将所取较小长方体区域改为正方体区域, 我们将归一化函数嵌入网络训练前。

# normalization
Xb = np.concatenate([xb_train, yb_train, zb_train, tb_train], 1)
lowb = Xb.min(0)  # minimal number in each column
upb = Xb.max(0)
trans = Transform(paddle.to_tensor(lowb), paddle.to_tensor(upb))
model.register_input_transform(trans.input_trans)

3.4 约束构建

由于我们边界点和初值点具有解析解, 因此我们使用监督约束, 其中alpha和beta为该损失函数的权重, 在本代码中与论文中描述一致, 都取为100。

sup_constraint_b = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg_b,
    ppsci.loss.MSELoss("mean", cfg.alpha),
    name="Sup_b",
)

# supervised constraint s.t ||u-u_0||
sup_constraint_0 = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg_ic,
    ppsci.loss.MSELoss("mean", cfg.beta),
    name="Sup_ic",
)

使用内部点构造纳韦斯托克方程的残差约束

# set equation constarint s.t. ||F(u)||
equation = {
    "NavierStokes": ppsci.equation.NavierStokes(
        nu=1.0 / cfg.re, rho=1.0, dim=3, time=True
    ),
}

pde_constraint = ppsci.constraint.InteriorConstraint(
    equation["NavierStokes"].equations,
    {"continuity": 0, "momentum_x": 0, "momentum_y": 0, "momentum_z": 0},
    geom,
    {
        "dataset": {"name": "NamedArrayDataset"},
        "batch_size": cfg.ntrain,
        "iters_per_epoch": cfg.TRAIN.lr_scheduler.iters_per_epoch,
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
    },
    ppsci.loss.MSELoss("mean"),
    name="EQ",
)

3.5 评估器构建

使用在数据生成时生成的测试点构造的测试集用于模型评估:

residual_validator = ppsci.validate.SupervisedValidator(
    valid_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric={"L2R": ppsci.metric.L2Rel()},
    name="Residual",
)

3.6 优化器构建

与论文中描述相同, 我们使用分段学习率构造Adam优化器, 其中可以通过调节epoch_list来调节训练轮数。

# set optimizer
lr_scheduler = ppsci.optimizer.lr_scheduler.Piecewise(**cfg.TRAIN.lr_scheduler)()
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

3.7 模型训练与评估

完成上述设置之后, 只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver

# initialize solver
solver = ppsci.solver.Solver(
    model=model,
    constraint=constraint,
    output_dir=cfg.output_dir,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    epochs=cfg.epochs,
    iters_per_epoch=cfg.TRAIN.lr_scheduler.iters_per_epoch,
    log_freq=cfg.TRAIN.log_freq,
    save_freq=cfg.TRAIN.save_freq,
    eval_freq=cfg.TRAIN.eval_freq,
    eval_during_train=True,
    seed=cfg.seed,
    equation=equation,
    geom=geom,
    validator=validator,
    eval_with_no_grad=cfg.TRAIN.eval_with_no_grad,
)

最后启动训练即可:

# train model
solver.train()

4. 完整代码

NSFNet.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path as osp

import hydra
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import paddle
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger


def generate_data(data_dir):
    train_ini1 = np.load(osp.join(data_dir, "train_ini2.npy")).astype(
        paddle.get_default_dtype()
    )
    train_iniv1 = np.load(osp.join(data_dir, "train_iniv2.npy")).astype(
        paddle.get_default_dtype()
    )
    train_xb1 = np.load(osp.join(data_dir, "train_xb2.npy")).astype(
        paddle.get_default_dtype()
    )
    train_vb1 = np.load(osp.join(data_dir, "train_vb2.npy")).astype(
        paddle.get_default_dtype()
    )

    xnode = np.linspace(12.47, 12.66, 191).astype(paddle.get_default_dtype())
    ynode = np.linspace(-1, -0.0031, 998).astype(paddle.get_default_dtype())
    znode = np.linspace(4.61, 4.82, 211).astype(paddle.get_default_dtype())

    x0_train = train_ini1[:, 0:1]
    y0_train = train_ini1[:, 1:2]
    z0_train = train_ini1[:, 2:3]
    t0_train = np.zeros_like(train_ini1[:, 0:1]).astype(paddle.get_default_dtype())
    u0_train = train_iniv1[:, 0:1]
    v0_train = train_iniv1[:, 1:2]
    w0_train = train_iniv1[:, 2:3]

    xb_train = train_xb1[:, 0:1]
    yb_train = train_xb1[:, 1:2]
    zb_train = train_xb1[:, 2:3]
    tb_train = train_xb1[:, 3:4]
    ub_train = train_vb1[:, 0:1]
    vb_train = train_vb1[:, 1:2]
    wb_train = train_vb1[:, 2:3]

    x_train1 = xnode.reshape(-1, 1)[np.random.choice(191, 100000, replace=True), :]
    y_train1 = ynode.reshape(-1, 1)[np.random.choice(998, 100000, replace=True), :]
    z_train1 = znode.reshape(-1, 1)[np.random.choice(211, 100000, replace=True), :]
    x_train = np.tile(x_train1, (17, 1))
    y_train = np.tile(y_train1, (17, 1))
    z_train = np.tile(z_train1, (17, 1))

    total_times1 = (np.array(list(range(17))) * 0.0065).astype(
        paddle.get_default_dtype()
    )
    t_train1 = total_times1.repeat(100000)
    t_train = t_train1.reshape(-1, 1)
    # test data
    test_x = np.load(osp.join(data_dir, "test43_l.npy")).astype(
        paddle.get_default_dtype()
    )
    test_v = np.load(osp.join(data_dir, "test43_vp.npy")).astype(
        paddle.get_default_dtype()
    )
    t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065]).astype(
        paddle.get_default_dtype()
    )
    t_star = np.tile(t.reshape(5, 1), (1, 3000)).reshape(-1, 1)
    x_star = np.tile(test_x[:, 0:1], (5, 1))
    y_star = np.tile(test_x[:, 1:2], (5, 1))
    z_star = np.tile(test_x[:, 2:3], (5, 1))
    u_star = test_v[:, 0:1]
    v_star = test_v[:, 1:2]
    w_star = test_v[:, 2:3]
    p_star = test_v[:, 3:4]

    return (
        x_train,
        y_train,
        z_train,
        t_train,
        x0_train,
        y0_train,
        z0_train,
        t0_train,
        u0_train,
        v0_train,
        w0_train,
        xb_train,
        yb_train,
        zb_train,
        tb_train,
        ub_train,
        vb_train,
        wb_train,
        x_star,
        y_star,
        z_star,
        t_star,
        u_star,
        v_star,
        w_star,
        p_star,
    )


class Transform:
    def __init__(self, lowb, upb) -> None:
        self.lowb = {"x": lowb[0], "y": lowb[1], "z": lowb[2], "t": lowb[3]}
        self.upb = {"x": upb[0], "y": upb[1], "z": upb[2], "t": upb[3]}

    def input_trans(self, input_dict):
        for key, v in input_dict.items():
            v = 2.0 * (v - self.lowb[key]) / (self.upb[key] - self.lowb[key]) - 1.0
            input_dict[key] = v
        return input_dict


def train(cfg: DictConfig):
    # set model
    model = ppsci.arch.MLP(**cfg.MODEL)

    # load data
    (
        x_train,
        y_train,
        z_train,
        t_train,
        x0_train,
        y0_train,
        z0_train,
        t0_train,
        u0_train,
        v0_train,
        w0_train,
        xb_train,
        yb_train,
        zb_train,
        tb_train,
        ub_train,
        vb_train,
        wb_train,
        x_star,
        y_star,
        z_star,
        t_star,
        u_star,
        v_star,
        w_star,
        p_star,
    ) = generate_data(cfg.data_dir)

    # normalization
    Xb = np.concatenate([xb_train, yb_train, zb_train, tb_train], 1)
    lowb = Xb.min(0)  # minimal number in each column
    upb = Xb.max(0)
    trans = Transform(paddle.to_tensor(lowb), paddle.to_tensor(upb))
    model.register_input_transform(trans.input_trans)

    # set dataloader config
    train_dataloader_cfg_b = {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {"x": xb_train, "y": yb_train, "z": zb_train, "t": tb_train},
            "label": {"u": ub_train, "v": vb_train, "w": wb_train},
        },
        "batch_size": cfg.nb_train,
        "iters_per_epoch": cfg.TRAIN.lr_scheduler.iters_per_epoch,
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
    }

    train_dataloader_cfg_ic = {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {"x": x0_train, "y": y0_train, "z": z0_train, "t": t0_train},
            "label": {"u": u0_train, "v": v0_train, "w": w0_train},
        },
        "batch_size": cfg.n0_train,
        "iters_per_epoch": cfg.TRAIN.lr_scheduler.iters_per_epoch,
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
    }

    valid_dataloader_cfg = {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {"x": x_star, "y": y_star, "z": z_star, "t": t_star},
            "label": {"u": u_star, "v": v_star, "w": w_star, "p": p_star},
        },
        "total_size": u_star.shape[0],
        "batch_size": u_star.shape[0],
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
    }

    geom = ppsci.geometry.PointCloud(
        {"x": x_train, "y": y_train, "z": z_train, "t": t_train}, ("x", "y", "z", "t")
    )
    # supervised constraint s.t ||u-u_b||
    sup_constraint_b = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg_b,
        ppsci.loss.MSELoss("mean", cfg.alpha),
        name="Sup_b",
    )

    # supervised constraint s.t ||u-u_0||
    sup_constraint_0 = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg_ic,
        ppsci.loss.MSELoss("mean", cfg.beta),
        name="Sup_ic",
    )

    # set equation constarint s.t. ||F(u)||
    equation = {
        "NavierStokes": ppsci.equation.NavierStokes(
            nu=1.0 / cfg.re, rho=1.0, dim=3, time=True
        ),
    }

    pde_constraint = ppsci.constraint.InteriorConstraint(
        equation["NavierStokes"].equations,
        {"continuity": 0, "momentum_x": 0, "momentum_y": 0, "momentum_z": 0},
        geom,
        {
            "dataset": {"name": "NamedArrayDataset"},
            "batch_size": cfg.ntrain,
            "iters_per_epoch": cfg.TRAIN.lr_scheduler.iters_per_epoch,
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": True,
            },
        },
        ppsci.loss.MSELoss("mean"),
        name="EQ",
    )

    # wrap constraints
    constraint = {
        pde_constraint.name: pde_constraint,
        sup_constraint_b.name: sup_constraint_b,
        sup_constraint_0.name: sup_constraint_0,
    }

    residual_validator = ppsci.validate.SupervisedValidator(
        valid_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric={"L2R": ppsci.metric.L2Rel()},
        name="Residual",
    )

    # wrap validator
    validator = {residual_validator.name: residual_validator}

    # set optimizer
    lr_scheduler = ppsci.optimizer.lr_scheduler.Piecewise(**cfg.TRAIN.lr_scheduler)()
    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
    # initialize solver
    solver = ppsci.solver.Solver(
        model=model,
        constraint=constraint,
        output_dir=cfg.output_dir,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        epochs=cfg.epochs,
        iters_per_epoch=cfg.TRAIN.lr_scheduler.iters_per_epoch,
        log_freq=cfg.TRAIN.log_freq,
        save_freq=cfg.TRAIN.save_freq,
        eval_freq=cfg.TRAIN.eval_freq,
        eval_during_train=True,
        seed=cfg.seed,
        equation=equation,
        geom=geom,
        validator=validator,
        eval_with_no_grad=cfg.TRAIN.eval_with_no_grad,
    )
    # train model
    solver.train()

    # evaluate after finished training
    solver.eval()

    solver.plot_loss_history()


def evaluate(cfg: DictConfig):
    # set model
    model = ppsci.arch.MLP(**cfg.MODEL)

    # test Data
    test_x = np.load(osp.join(cfg.data_dir, "test43_l.npy")).astype(
        paddle.get_default_dtype()
    )
    test_v = np.load(osp.join(cfg.data_dir, "test43_vp.npy")).astype(
        paddle.get_default_dtype()
    )
    t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065]).astype(
        paddle.get_default_dtype()
    )
    t_star = paddle.to_tensor(np.tile(t.reshape(5, 1), (1, 3000)).reshape(-1, 1))
    x_star = paddle.to_tensor(np.tile(test_x[:, 0:1], (5, 1)).reshape(-1, 1))
    y_star = paddle.to_tensor(np.tile(test_x[:, 1:2], (5, 1)).reshape(-1, 1))
    z_star = paddle.to_tensor(np.tile(test_x[:, 2:3], (5, 1)).reshape(-1, 1))
    u_star = paddle.to_tensor(test_v[:, 0:1])
    v_star = paddle.to_tensor(test_v[:, 1:2])
    w_star = paddle.to_tensor(test_v[:, 2:3])
    p_star = paddle.to_tensor(test_v[:, 3:4])

    # wrap validator
    ppsci.utils.load_pretrain(model, cfg.EVAL.pretrained_model_path)

    # print the relative error
    solution = model(
        {
            "x": x_star,
            "y": y_star,
            "z": z_star,
            "t": t_star,
        }
    )
    u_pred = solution["u"].reshape((5, -1))
    v_pred = solution["v"].reshape((5, -1))
    w_pred = solution["w"].reshape((5, -1))
    p_pred = solution["p"].reshape((5, -1))
    u_star = u_star.reshape((5, -1))
    v_star = v_star.reshape((5, -1))
    w_star = w_star.reshape((5, -1))
    p_star = p_star.reshape((5, -1))

    # NS equation can figure out pressure drop, need background pressure p_star.mean()
    p_pred = p_pred - p_pred.mean() + p_star.mean()

    u_error = paddle.linalg.norm(u_pred - u_star, axis=1) / np.linalg.norm(
        u_star, axis=1
    )
    v_error = paddle.linalg.norm(v_pred - v_star, axis=1) / np.linalg.norm(
        v_star, axis=1
    )
    w_error = paddle.linalg.norm(w_pred - w_star, axis=1) / np.linalg.norm(
        w_star, axis=1
    )
    p_error = paddle.linalg.norm(p_pred - p_star, axis=1) / np.linalg.norm(
        w_star, axis=1
    )
    t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065])
    plt.plot(t, np.array(u_error))
    plt.plot(t, np.array(v_error))
    plt.plot(t, np.array(w_error))
    plt.plot(t, np.array(p_error))
    plt.legend(["u_error", "v_error", "w_error", "p_error"])
    plt.xlabel("t")
    plt.ylabel("Relative l2 Error")
    plt.title("Relative l2 Error, on test dataset")
    plt.savefig(osp.join(cfg.output_dir, "error.jpg"))
    logger.info("L2 error picture is saved")

    grid_x, grid_y = np.mgrid[
        x_star.min() : x_star.max() : 100j, y_star.min() : y_star.max() : 100j
    ].astype(paddle.get_default_dtype())
    x_plot = paddle.to_tensor(grid_x.reshape(-1, 1))
    y_plot = paddle.to_tensor(grid_y.reshape(-1, 1))
    z_plot = paddle.to_tensor(z_star.min() * paddle.ones(y_plot.shape))
    t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32)
    sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot})
    fig, ax = plt.subplots(1, 4, figsize=(16, 4))
    cmap = matplotlib.colormaps.get_cmap("jet")

    ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[0].set_title("u prediction")
    ax[1].contourf(grid_x, grid_y, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[1].set_title("v prediction")
    ax[2].contourf(grid_x, grid_y, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[2].set_title("w prediction")
    ax[3].contourf(grid_x, grid_y, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[3].set_title("p prediction")
    norm = matplotlib.colors.Normalize(
        vmin=sol["u"].min(), vmax=sol["u"].max()
    )  # set maximum and minimum
    im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    plt.savefig(osp.join(cfg.output_dir, "z=0 plane"))

    grid_y, grid_z = np.mgrid[
        y_star.min() : y_star.max() : 100j, z_star.min() : z_star.max() : 100j
    ].astype(paddle.get_default_dtype())
    z_plot = paddle.to_tensor(grid_z.reshape(-1, 1))
    y_plot = paddle.to_tensor(grid_y.reshape(-1, 1))
    x_plot = paddle.to_tensor(x_star.min() * paddle.ones(y_plot.shape))
    t_plot = paddle.to_tensor((t[-1]) * np.ones(x_plot.shape), paddle.float32)
    sol = model({"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot})
    fig, ax = plt.subplots(1, 4, figsize=(16, 4))
    cmap = matplotlib.colormaps.get_cmap("jet")

    ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[0].set_title("u prediction")
    ax[1].contourf(grid_y, grid_z, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[1].set_title("v prediction")
    ax[2].contourf(grid_y, grid_z, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[2].set_title("w prediction")
    ax[3].contourf(grid_y, grid_z, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[3].set_title("p prediction")
    norm = matplotlib.colors.Normalize(
        vmin=sol["u"].min(), vmax=sol["u"].max()
    )  # set maximum and minimum
    im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    plt.savefig(osp.join(cfg.output_dir, "x=0 plane"))


def export(cfg: DictConfig):
    from paddle.static import InputSpec

    # set models
    model = ppsci.arch.MLP(**cfg.MODEL)

    # load pretrained model
    solver = ppsci.solver.Solver(
        model=model, pretrained_model_path=cfg.INFER.pretrained_model_path
    )

    # export models
    input_spec = [
        {key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
    ]
    solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
    from deploy.python_infer import pinn_predictor

    # set model predictor
    predictor = pinn_predictor.PINNPredictor(cfg)

    # infer Data
    test_x = np.load(osp.join(cfg.data_dir, "test43_l.npy")).astype(np.float32)
    test_v = np.load(osp.join(cfg.data_dir, "test43_vp.npy")).astype(np.float32)
    t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065]).astype(
        np.float32
    )
    t_star = np.tile(t.reshape(5, 1), (1, 3000)).reshape(-1, 1)
    x_star = np.tile(test_x[:, 0:1], (5, 1)).reshape(-1, 1)
    y_star = np.tile(test_x[:, 1:2], (5, 1)).reshape(-1, 1)
    z_star = np.tile(test_x[:, 2:3], (5, 1)).reshape(-1, 1)
    u_star = test_v[:, 0:1]
    v_star = test_v[:, 1:2]
    w_star = test_v[:, 2:3]
    p_star = test_v[:, 3:4]

    pred = predictor.predict(
        {
            "x": x_star,
            "y": y_star,
            "z": z_star,
            "t": t_star,
        },
        cfg.INFER.batch_size,
    )

    pred = {
        store_key: pred[infer_key]
        for store_key, infer_key in zip(cfg.INFER.output_keys, pred.keys())
    }

    u_pred = pred["u"].reshape((5, -1))
    v_pred = pred["v"].reshape((5, -1))
    w_pred = pred["w"].reshape((5, -1))
    p_pred = pred["p"].reshape((5, -1))
    u_star = u_star.reshape((5, -1))
    v_star = v_star.reshape((5, -1))
    w_star = w_star.reshape((5, -1))
    p_star = p_star.reshape((5, -1))

    # NS equation can figure out pressure drop, need background pressure p_star.mean()
    p_pred = p_pred - p_pred.mean() + p_star.mean()

    u_error = np.linalg.norm(u_pred - u_star, axis=1) / np.linalg.norm(u_star, axis=1)
    v_error = np.linalg.norm(v_pred - v_star, axis=1) / np.linalg.norm(v_star, axis=1)
    w_error = np.linalg.norm(w_pred - w_star, axis=1) / np.linalg.norm(w_star, axis=1)
    p_error = np.linalg.norm(p_pred - p_star, axis=1) / np.linalg.norm(w_star, axis=1)
    t = np.array([0.0065, 4 * 0.0065, 7 * 0.0065, 10 * 0.0065, 13 * 0.0065])
    plt.plot(t, np.array(u_error))
    plt.plot(t, np.array(v_error))
    plt.plot(t, np.array(w_error))
    plt.plot(t, np.array(p_error))
    plt.legend(["u_error", "v_error", "w_error", "p_error"])
    plt.xlabel("t")
    plt.ylabel("Relative l2 Error")
    plt.title("Relative l2 Error, on test dataset")
    plt.savefig(osp.join(cfg.output_dir, "error.jpg"))

    grid_x, grid_y = np.mgrid[
        x_star.min() : x_star.max() : 100j, y_star.min() : y_star.max() : 100j
    ].astype(np.float32)
    x_plot = grid_x.reshape(-1, 1)
    y_plot = grid_y.reshape(-1, 1)
    z_plot = (z_star.min() * np.ones(y_plot.shape)).astype(np.float32)
    t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32)
    sol = predictor.predict(
        {"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size
    )
    sol = {
        store_key: sol[infer_key]
        for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys())
    }
    fig, ax = plt.subplots(1, 4, figsize=(16, 4))
    cmap = matplotlib.colormaps.get_cmap("jet")

    ax[0].contourf(grid_x, grid_y, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[0].set_title("u prediction")
    ax[1].contourf(grid_x, grid_y, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[1].set_title("v prediction")
    ax[2].contourf(grid_x, grid_y, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[2].set_title("w prediction")
    ax[3].contourf(grid_x, grid_y, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[3].set_title("p prediction")
    norm = matplotlib.colors.Normalize(
        vmin=sol["u"].min(), vmax=sol["u"].max()
    )  # set maximum and minimum
    im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    plt.savefig(osp.join(cfg.output_dir, "z=0 plane"))

    grid_y, grid_z = np.mgrid[
        y_star.min() : y_star.max() : 100j, z_star.min() : z_star.max() : 100j
    ].astype(np.float32)
    z_plot = grid_z.reshape(-1, 1)
    y_plot = grid_y.reshape(-1, 1)
    x_plot = (x_star.min() * np.ones(y_plot.shape)).astype(np.float32)
    t_plot = ((t[-1]) * np.ones(x_plot.shape)).astype(np.float32)
    sol = predictor.predict(
        {"x": x_plot, "y": y_plot, "z": z_plot, "t": t_plot}, cfg.INFER.batch_size
    )
    sol = {
        store_key: sol[infer_key]
        for store_key, infer_key in zip(cfg.INFER.output_keys, sol.keys())
    }
    fig, ax = plt.subplots(1, 4, figsize=(16, 4))
    cmap = matplotlib.colormaps.get_cmap("jet")

    ax[0].contourf(grid_y, grid_z, sol["u"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[0].set_title("u prediction")
    ax[1].contourf(grid_y, grid_z, sol["v"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[1].set_title("v prediction")
    ax[2].contourf(grid_y, grid_z, sol["w"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[2].set_title("w prediction")
    ax[3].contourf(grid_y, grid_z, sol["p"].reshape(grid_x.shape), levels=50, cmap=cmap)
    ax[3].set_title("p prediction")
    norm = matplotlib.colors.Normalize(
        vmin=sol["u"].min(), vmax=sol["u"].max()
    )  # set maximum and minimum
    im = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    ax13 = fig.add_axes([0.125, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.325, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.525, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    ax13 = fig.add_axes([0.725, 0.0, 0.175, 0.02])
    plt.colorbar(im, cax=ax13, orientation="horizontal")
    plt.savefig(osp.join(cfg.output_dir, "x=0 plane"))


@hydra.main(version_base=None, config_path="./conf", config_name="VP_NSFNet4.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()

5. 结果展示

NSFNet4

如图所示, NSFNet的结果在时间上的误差相对平稳, 并未出现传统方法中经常出现的误差累积问题。其中, 虽然在训练过程中三个方向的速度并未被设置权重, 但是训练结果可以看出, 神经网络在第一个速度方向u上面逼近效果最好, 在第三个速度方向w上面逼近效果次之, 在第二个速度v上面逼近效果最差且出现较为明显的误差累积现象。 image

如图所示, 在x=12.47的y-z平面的轮廓图, 第一个为速度u的轮廓图, 第二个为速度v的轮廓图, 第三个为速度w的轮廓图, 第四个为速度p的轮廓图。可以看出, 速度u的轮廓图相对于v, w, p来说较为光滑。 image

如图所示, 在z=4.61的x-y平面的轮廓图, 第一个为速度u的轮廓图, 第二个为速度v的轮廓图, 第三个为速度w的轮廓图, 第四个为速度p的轮廓图。可以看出, 速度u的轮廓图相对于v, w, p来说较为光滑。 image

综上所述, 虽然u, v, w三个速度方向都是需要神经网络进行训练, 但是对于JHTDB数据集来说, u方向数据较为光滑, 更容易被神经网络所学习。因此在后续研究中, 可以尝试对三个不同方向的分量分而治之, 加大复杂分量方向的训练强度, 减少简单分量方向的训练强度。

6. 结果说明

我们使用PINN对不可压纳韦斯托克方程进行数值求解。在PINN中, 随机选取的时间和空间的坐标被当作输入值, 所对应的速度场以及压强场被当作输出值, 使用初值、边界条件当作监督约束以及纳韦斯托克方程本身的当作无监督约束条件加入损失函数进行训练。我们使用高精度JHTDB数据集进行训练。通过损失函数的下降可以证明神经网络在求解纳韦斯托克方程中的收敛性, 表明PINN拥有对不可压强迫各项同性湍流的求解能力。而通过实验结果表明, PINN可以很好的逼近对应的高精度不可压强迫各项同性湍流数据集, 并且, 我们发现增加边界约束以及初值约束的权重可以使得神经网络拥有更好的逼近效果。相比之下, 在误差允许范围内, 使用PINN求解该纳韦斯托克方程比原本使用DNS方法的推理速度更快。

7. 参考资料