Losses
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax.numpy as jnp
from jax import custom_vjp
from typing import Optional, Tuple
import optax
from einops import rearrange
:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def to_db(
x: jnp.ndarray,float = 1e-10,
eps:
):return 20 * jnp.log10(x + eps)
def log_mag(
x: jnp.ndarray,float = 1e-10,
eps:
):return jnp.log(jnp.abs(x) + eps)
:::
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def db_mag_loss(
# Complex-valued FFT of the predicted signal
predicted: jnp.ndarray, # Complex-valued FFT of the target signal
target: jnp.ndarray, float = 1e-10, # Small constant to avoid log(0)
eps: str = "l1", # Distance metric: 'l1' or 'l2'
distance: -> jnp.ndarray:
) """
Calculate the mean L1 or L2 loss between the decibel magnitudes of two FFT signals.
:param predicted: FFT of the predicted signal.
:param target: FFT of the target signal.
:param epsilon: Small constant for numerical stability in log computation.
:param distance_metric: Type of distance metric ('l1' or 'l2').
:return: Mean L1 or L2 loss in decibel magnitude.
"""
# Convert to decibel magnitude
= to_db(jnp.abs(predicted), eps)
pred_db = to_db(jnp.abs(target), eps)
target_db
# Compute loss based on the specified distance metric
if distance == "l1":
return jnp.mean(jnp.abs(pred_db - target_db))
elif distance == "l2":
return jnp.mean((pred_db - target_db) ** 2)
else:
raise ValueError("Invalid distance metric. Choose 'l1' or 'l2'.")
:::
The spectral log magnitude loss is defined as:
\[ \mathcal{L} = \frac{1}{N}|| \log({Y}) - \log({\hat{Y}}) ||_p \]
::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def log_mag_loss(
# complex valued fft of the signal
pred: jnp.ndarray, # complex valued fft of the signal
target: jnp.ndarray, float = 1e-10,
eps: str = "l1",
distance:
):"""
Spectral log magtinude loss but for a fft of a signal
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
"""
= log_mag(pred, eps)
pred_log_mag = log_mag(target, eps)
target_log_mag
# l1 spectral log magnitude loss
if distance == "l1":
return jnp.mean(jnp.abs(pred_log_mag - target_log_mag))
# l2 spectral log magnitude loss
elif distance == "l2":
return jnp.mean((pred_log_mag - target_log_mag) ** 2)
else:
raise ValueError("Invalid distance metric. Choose 'l1' or 'l2'.")
:::
The spectral convergence loss is defined as:
\[ \mathcal{L} = \frac{| {Y} - {\hat{Y}} ||^2_2}{|{Y}||^2_2} \]
::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def spectral_convergence_loss(
# complex valued fft of the signal
pred: jnp.ndarray, # complex valued fft of the signal
target: jnp.ndarray,
):"""
Spectral convergence loss but for a fft of a signal
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
"""
# l2 spectral convergence loss
return jnp.linalg.norm(jnp.abs(target) - jnp.abs(pred)) / jnp.linalg.norm(
abs(target)
jnp. )
:::
Loss functions for training
::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
from flax.training import train_state
from typing import Tuple
:::
::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def mse_loss(
params,
state: train_state.TrainState,# input sequence (batch, timesteps, grid_size, 1) zeros in our case
x: jnp.ndarray, # output sequence (batch, timesteps, grid_size, 1) u in our case
y: jnp.ndarray, = None,
dropout_key: jnp.ndarray str = "layer",
norm: -> Tuple[jnp.ndarray, jnp.ndarray]: # loss, pred
) if norm in ["layer"]:
= state.apply_fn({"params": params}, x, rngs={"dropout": dropout_key})
pred vars = None
else:
vars = state.apply_fn(
pred, "params": params, "batch_stats": state.batch_stats},
{
x,={"dropout": dropout_key},
rngs=["batch_stats"],
mutable
)
= jnp.mean((pred - y) ** 2)
mse_loss return mse_loss, (pred, vars)
def fft_loss(
params,
state: train_state.TrainState,# input sequence (batch, timesteps, grid_size, 1) zeros in our case
x: jnp.ndarray, # output sequence (batch, timesteps, grid_size, 1) u in our case
y: jnp.ndarray, = None,
dropout_key: jnp.ndarray str = "layer",
norm: -> Tuple[jnp.ndarray, jnp.ndarray]: # loss, pred
) if norm in ["layer"]:
= state.apply_fn(
pred "params": params},
{
x,={"dropout": dropout_key},
rngs
)vars = None
else:
vars = state.apply_fn(
pred, "params": params, "batch_stats": state.batch_stats},
{
x,={"dropout": dropout_key},
rngs=["batch_stats"],
mutable
)
# take the fft of the predicted and target signals
= jnp.fft.rfft(pred, axis=-3)
pred_fft = jnp.fft.rfft(y, axis=-3)
y_fft
# huber_loss = jnp.mean(optax.huber_loss(pred, y, delta=0.1))
= jnp.mean((pred - y) ** 2)
mse_loss
# magnitude mse loss
= jnp.mean(spectral_convergence_loss(pred_fft, y_fft))
spec_conv_loss # mag_mse_loss = log_mag_loss(pred_fft, y_fft, distance="l2")
return spec_conv_loss + mse_loss, (pred, vars)
:::
::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def lindyn_loss(
params,
state: train_state.TrainState,# (B, T, G, C) pde solution
x: jnp.ndarray, # (B, T, G, C) shifted pde solution
y: jnp.ndarray, float = 1.0,
encdec_weight: float = 0.01,
lindyn_weight: float = 1.0,
pred_weight: = None,
dropout_key: jnp.ndarray str = "layer",
norm: -> Tuple[jnp.ndarray, jnp.ndarray]: # loss, pred
) = {"params": params}
params
= jnp.concatenate([x, y], axis=1)
full_x # encode the initial state
= state.apply_fn(
encoded
params,
full_x,="encode",
method
)
= state.apply_fn(
decoded
params,
encoded,="decode",
method
)
# advance the initial state
# states are [1, n+1]
= state.apply_fn(
states
params,0],
encoded[:, ="advance",
method
)
# decode the encoded states
= state.apply_fn(
pred
params,
states,="decode",
method
)
# reconstruction loss between the initial state encoded and decoded
= jnp.mean((decoded - full_x) ** 2)
reconstruction_loss
# consistency loss between the predicted encoded states and the gt encoded states
# compare only [1, n] with [1, n]
= jnp.mean((states - encoded[:, 1:]) ** 2)
lindyn_mse_loss
# prediction loss
= jnp.mean((pred - y) ** 2)
pred_mse_loss
return (
* pred_mse_loss
pred_weight + encdec_weight * reconstruction_loss
+ lindyn_weight * lindyn_mse_loss,
None),
(pred, )
:::