Losses


source

log_mag

 log_mag (x:jax.Array, eps:float=1e-10)

source

to_db

 to_db (x:jax.Array, eps:float=1e-10)

source

db_mag_loss

 db_mag_loss (predicted:jax.Array, target:jax.Array, eps:float=1e-10,
              distance:str='l1')

*Calculate the mean L1 or L2 loss between the decibel magnitudes of two FFT signals.

param predicted: FFT of the predicted signal. :param target: FFT of the target signal. :param epsilon: Small constant for numerical stability in log computation. :param distance_metric: Type of distance metric (‘l1’ or ‘l2’). :return: Mean L1 or L2 loss in decibel magnitude.*
Type Default Details
predicted Array Complex-valued FFT of the predicted signal
target Array Complex-valued FFT of the target signal
eps float 1e-10 Small constant to avoid log(0)
distance str l1 Distance metric: ‘l1’ or ‘l2’
Returns Array

The spectral log magnitude loss is defined as:

\[ \mathcal{L} = \frac{1}{N}|| \log({Y}) - \log({\hat{Y}}) ||_p \]


source

log_mag_loss

 log_mag_loss (pred:jax.Array, target:jax.Array, eps:float=1e-10,
               distance:str='l1')

Spectral log magtinude loss but for a fft of a signal See Arik et al., 2018

Type Default Details
pred Array complex valued fft of the signal
target Array complex valued fft of the signal
eps float 1e-10
distance str l1

The spectral convergence loss is defined as:

\[ \mathcal{L} = \frac{| {Y} - {\hat{Y}} ||^2_2}{|{Y}||^2_2} \]


source

spectral_convergence_loss

 spectral_convergence_loss (pred:jax.Array, target:jax.Array)

Spectral convergence loss but for a fft of a signal See Arik et al., 2018

Type Details
pred Array complex valued fft of the signal
target Array complex valued fft of the signal

Loss functions for training


source

fft_loss

 fft_loss (params, state:flax.training.train_state.TrainState,
           x:jax.Array, y:jax.Array, dropout_key:jax.Array=None,
           norm:str='layer')
Type Default Details
params
state TrainState
x Array input sequence (batch, timesteps, grid_size, 1) zeros in our case
y Array output sequence (batch, timesteps, grid_size, 1) u in our case
dropout_key Array None
norm str layer
Returns Tuple loss, pred

source

mse_loss

 mse_loss (params, state:flax.training.train_state.TrainState,
           x:jax.Array, y:jax.Array, dropout_key:jax.Array=None,
           norm:str='layer')
Type Default Details
params
state TrainState
x Array input sequence (batch, timesteps, grid_size, 1) zeros in our case
y Array output sequence (batch, timesteps, grid_size, 1) u in our case
dropout_key Array None
norm str layer
Returns Tuple loss, pred

source

lindyn_loss

 lindyn_loss (params, state:flax.training.train_state.TrainState,
              x:jax.Array, y:jax.Array, encdec_weight:float=1.0,
              lindyn_weight:float=0.01, pred_weight:float=1.0,
              dropout_key:jax.Array=None, norm:str='layer')
Type Default Details
params
state TrainState
x Array (B, T, G, C) pde solution
y Array (B, T, G, C) shifted pde solution
encdec_weight float 1.0
lindyn_weight float 0.01
pred_weight float 1.0
dropout_key Array None
norm str layer
Returns Tuple loss, pred