= 0.0
phase = jax.random.uniform(jax.random.PRNGKey(0), shape=(10,)) * 100
omegas = 0.99
mag
def osc_bank(t, omegas):
return mag * jnp.sin(omegas[..., None] * jnp.pi * 2 * t[None] + phase)
Losses
A collection of losses including:
- Spectral log magnitude loss
- Spectral convergence loss
- Wasserstein loss
wasserstein_1d
wasserstein_1d (u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True)
*This is a port of the wasserstein_1d function from POT in JAX. Computes the 1 dimensional OT loss [15] between two (batched) empirical distributions
.. math: OT_{loss} = _0^1 |cdf_u^{-1}(q) - cdf_v{-1}(q)|p dq
It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them.
This function should be preferred to emd_1d
whenever the backend is different to numpy, and when gradients over either sample positions or weights are required.*
Type | Default | Details | |
---|---|---|---|
u_values | |||
v_values | |||
u_weights | NoneType | None | |
v_weights | NoneType | None | |
p | int | 1 | |
require_sort | bool | True | |
Returns | cost: float/array-like, shape (…) | the batched EMD |
quantile_function
quantile_function (qs, cws, xs)
Computes the quantile function of an empirical distribution
Type | Details | |
---|---|---|
qs | ||
cws | ||
xs | ||
Returns | q: array-like, shape (…, n) | The quantiles of the distribution |
compute_mag
compute_mag (x:jax.Array)
Type | Details | |
---|---|---|
x | Array | (b, t) |
Returns | Array |
spectral_wasserstein
spectral_wasserstein (x, y, squared=True, is_mag=False)
log_mag_loss
log_mag_loss (pred:jax.Array, target:jax.Array, eps:float=1e-10, distance:str='l1')
Spectral log magtinude loss but for a fft of a signal See Arik et al., 2018
Type | Default | Details | |
---|---|---|---|
pred | Array | complex valued fft of the signal | |
target | Array | complex valued fft of the signal | |
eps | float | 1e-10 | |
distance | str | l1 |
log_mag
log_mag (x:jax.Array, eps:float=1e-10)
= jnp.linspace(0, 1, 1000)
t = osc_bank(t, omegas)
gt_osc_values # print(gt_osc_values)
print(gt_osc_values.shape)
(10, 1000)
= jax.vmap(spectral_wasserstein)(gt_osc_values, gt_osc_values)
a print(a)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
def loss_fn(omega):
= osc_bank(t, omega)
pred_osc_values = compute_mag(gt_osc_values)
x_fft = compute_mag(pred_osc_values)
y_fft = jnp.mean((x_fft - y_fft) ** 2)
l2_mag_loss return l2_mag_loss
def ot_loss_fn(omega):
= osc_bank(t, omega)
pred_osc_values
= jnp.mean(
ot_loss
jax.vmap(spectral_wasserstein)(gt_osc_values, pred_osc_values),
)return ot_loss
= compute_mag(gt_osc_values) ** 2
x_fft 5].T) plt.plot(x_fft[:
= jnp.linspace(-50, 50, 100)
ranges = omegas + ranges[:, None]
omegas_scan
# print(omegas_scan.shape)
= jax.vmap(jax.value_and_grad(loss_fn))(omegas_scan)
loss, grad = jax.vmap(jax.value_and_grad(ot_loss_fn))(omegas_scan)
loss_ot, grad_ot
print(loss.shape, loss.dtype)
print(loss_ot.shape, loss_ot.dtype)
plt.plot(ranges, loss_ot)
(100,) float32
(100,) float32
= jax.random.uniform(jax.random.PRNGKey(0), shape=(10,)) * 1000
omegas_gt = omegas_gt * 1
omegas_pred
= osc_bank(t, omegas_pred).mean(axis=0)
pred_osc_values = osc_bank(t, omegas_gt)
gt_osc_values
= compute_mag(gt_osc_values.mean(axis=0))
x_mag = compute_mag(pred_osc_values)
y_mag
plt.semilogx(x_mag) plt.semilogx(y_mag)
spectral_convergence_loss
spectral_convergence_loss (pred:jax.Array, target:jax.Array)
Spectral convergence loss but for a fft of a signal See Arik et al., 2018
Type | Details | |
---|---|---|
pred | Array | magnitude of the fft of the predicted signal |
target | Array | magnitude of the fft of the target signal |