Losses
log_mag
log_mag (x:jax.Array, eps:float=1e-10)
to_db
to_db (x:jax.Array, eps:float=1e-10)
db_mag_loss
db_mag_loss (predicted:jax.Array, target:jax.Array, eps:float=1e-10, distance:str='l1')
*Calculate the mean L1 or L2 loss between the decibel magnitudes of two FFT signals.
Type | Default | Details | |
---|---|---|---|
predicted | Array | Complex-valued FFT of the predicted signal | |
target | Array | Complex-valued FFT of the target signal | |
eps | float | 1e-10 | Small constant to avoid log(0) |
distance | str | l1 | Distance metric: ‘l1’ or ‘l2’ |
Returns | Array |
The spectral log magnitude loss is defined as:
\[ \mathcal{L} = \frac{1}{N}|| \log({Y}) - \log({\hat{Y}}) ||_p \]
log_mag_loss
log_mag_loss (pred:jax.Array, target:jax.Array, eps:float=1e-10, distance:str='l1')
Spectral log magtinude loss but for a fft of a signal See Arik et al., 2018
Type | Default | Details | |
---|---|---|---|
pred | Array | complex valued fft of the signal | |
target | Array | complex valued fft of the signal | |
eps | float | 1e-10 | |
distance | str | l1 |
The spectral convergence loss is defined as:
\[ \mathcal{L} = \frac{| {Y} - {\hat{Y}} ||^2_2}{|{Y}||^2_2} \]
spectral_convergence_loss
spectral_convergence_loss (pred:jax.Array, target:jax.Array)
Spectral convergence loss but for a fft of a signal See Arik et al., 2018
Type | Details | |
---|---|---|
pred | Array | complex valued fft of the signal |
target | Array | complex valued fft of the signal |
Loss functions for training
fft_loss
fft_loss (params, state:flax.training.train_state.TrainState, x:jax.Array, y:jax.Array, dropout_key:jax.Array=None, norm:str='layer')
Type | Default | Details | |
---|---|---|---|
params | |||
state | TrainState | ||
x | Array | input sequence (batch, timesteps, grid_size, 1) zeros in our case | |
y | Array | output sequence (batch, timesteps, grid_size, 1) u in our case | |
dropout_key | Array | None | |
norm | str | layer | |
Returns | Tuple | loss, pred |
mse_loss
mse_loss (params, state:flax.training.train_state.TrainState, x:jax.Array, y:jax.Array, dropout_key:jax.Array=None, norm:str='layer')
Type | Default | Details | |
---|---|---|---|
params | |||
state | TrainState | ||
x | Array | input sequence (batch, timesteps, grid_size, 1) zeros in our case | |
y | Array | output sequence (batch, timesteps, grid_size, 1) u in our case | |
dropout_key | Array | None | |
norm | str | layer | |
Returns | Tuple | loss, pred |
lindyn_loss
lindyn_loss (params, state:flax.training.train_state.TrainState, x:jax.Array, y:jax.Array, encdec_weight:float=1.0, lindyn_weight:float=0.01, pred_weight:float=1.0, dropout_key:jax.Array=None, norm:str='layer')
Type | Default | Details | |
---|---|---|---|
params | |||
state | TrainState | ||
x | Array | (B, T, G, C) pde solution | |
y | Array | (B, T, G, C) shifted pde solution | |
encdec_weight | float | 1.0 | |
lindyn_weight | float | 0.01 | |
pred_weight | float | 1.0 | |
dropout_key | Array | None | |
norm | str | layer | |
Returns | Tuple | loss, pred |