Losses

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import jax.numpy as jnp
from jax import custom_vjp
from typing import Optional, Tuple
import optax
from einops import rearrange

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def to_db(
    x: jnp.ndarray,
    eps: float = 1e-10,
):
    return 20 * jnp.log10(x + eps)


def log_mag(
    x: jnp.ndarray,
    eps: float = 1e-10,
):
    return jnp.log(jnp.abs(x) + eps)

:::

::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def db_mag_loss(
    predicted: jnp.ndarray,  # Complex-valued FFT of the predicted signal
    target: jnp.ndarray,  # Complex-valued FFT of the target signal
    eps: float = 1e-10,  # Small constant to avoid log(0)
    distance: str = "l1",  # Distance metric: 'l1' or 'l2'
) -> jnp.ndarray:
    """
    Calculate the mean L1 or L2 loss between the decibel magnitudes of two FFT signals.

    :param predicted: FFT of the predicted signal.
    :param target: FFT of the target signal.
    :param epsilon: Small constant for numerical stability in log computation.
    :param distance_metric: Type of distance metric ('l1' or 'l2').
    :return: Mean L1 or L2 loss in decibel magnitude.
    """

    # Convert to decibel magnitude
    pred_db = to_db(jnp.abs(predicted), eps)
    target_db = to_db(jnp.abs(target), eps)

    # Compute loss based on the specified distance metric
    if distance == "l1":
        return jnp.mean(jnp.abs(pred_db - target_db))
    elif distance == "l2":
        return jnp.mean((pred_db - target_db) ** 2)
    else:
        raise ValueError("Invalid distance metric. Choose 'l1' or 'l2'.")

:::

The spectral log magnitude loss is defined as:

\[ \mathcal{L} = \frac{1}{N}|| \log({Y}) - \log({\hat{Y}}) ||_p \]

::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def log_mag_loss(
    pred: jnp.ndarray,  # complex valued fft of the signal
    target: jnp.ndarray,  # complex valued fft of the signal
    eps: float = 1e-10,
    distance: str = "l1",
):
    """
    Spectral log magtinude loss but for a fft of a signal
    See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
    """

    pred_log_mag = log_mag(pred, eps)
    target_log_mag = log_mag(target, eps)

    # l1 spectral log magnitude loss
    if distance == "l1":
        return jnp.mean(jnp.abs(pred_log_mag - target_log_mag))
    # l2 spectral log magnitude loss
    elif distance == "l2":
        return jnp.mean((pred_log_mag - target_log_mag) ** 2)
    else:
        raise ValueError("Invalid distance metric. Choose 'l1' or 'l2'.")

:::

The spectral convergence loss is defined as:

\[ \mathcal{L} = \frac{| {Y} - {\hat{Y}} ||^2_2}{|{Y}||^2_2} \]

::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def spectral_convergence_loss(
    pred: jnp.ndarray,  # complex valued fft of the signal
    target: jnp.ndarray,  # complex valued fft of the signal
):
    """
    Spectral convergence loss but for a fft of a signal
    See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
    """
    # l2 spectral convergence loss
    return jnp.linalg.norm(jnp.abs(target) - jnp.abs(pred)) / jnp.linalg.norm(
        jnp.abs(target)
    )

:::

Loss functions for training

::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

from flax.training import train_state
from typing import Tuple

:::

::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def mse_loss(
    params,
    state: train_state.TrainState,
    x: jnp.ndarray,  # input sequence (batch, timesteps, grid_size, 1) zeros in our case
    y: jnp.ndarray,  # output sequence (batch, timesteps, grid_size, 1) u in our case
    dropout_key: jnp.ndarray = None,
    norm: str = "layer",
) -> Tuple[jnp.ndarray, jnp.ndarray]:  # loss, pred
    if norm in ["layer"]:
        pred = state.apply_fn({"params": params}, x, rngs={"dropout": dropout_key})
        vars = None
    else:
        pred, vars = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            x,
            rngs={"dropout": dropout_key},
            mutable=["batch_stats"],
        )

    mse_loss = jnp.mean((pred - y) ** 2)
    return mse_loss, (pred, vars)


def fft_loss(
    params,
    state: train_state.TrainState,
    x: jnp.ndarray,  # input sequence (batch, timesteps, grid_size, 1) zeros in our case
    y: jnp.ndarray,  # output sequence (batch, timesteps, grid_size, 1) u in our case
    dropout_key: jnp.ndarray = None,
    norm: str = "layer",
) -> Tuple[jnp.ndarray, jnp.ndarray]:  # loss, pred
    if norm in ["layer"]:
        pred = state.apply_fn(
            {"params": params},
            x,
            rngs={"dropout": dropout_key},
        )
        vars = None
    else:
        pred, vars = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            x,
            rngs={"dropout": dropout_key},
            mutable=["batch_stats"],
        )

    # take the fft of the predicted and target signals
    pred_fft = jnp.fft.rfft(pred, axis=-3)
    y_fft = jnp.fft.rfft(y, axis=-3)

    # huber_loss = jnp.mean(optax.huber_loss(pred, y, delta=0.1))
    mse_loss = jnp.mean((pred - y) ** 2)

    # magnitude mse loss
    spec_conv_loss = jnp.mean(spectral_convergence_loss(pred_fft, y_fft))
    # mag_mse_loss = log_mag_loss(pred_fft, y_fft, distance="l2")
    return spec_conv_loss + mse_loss, (pred, vars)

:::

::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def lindyn_loss(
    params,
    state: train_state.TrainState,
    x: jnp.ndarray,  # (B, T, G, C) pde solution
    y: jnp.ndarray,  # (B, T, G, C) shifted pde solution
    encdec_weight: float = 1.0,
    lindyn_weight: float = 0.01,
    pred_weight: float = 1.0,
    dropout_key: jnp.ndarray = None,
    norm: str = "layer",
) -> Tuple[jnp.ndarray, jnp.ndarray]:  # loss, pred
    params = {"params": params}

    full_x = jnp.concatenate([x, y], axis=1)
    # encode the initial state
    encoded = state.apply_fn(
        params,
        full_x,
        method="encode",
    )

    decoded = state.apply_fn(
        params,
        encoded,
        method="decode",
    )

    # advance the initial state
    # states are [1, n+1]
    states = state.apply_fn(
        params,
        encoded[:, 0],
        method="advance",
    )

    # decode the encoded states
    pred = state.apply_fn(
        params,
        states,
        method="decode",
    )

    # reconstruction loss between the initial state encoded and decoded
    reconstruction_loss = jnp.mean((decoded - full_x) ** 2)

    # consistency loss between the predicted encoded states and the gt encoded states
    # compare only [1, n] with [1, n]
    lindyn_mse_loss = jnp.mean((states - encoded[:, 1:]) ** 2)

    # prediction loss
    pred_mse_loss = jnp.mean((pred - y) ** 2)

    return (
        pred_weight * pred_mse_loss
        + encdec_weight * reconstruction_loss
        + lindyn_weight * lindyn_mse_loss,
        (pred, None),
    )

:::