跳转至

Loss.mtl(多任务学习) 模块

ppsci.loss.mtl

AGDA

Bases: LossAggregator

Adaptive Gradient Descent Algorithm

Physics-informed neural network based on a new adaptive gradient descent algorithm for solving partial differential equations of flow problems

NOTE: This loss aggregator is only suitable for two-task learning and the first task loss must be PDE loss.

Parameters:

Name Type Description Default
model Layer

Training model.

required
M int

Smoothing period. Defaults to 100.

100
gamma float

Smooth factor. Defaults to 0.999.

0.999

Examples:

>>> import paddle
>>> from ppsci.loss import mtl
>>> model = paddle.nn.Linear(3, 4)
>>> loss_aggregator = mtl.AGDA(model)
>>> for i in range(5):
...     x1 = paddle.randn([8, 3])
...     x2 = paddle.randn([8, 3])
...     y1 = model(x1)
...     y2 = model(x2)
...     pde_loss = paddle.sum(y1)
...     bc_loss = paddle.sum((y2 - 2) ** 2)
...     loss_aggregator({'pde_loss': pde_loss, 'bc_loss': bc_loss}).backward()
Source code in ppsci/loss/mtl/agda.py
class AGDA(base.LossAggregator):
    r"""
    **A**daptive **G**radient **D**escent **A**lgorithm

    [Physics-informed neural network based on a new adaptive gradient descent algorithm for solving partial differential equations of flow problems](https://pubs.aip.org/aip/pof/article-abstract/35/6/063608/2899773/Physics-informed-neural-network-based-on-a-new)

    NOTE: This loss aggregator is only suitable for two-task learning and the first task loss must be PDE loss.

    Args:
        model (nn.Layer): Training model.
        M (int, optional): Smoothing period. Defaults to 100.
        gamma (float, optional): Smooth factor. Defaults to 0.999.

    Examples:
        >>> import paddle
        >>> from ppsci.loss import mtl
        >>> model = paddle.nn.Linear(3, 4)
        >>> loss_aggregator = mtl.AGDA(model)
        >>> for i in range(5):
        ...     x1 = paddle.randn([8, 3])
        ...     x2 = paddle.randn([8, 3])
        ...     y1 = model(x1)
        ...     y2 = model(x2)
        ...     pde_loss = paddle.sum(y1)
        ...     bc_loss = paddle.sum((y2 - 2) ** 2)
        ...     loss_aggregator({'pde_loss': pde_loss, 'bc_loss': bc_loss}).backward()
    """

    def __init__(self, model: nn.Layer, M: int = 100, gamma: float = 0.999) -> None:
        super().__init__(model)
        self.M = M
        self.gamma = gamma
        self.Lf_smooth = 0
        self.Lu_smooth = 0
        self.Lf_tilde_acc = 0.0
        self.Lu_tilde_acc = 0.0

    def __call__(self, losses, step: int = 0) -> "AGDA":
        if len(losses) != 2:
            raise ValueError(
                f"Number of losses(tasks) for AGDA shoule be 2, but got {len(losses)}"
            )
        return super().__call__(losses, step)

    def backward(self) -> None:
        grads_list = self._compute_grads()
        with paddle.no_grad():
            refined_grads = self._refine_grads(grads_list)
            self._set_grads(refined_grads)

    def _compute_grads(self) -> List[paddle.Tensor]:
        # compute all gradients derived by each loss
        grads_list = []  # num_params x num_losses
        for key in self.losses:
            # backward with current loss
            self.losses[key].backward()
            grads_list.append(
                paddle.concat(
                    [
                        param.grad.clone().reshape([-1])
                        for param in self.model.parameters()
                        if param.grad is not None
                    ],
                    axis=0,
                )
            )
            # clear gradients for current loss for not affecting other loss
            self.model.clear_gradients()

        return grads_list

    def _refine_grads(self, grads_list: List[paddle.Tensor]) -> List[paddle.Tensor]:
        # compute moving average of L^smooth_i(n) - eq.(16)
        losses_seq = list(self.losses.values())
        self.Lf_smooth = (
            self.gamma * self.Lf_smooth + (1 - self.gamma) * losses_seq[0].item()
        )
        self.Lu_smooth = (
            self.gamma * self.Lu_smooth + (1 - self.gamma) * losses_seq[1].item()
        )

        # compute L^smooth_i(kM) - eq.(17)
        if self.step % self.M == 0:
            Lf_smooth_kM = self.Lf_smooth
            Lu_smooth_kM = self.Lu_smooth
        Lf_tilde = self.Lf_smooth / Lf_smooth_kM
        Lu_tilde = self.Lu_smooth / Lu_smooth_kM

        # compute r_i(n) - eq.(18)
        self.Lf_tilde_acc += Lf_tilde
        self.Lu_tilde_acc += Lu_tilde
        rf = Lf_tilde / self.Lf_tilde_acc
        ru = Lu_tilde / self.Lu_tilde_acc

        # compute E(g(n)) - step1(1)
        gf_magn = (grads_list[0] * grads_list[0]).sum().sqrt()
        gu_magn = (grads_list[1] * grads_list[1]).sum().sqrt()
        Eg = (gf_magn + gu_magn) / 2

        # compute \omega_f(n) - step1(2)
        omega_f = (rf * (Eg - gf_magn) + gf_magn) / gf_magn
        omega_u = (ru * (Eg - gu_magn) + gu_magn) / gu_magn

        # compute g_bar(n) - step1(3)
        gf_bar = omega_f * grads_list[0]
        gu_bar = omega_u * grads_list[1]

        # compute gradient projection - step2(1)
        dot_product = (gf_bar * gu_bar).sum()
        if dot_product < 0:
            gu_bar = gu_bar - (dot_product / (gf_bar * gf_bar).sum()) * gf_bar
        grads_list = [gf_bar, gu_bar]

        proj_grads: List[paddle.Tensor] = []
        for j in range(len(self.losses)):
            start_idx = 0
            for idx, var in enumerate(self.model.parameters()):
                grad_shape = var.shape
                flatten_dim = var.numel()
                refined_grad = grads_list[j][start_idx : start_idx + flatten_dim]
                refined_grad = paddle.reshape(refined_grad, grad_shape)
                if len(proj_grads) < self.param_num:
                    proj_grads.append(refined_grad)
                else:
                    proj_grads[idx] += refined_grad
                start_idx += flatten_dim
        return proj_grads

    def _set_grads(self, grads_list: List[paddle.Tensor]) -> None:
        for i, param in enumerate(self.model.parameters()):
            param.grad = grads_list[i]

GradNorm

Bases: LossAggregator

GradNorm loss weighting algorithm.

reference: https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/jaxpi/models.py#L132-L146

\[ \begin{align*} L^t &= \sum_{i=1}^{N}{\tilde{w}_i^t\cdot L_i^t}, \\ \text{where } \\ \tilde{w}_i^0&=1, \\ \tilde{w}_i^t&=\tilde{w}_i^{t-1}\cdot m+w_i^t\cdot (1-m), t\ge1\\ w_i^t&=\dfrac{\overline{\Vert \nabla_{\theta}{L_i^t} \Vert_2}}{\Vert \nabla_{\theta}{L_i^t} \Vert_2}, \\ \overline{\Vert \nabla_{\theta}{L_i^t} \Vert_2}&=\dfrac{1}{N}\sum_{i=1}^N{\Vert \nabla_{\theta}{L_i^t} \Vert_2}, \\ &t \text{ is the training step started from 0}. \end{align*} \]

Parameters:

Name Type Description Default
model Layer

Training model.

required
num_losses int

Number of losses. Defaults to 1.

1
update_freq int

Weight updating frequency. Defaults to 1000.

1000
momentum float

Momentum \(m\) for moving weight. Defaults to 0.9.

0.9
init_weights List[float]

Initial weights list. Defaults to None.

None

Examples:

>>> import paddle
>>> from ppsci.loss import mtl
>>> model = paddle.nn.Linear(3, 4)
>>> loss_aggregator = mtl.GradNorm(model, num_losses=2)
>>> for i in range(5):
...     x1 = paddle.randn([8, 3])
...     x2 = paddle.randn([8, 3])
...     y1 = model(x1)
...     y2 = model(x2)
...     loss1 = paddle.sum(y1)
...     loss2 = paddle.sum((y2 - 2) ** 2)
...     loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
Source code in ppsci/loss/mtl/grad_norm.py
class GradNorm(base.LossAggregator):
    r"""GradNorm loss weighting algorithm.

    reference: [https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/jaxpi/models.py#L132-L146](https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/jaxpi/models.py#L132-L146)

    $$
    \begin{align*}
    L^t &= \sum_{i=1}^{N}{\tilde{w}_i^t\cdot L_i^t}, \\
        \text{where } \\
        \tilde{w}_i^0&=1, \\
        \tilde{w}_i^t&=\tilde{w}_i^{t-1}\cdot m+w_i^t\cdot (1-m), t\ge1\\
        w_i^t&=\dfrac{\overline{\Vert \nabla_{\theta}{L_i^t} \Vert_2}}{\Vert \nabla_{\theta}{L_i^t} \Vert_2}, \\
        \overline{\Vert \nabla_{\theta}{L_i^t} \Vert_2}&=\dfrac{1}{N}\sum_{i=1}^N{\Vert \nabla_{\theta}{L_i^t} \Vert_2}, \\
        &t \text{ is the training step started from 0}.
    \end{align*}
    $$

    Args:
        model (nn.Layer): Training model.
        num_losses (int, optional): Number of losses. Defaults to 1.
        update_freq (int, optional): Weight updating frequency. Defaults to 1000.
        momentum (float, optional): Momentum $m$ for moving weight. Defaults to 0.9.
        init_weights (List[float]): Initial weights list. Defaults to None.

    Examples:
        >>> import paddle
        >>> from ppsci.loss import mtl
        >>> model = paddle.nn.Linear(3, 4)
        >>> loss_aggregator = mtl.GradNorm(model, num_losses=2)
        >>> for i in range(5):
        ...     x1 = paddle.randn([8, 3])
        ...     x2 = paddle.randn([8, 3])
        ...     y1 = model(x1)
        ...     y2 = model(x2)
        ...     loss1 = paddle.sum(y1)
        ...     loss2 = paddle.sum((y2 - 2) ** 2)
        ...     loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
    """
    weight: paddle.Tensor

    def __init__(
        self,
        model: nn.Layer,
        num_losses: int = 1,
        update_freq: int = 1000,
        momentum: float = 0.9,
        init_weights: List[float] = None,
    ) -> None:
        super().__init__(model)
        self.step = 0
        self.num_losses = num_losses
        self.update_freq = update_freq
        self.momentum = momentum
        if init_weights is not None and num_losses != len(init_weights):
            raise ValueError(
                f"Length of init_weights({len(init_weights)}) should be equal to "
                f"num_losses({num_losses})."
            )
        self.register_buffer(
            "weight",
            paddle.to_tensor(init_weights, dtype="float32")
            if init_weights is not None
            else paddle.ones([num_losses]),
        )

    def _compute_weight(self, losses: List["paddle.Tensor"]) -> List["paddle.Tensor"]:
        grad_norms = []
        for loss in losses:
            loss.backward(retain_graph=True)  # NOTE: Keep graph for loss backward
            with paddle.no_grad():
                grad_vector = paddle.concat(
                    [
                        p.grad.reshape([-1])
                        for p in self.model.parameters()
                        if p.grad is not None
                    ]
                )
                grad_norms.append(paddle.linalg.norm(grad_vector, p=2))
                self.model.clear_gradients()

        mean_grad_norm = paddle.mean(paddle.stack(grad_norms))
        weight = [(mean_grad_norm / x) for x in grad_norms]

        return weight

    def __call__(
        self, losses: Dict[str, "paddle.Tensor"], step: int = 0
    ) -> "paddle.Tensor":
        assert len(losses) == self.num_losses, (
            f"Length of given losses({len(losses)}) should be equal to "
            f"num_losses({self.num_losses})."
        )
        self.step = step

        # compute current loss with moving weights
        loss = 0.0
        for i, key in enumerate(losses):
            if i == 0:
                loss = self.weight[i] * losses[key]
            else:
                loss += self.weight[i] * losses[key]

        # update moving weights every 'update_freq' steps
        if self.step % self.update_freq == 0:
            weight = self._compute_weight(list(losses.values()))
            for i in range(self.num_losses):
                self.weight[i].set_value(
                    self.momentum * self.weight[i] + (1 - self.momentum) * weight[i]
                )
            # logger.message(f"weight at step {self.step}: {self.weight.numpy()}")

        return loss

LossAggregator

Bases: Layer

Base class of loss aggregator mainly for multitask learning.

Parameters:

Name Type Description Default
model Layer

Training model.

required
Source code in ppsci/loss/mtl/base.py
class LossAggregator(nn.Layer):
    """Base class of loss aggregator mainly for multitask learning.

    Args:
        model (nn.Layer): Training model.
    """

    def __init__(self, model: nn.Layer) -> None:
        super().__init__()
        self.model = model
        self.step = 0
        self.param_num = 0
        for param in self.model.parameters():
            if not param.stop_gradient:
                self.param_num += 1

    def forward(
        self, losses: Dict[str, "paddle.Tensor"], step: int = 0
    ) -> Union["paddle.Tensor", "LossAggregator"]:
        self.losses = losses
        self.loss_num = len(losses)
        self.step = step
        return self

    def backward(self) -> None:
        raise NotImplementedError(
            f"'backward' should be implemented in subclass {self.__class__.__name__}"
        )

PCGrad

Bases: LossAggregator

Projecting Conflicting Gradients

Gradient Surgery for Multi-Task Learning

Code reference: https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py

Parameters:

Name Type Description Default
model Layer

Training model.

required

Examples:

>>> import paddle
>>> from ppsci.loss import mtl
>>> model = paddle.nn.Linear(3, 4)
>>> loss_aggregator = mtl.PCGrad(model)
>>> for i in range(5):
...     x1 = paddle.randn([8, 3])
...     x2 = paddle.randn([8, 3])
...     y1 = model(x1)
...     y2 = model(x2)
...     loss1 = paddle.sum(y1)
...     loss2 = paddle.sum((y2 - 2) ** 2)
...     loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
Source code in ppsci/loss/mtl/pcgrad.py
class PCGrad(base.LossAggregator):
    r"""
    **P**rojecting **C**onflicting Gradients

    [Gradient Surgery for Multi-Task Learning](https://papers.nips.cc/paper/2020/hash/3fe78a8acf5fda99de95303940a2420c-Abstract.html)

    Code reference: [https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py](https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py)

    Args:
        model (nn.Layer): Training model.

    Examples:
        >>> import paddle
        >>> from ppsci.loss import mtl
        >>> model = paddle.nn.Linear(3, 4)
        >>> loss_aggregator = mtl.PCGrad(model)
        >>> for i in range(5):
        ...     x1 = paddle.randn([8, 3])
        ...     x2 = paddle.randn([8, 3])
        ...     y1 = model(x1)
        ...     y2 = model(x2)
        ...     loss1 = paddle.sum(y1)
        ...     loss2 = paddle.sum((y2 - 2) ** 2)
        ...     loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
    """

    def __init__(self, model: nn.Layer) -> None:
        super().__init__(model)
        self._zero = paddle.zeros([])

    def backward(self) -> None:
        # shuffle order of losses
        keys = list(self.losses.keys())
        np.random.shuffle(keys)
        self.losses = {key: self.losses[key] for key in keys}

        grads_list = self._compute_grads()
        with paddle.no_grad():
            refined_grads = self._refine_grads(grads_list)
            self._set_grads(refined_grads)

    def _compute_grads(self) -> List[paddle.Tensor]:
        # compute all gradients derived by each loss
        grads_list = []  # num_params x num_losses
        for key in self.losses:
            # backward with current loss
            self.losses[key].backward()
            grads_list.append(
                paddle.concat(
                    [
                        param.grad.clone().reshape([-1])
                        for param in self.model.parameters()
                        if param.grad is not None
                    ],
                    axis=0,
                )
            )
            # clear gradients for current loss for not affecting other loss
            self.model.clear_gradients()

        return grads_list

    def _refine_grads(self, grads_list: List[paddle.Tensor]) -> List[paddle.Tensor]:
        def proj_grad(grad: paddle.Tensor):
            for k in range(self.loss_num):
                inner_product = paddle.sum(grad * grads_list[k])
                proj_direction = inner_product / paddle.sum(
                    grads_list[k] * grads_list[k]
                )
                grad = grad - paddle.minimum(proj_direction, self._zero) * grads_list[k]
            return grad

        grads_list = [proj_grad(grad) for grad in grads_list]

        # Unpack flattened projected gradients back to their original shapes.
        proj_grads: List[paddle.Tensor] = []
        for j in range(self.loss_num):
            start_idx = 0
            for idx, var in enumerate(self.model.parameters()):
                grad_shape = var.shape
                flatten_dim = var.numel()
                refined_grad = grads_list[j][start_idx : start_idx + flatten_dim]
                refined_grad = paddle.reshape(refined_grad, grad_shape)
                if len(proj_grads) < self.param_num:
                    proj_grads.append(refined_grad)
                else:
                    proj_grads[idx] += refined_grad
                start_idx += flatten_dim
        return proj_grads

    def _set_grads(self, grads_list: List[paddle.Tensor]) -> None:
        for i, param in enumerate(self.model.parameters()):
            param.grad = grads_list[i]

Relobralo

Bases: Layer

Relative Loss Balancing with Random Lookback

Multi-Objective Loss Balancing for Physics-Informed Deep Learning

Parameters:

Name Type Description Default
num_losses int

Number of losses.

required
alpha float

Ability for remembering past in paper. Defaults to 0.95.

0.95
beta float

Parameter for generating \(\rho\) from bernoulli distribution, and \(E[\rho](=\beta)\) should be close to 1. Defaults to 0.99.

0.99
tau float

Temperature factor. Equivalent to softmax when \(\tau\)=1.0, equivalent to argmax when \(\tau\)=0. Defaults to 1.0.

1.0
eps float

\(\epsilon\) to avoid divided by 0 in losses. Defaults to 1e-8.

1e-08

Examples:

>>> import paddle
>>> from ppsci.loss import mtl
>>> model = paddle.nn.Linear(3, 4)
>>> loss_aggregator = mtl.Relobralo(num_losses=2)
>>> for i in range(5):
...     x1 = paddle.randn([8, 3])
...     x2 = paddle.randn([8, 3])
...     y1 = model(x1)
...     y2 = model(x2)
...     loss1 = paddle.sum(y1)
...     loss2 = paddle.sum((y2 - 2) ** 2)
...     loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
Source code in ppsci/loss/mtl/relobralo.py
class Relobralo(nn.Layer):
    r"""
    **Re**lative **Lo**ss **B**alancing with **Ra**ndom **Lo**okback

    [Multi-Objective Loss Balancing for Physics-Informed Deep Learning](https://arxiv.org/abs/2110.09813)

    Args:
        num_losses (int): Number of losses.
        alpha (float, optional): Ability for remembering past in paper. Defaults to 0.95.
        beta (float, optional): Parameter for generating $\rho$ from bernoulli distribution,
            and $E[\rho](=\beta)$ should be close to 1. Defaults to 0.99.
        tau (float, optional): Temperature factor. Equivalent to softmax when $\tau$=1.0,
            equivalent to argmax when $\tau$=0. Defaults to 1.0.
        eps (float, optional): $\epsilon$ to avoid divided by 0 in losses. Defaults to 1e-8.

    Examples:
        >>> import paddle
        >>> from ppsci.loss import mtl
        >>> model = paddle.nn.Linear(3, 4)
        >>> loss_aggregator = mtl.Relobralo(num_losses=2)
        >>> for i in range(5):
        ...     x1 = paddle.randn([8, 3])
        ...     x2 = paddle.randn([8, 3])
        ...     y1 = model(x1)
        ...     y2 = model(x2)
        ...     loss1 = paddle.sum(y1)
        ...     loss2 = paddle.sum((y2 - 2) ** 2)
        ...     loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
    """

    def __init__(
        self,
        num_losses: int,
        alpha: float = 0.95,
        beta: float = 0.99,
        tau: float = 1.0,
        eps: float = 1e-8,
    ) -> None:
        super().__init__()
        self.step = 0
        self.num_losses: int = num_losses
        self.alpha: float = alpha
        self.beta: float = beta
        self.tau: float = tau
        self.eps: float = eps
        self.register_buffer("losses_init", paddle.zeros([self.num_losses]))
        self.register_buffer("losses_prev", paddle.zeros([self.num_losses]))
        self.register_buffer("lmbda", paddle.ones([self.num_losses]))

    def _softmax(self, vec: "paddle.Tensor") -> "paddle.Tensor":
        max_item = vec.max()
        result = paddle.exp(vec - max_item) / paddle.exp(vec - max_item).sum()
        return result

    def _compute_bal(
        self, losses_vec1: "paddle.Tensor", losses_vec2: "paddle.Tensor"
    ) -> "paddle.Tensor":
        return self.num_losses * (
            self._softmax(losses_vec1 / (self.tau * losses_vec2 + self.eps))
        )

    def __call__(
        self, losses: Dict[str, "paddle.Tensor"], step: int = 0
    ) -> "paddle.Tensor":
        assert len(losses) == self.num_losses, (
            f"Length of given losses({len(losses)}) should be equal to "
            f"num_losses({self.num_losses})."
        )
        self.step = step
        losses_stacked = paddle.stack(list(losses.values()))  # [num_losses, ]

        if self.step == 0:
            loss = losses_stacked.sum()
            with paddle.no_grad():
                paddle.assign(losses_stacked.detach(), self.losses_init)
        else:
            with paddle.no_grad():
                # 1. update lambda_hist
                rho = paddle.bernoulli(paddle.to_tensor(self.beta))
                lmbda_hist = rho * self.lmbda + (1 - rho) * self._compute_bal(
                    losses_stacked, self.losses_init
                )

                # 2. update lambda
                paddle.assign(
                    self.alpha * lmbda_hist
                    + (1 - self.alpha)
                    * self._compute_bal(losses_stacked, self.losses_prev),
                    self.lmbda,
                )

            # 3. compute reweighted total loss with lambda
            loss = (losses_stacked * self.lmbda).sum()

        # update losses_prev at the end of each step
        with paddle.no_grad():
            paddle.assign(losses_stacked.detach(), self.losses_prev)

        return loss

Sum

Bases: LossAggregator

Default loss aggregator which do simple summation for given losses as below.

\[ loss = \sum_i^N losses_i \]
Source code in ppsci/loss/mtl/sum.py
class Sum(LossAggregator):
    r"""
    **Default loss aggregator** which do simple summation for given losses as below.

    $$
    loss = \sum_i^N losses_i
    $$
    """

    def __init__(self) -> None:
        self.step = 0

    def __call__(
        self, losses: Dict[str, "paddle.Tensor"], step: int = 0
    ) -> "paddle.Tensor":
        assert (
            len(losses) > 0
        ), f"Number of given losses({len(losses)}) can not be empty."
        self.step = step

        total_loss = 0.0
        for i, key in enumerate(losses):
            if i == 0:
                total_loss = losses[key]
            else:
                total_loss += losses[key]

        return total_loss