from functools import partial
from physmodjax.models.recurrent import (
DiscreteModalDynamicsAngleMag,
LRUDynamicsVarying,
)Autoencoder models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from collections.abc import Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from jax.typing import ArrayLike
from physmodjax.models.conv import ConvDecoder, ConvEncoder
from physmodjax.models.mlp import ModulatedSiren
from physmodjax.models.recurrent import LRUDynamics
from physmodjax.utils.data import create_grid2025-01-07 11:25:03.453384: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.85). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
:::
2D Convolutional Autoencoder with Linear Dynamics
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FourierAutoencoder2D(nn.Module):
dynamics_model: nn.Module
d_vars: int
d_model: tuple[int, int]
norm: str = "layer"
training: bool = True
use_positions: bool = False
n_modes: int = 20
def setup(self):
self.dynamics = self.dynamics_model()
if self.use_positions:
self.grid = create_grid(self.d_model[1], self.d_model[0])
def __call__(
self,
x: jnp.ndarray, # (W, H, C) or (T, W, H, C)
) -> jnp.ndarray:
z = self.encode(x)
z = self.advance(z, x)
return self.decode(z)
def encode(
self,
x: jnp.ndarray, # (W, H, C) or (T, W, H, C)
) -> jnp.ndarray: # (hidden_dim) or (T, hidden_dim) complex
"""
Spatial encoding of the input data.
"""
if self.use_positions:
x = jnp.concatenate([x, self.grid], axis=-1)
*_, W, H, C = x.shape
x = jnp.fft.rfft2(x, s=(W * 2 - 1, H * 2 - 1), axes=(-3, -2), norm="ortho")
x = x[..., : self.n_modes, : self.n_modes, :]
if len(x.shape) == 3:
x = rearrange(x, "w h c -> (w h c)")
elif len(x.shape) == 4:
x = rearrange(x, "t w h c -> t (w h c)")
return x
def advance(
self,
z: jnp.ndarray, # (hidden_dim,) complex
steps: jnp.ndarray, # (T,)
) -> jnp.ndarray: # (T, hidden_dim,)
# z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, steps.shape[0])
# z = jnp.concatenate([z.real, z.imag], axis=-1)
return z
def decode(
self,
z: jnp.ndarray, # (hidden_dim) or (T, hidden_dim) complex
) -> jnp.ndarray: # (W, H, C) or (T, W, H, C) real
w_modes = self.n_modes
h_modes = self.n_modes
if len(z.shape) == 1:
z = rearrange(z, "(w h c) -> w h c", w=w_modes, h=h_modes, c=self.d_vars)
elif len(z.shape) == 2:
z = rearrange(
z, "t (w h c) -> t w h c", w=w_modes, h=h_modes, c=self.d_vars
)
z = jnp.fft.irfft2(
z, s=(self.d_model[0], self.d_model[1]), axes=(-3, -2), norm="ortho"
)
return z
BatchedFourierAutoencoder2D = nn.vmap(
FourierAutoencoder2D,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
):::
::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
d_hidden = 128
B, T, H, W, C = 5, 16, 41, 37, 2
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))
dynamics_model = partial(
LRUDynamics,
d_hidden=(20 * 20 * 2),
r_min=0.9,
r_max=1.0,
max_phase=jnp.pi * 2,
clip_eigs=False,
prepend_ones=False,
)
model = BatchedFourierAutoencoder2D(
dynamics_model=dynamics_model,
d_vars=C,
d_model=(H, W),
norm="layer",
training=True,
n_modes=20,
)
vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, H, W, C):::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DenseKoopmanAutoencoder2D(nn.Module):
"""
Koopman Dense Autoencoder
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Module
d_vars: int
d_model: tuple[int, int]
n_steps: int
norm: str = "layer"
training: bool = True
use_positions: bool = False
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
if self.use_positions:
self.grid = create_grid(self.d_model[1], self.d_model[0])
def __call__(
self,
x: jnp.ndarray, # (W, H, C)
) -> jnp.ndarray:
z = self.encode(x[0])
z = self.advance(z)
return self.decode(z)
def encode(
self,
x: jnp.ndarray, # (W, H, C) or (T, W, H, C)
) -> jnp.ndarray: # (hidden_dim) or (T, hidden_dim)
if self.use_positions:
x = jnp.concatenate([x, self.grid], axis=-1)
if len(x.shape) == 3:
x = rearrange(x, "w h c -> (w h c)")
elif len(x.shape) == 4:
x = rearrange(x, "t w h c -> t (w h c)")
return self.encoder(x)
def advance(
self,
z: jnp.ndarray, # (hidden_dim,)
) -> jnp.ndarray: # (T, hidden_dim,)
z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
return z
def decode(
self,
z: jnp.ndarray, # (hidden_dim) or (T, hidden_dim)
) -> jnp.ndarray: # (W, H, C) or (T, W, H, C)
z = self.decoder(z)
if len(z.shape) == 1:
z = rearrange(
z,
"(w h c) -> w h c",
w=self.d_model[0],
h=self.d_model[1],
c=self.d_vars,
)
elif len(z.shape) == 2:
z = rearrange(
z,
"t (w h c) -> t w h c",
w=self.d_model[0],
h=self.d_model[1],
c=self.d_vars,
)
return z
BatchedDenseKoopmanAutoencoder2D = nn.vmap(
DenseKoopmanAutoencoder2D,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
):::
::: {#cell-9 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
n_steps = 16
d_hidden = 128
B, T, H, W, C = 5, n_steps, 41, 37, 3
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))
encoder_model = partial(nn.Dense, features=d_hidden * 2)
decoder_model = partial(nn.Dense, features=H * W * C)
dynamics_model = partial(
LRUDynamics,
d_hidden=d_hidden,
r_min=0.9,
r_max=1.0,
max_phase=jnp.pi * 2,
clip_eigs=False,
prepend_ones=False,
)
model = BatchedDenseKoopmanAutoencoder2D(
encoder_model=encoder_model,
decoder_model=decoder_model,
dynamics_model=dynamics_model,
n_steps=n_steps,
d_vars=C,
d_model=(H, W),
norm="layer",
training=True,
)
vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, H, W, C):::
::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
encoded = model.apply(vars, dummy, method="encode")
decoded = model.apply(vars, encoded, method="decode")
enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode"):::
::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class KoopmanAutoencoder2D(nn.Module):
"""
Koopman Autoencoder
"""
encoder_model: ConvEncoder
decoder_model: ConvDecoder
dynamics_model: LRUDynamics
d_latent_channels: int
d_latent_dims: tuple[int, int]
n_steps: int
norm: str = "layer"
training: bool = True
def setup(self):
self.encoder = self.encoder_model(training=self.training, norm=self.norm)
self.decoder = self.decoder_model(training=self.training, norm=self.norm)
self.dynamics = self.dynamics_model()
def __call__(
self,
x: jnp.ndarray, # (H, W, C)
) -> jnp.ndarray: # (T, H, W, C)
z = self.encode(x[0])
z = self.advance(z)
x_hat = self.decode(z)
return x_hat
def encode(
self,
x: jnp.ndarray, # (H, W, C) or (T, H, W, C)
) -> jnp.ndarray: # (hidden_dim,) or (T, hidden_dim)
z = self.encoder(x)
if len(z.shape) == 4:
z = rearrange(z, "t h w c -> t (h w c)")
elif len(z.shape) == 3:
z = rearrange(z, "h w c -> (h w c)")
return z
def decode(
self,
z: jnp.ndarray, # (hidden_dim,) or (T, hidden_dim)
) -> jnp.ndarray: # (H, W, C) or (T, H, W, C)
if len(z.shape) == 2:
z = rearrange(
z,
"t (h w c) -> t h w c",
h=self.d_latent_dims[0],
w=self.d_latent_dims[1],
c=self.d_latent_channels,
)
elif len(z.shape) == 1:
z = rearrange(
z,
"(h w c) -> h w c",
h=self.d_latent_dims[0],
w=self.d_latent_dims[1],
c=self.d_latent_channels,
)
return self.decoder(z)
def advance(
self,
z: jnp.ndarray, # (hidden_dim,)
) -> jnp.ndarray:
# convert to complex and back
z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
return z
BatchedKoopmanAutoencoder2D = nn.vmap(
KoopmanAutoencoder2D,
in_axes=0, # map over the first axis of the first input not the second
out_axes=0,
variable_axes={"params": None, "batch_stats": None, "cache": 0, "prime": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
axis_name="batch",
):::
::: {#cell-12 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
n_steps = 16
d_hidden = 128
B, T, H, W, C = 5, n_steps, 40, 40, 3
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))
encoder_model = partial(ConvEncoder, block_size=(8, 16, 32))
decoder_model = partial(ConvDecoder, block_size=(8, 16, 32))
dynamics_model = partial(
LRUDynamics,
d_hidden=16 * 5 * 5,
r_min=0.9,
r_max=1.0,
max_phase=jnp.pi * 2,
clip_eigs=False,
prepend_ones=False,
)
model = BatchedKoopmanAutoencoder2D(
encoder_model=encoder_model,
decoder_model=decoder_model,
dynamics_model=dynamics_model,
d_latent_channels=32,
d_latent_dims=(5, 5),
n_steps=16,
norm="layer",
training=True,
)
vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, H, W, C):::
::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
encoded = model.apply(vars, dummy, method="encode")
decoded = model.apply(vars, encoded, method="decode")
enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode"):::
Koopman Autoencoder 1D
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class KoopmanAutoencoder1D(nn.Module):
"""
Koopman Autoencoder
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Module
d_vars: int
d_model: int
n_steps: int
norm: str = "layer"
training: bool = True
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
def __call__(
self,
x: jnp.ndarray, # (T, W, C)
) -> jnp.ndarray:
z = self.encode(x[0])
z = self.advance(z)
x_hat = self.decode(z)
return x_hat
def encode(
self,
x: jnp.ndarray, # (W, C) or (T, W, C)
) -> jnp.ndarray: # (T, hidden_dim)
if len(x.shape) == 2:
x = rearrange(x, "w c -> (w c)")
elif len(x.shape) == 3:
x = rearrange(x, "t w c -> t (w c)")
return self.encoder(x)
def decode(
self,
z: jnp.ndarray, # (hidden_dim,) or (T, hidden_dim)
) -> jnp.ndarray: # (W, C) or (T, W, C)
z = self.decoder(z)
if len(z.shape) == 2:
z = rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
elif len(z.shape) == 1:
z = rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
return z
def advance(
self,
z: jnp.ndarray, # (hidden_dim,)
) -> jnp.ndarray: # (T, hidden_dim)
# convert to complex and back
z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
return z
BatchedKoopmanAutoencoder1D = nn.vmap(
KoopmanAutoencoder1D,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
):::
::: {#cell-16 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 5, 16, 101, 3
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))
encoder_model = partial(
nn.Dense,
features=d_hidden * 2,
kernel_init=nn.initializers.orthogonal(),
use_bias=False,
)
decoder_model = partial(
nn.Dense,
features=W * C,
kernel_init=nn.initializers.orthogonal(),
use_bias=False,
)
dynamics_model = partial(
LRUDynamicsVarying,
d_hidden=d_hidden,
r_min=0.9,
r_max=1.0,
max_phase=jnp.pi * 2,
model=nn.Dense(features=d_hidden * 2, kernel_init=nn.initializers.orthogonal()),
clip_eigs=False,
prepend_ones=False,
)
model = BatchedKoopmanAutoencoder1D(
encoder_model=encoder_model,
decoder_model=decoder_model,
dynamics_model=dynamics_model,
d_vars=C,
d_model=W,
n_steps=T,
norm="layer",
training=True,
)
vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, W, C)E0107 11:25:15.482350 37996 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
--------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) Cell In[13], line 42 20 dynamics_model = partial( 21 LRUDynamicsVarying, 22 d_hidden=d_hidden, (...) 28 prepend_ones=False, 29 ) 31 model = BatchedKoopmanAutoencoder1D( 32 encoder_model=encoder_model, 33 decoder_model=decoder_model, (...) 39 training=True, 40 ) ---> 42 vars = model.init(jax.random.PRNGKey(0), dummy) 43 out = model.apply(vars, dummy) 45 assert out.shape == (B, T, W, C) [... skipping hidden 17 frame] Cell In[12], line 27, in KoopmanAutoencoder1D.__call__(self, x) 23 def __call__( 24 self, 25 x: jnp.ndarray, # (T, W, C) 26 ) -> jnp.ndarray: ---> 27 z = self.encode(x[0]) 28 z = self.advance(z) 29 x_hat = self.decode(z) [... skipping hidden 2 frame] Cell In[12], line 40, in KoopmanAutoencoder1D.encode(self, x) 38 elif len(x.shape) == 3: 39 x = rearrange(x, "t w c -> t (w c)") ---> 40 return self.encoder(x) [... skipping hidden 2 frame] File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/flax/linen/linear.py:251, in Dense.__call__(self, inputs) 241 @compact 242 def __call__(self, inputs: Array) -> Array: 243 """Applies a linear transformation to the inputs along the last dimension. 244 245 Args: (...) 249 The transformed input. 250 """ --> 251 kernel = self.param( 252 'kernel', 253 self.kernel_init, 254 (jnp.shape(inputs)[-1], self.features), 255 self.param_dtype, 256 ) 257 if self.use_bias: 258 bias = self.param( 259 'bias', self.bias_init, (self.features,), self.param_dtype 260 ) [... skipping hidden 2 frame] File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/jax/_src/nn/initializers.py:611, in orthogonal.<locals>.init(key, shape, dtype) 609 matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) 610 A = random.normal(key, matrix_shape, dtype) --> 611 Q, R = jnp.linalg.qr(A) 612 diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim) 613 Q *= diag_sign # needed for a uniform distribution [... skipping hidden 10 frame] File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1185, in ExecuteReplicated.__call__(self, *args) 1183 self._handle_token_bufs(result_token_bufs, sharded_runtime_token) 1184 else: -> 1185 results = self.xla_executable.execute_sharded(input_bufs) 1186 if dispatch.needs_check_special(): 1187 out_arrays = results.disassemble_into_single_device_arrays() XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
:::
::: {#cell-17 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
print(dummy.shape)
encoded = model.apply(vars, dummy, method="encode")
print("encoded", encoded.shape)
decoded = model.apply(vars, encoded, method="decode")
assert decoded.shape == dummy.shape
enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode")
assert dec_sequence.shape == (B, T, W, C):::
::: {#cell-18 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class KoopmanAutoencoder1DReal(nn.Module):
"""
Koopman Autoencoder but with real encoding and decoding
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Module
d_vars: int
d_model: int
n_steps: int
norm: str = "layer"
training: bool = True
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
def __call__(
self,
x: jnp.ndarray, # (T, W, C)
) -> jnp.ndarray:
z = self.encode(x[0])
z = self.advance(z)
x_hat = self.decode(z)
return x_hat
def encode(
self,
x: jnp.ndarray, # (W, C) or (T, W, C)
) -> jnp.ndarray: # (T, hidden_dim)
if len(x.shape) == 2:
x = rearrange(x, "w c -> (w c)")
elif len(x.shape) == 3:
x = rearrange(x, "t w c -> t (w c)")
return self.encoder(x)
def decode(
self,
z: jnp.ndarray, # (hidden_dim,) or (T, hidden_dim)
) -> jnp.ndarray: # (W, C) or (T, W, C)
z = self.decoder(z)
if len(z.shape) == 2:
z = rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
elif len(z.shape) == 1:
z = rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
return z
def advance(
self,
z: jnp.ndarray, # (hidden_dim,)
) -> jnp.ndarray: # (T, hidden_dim)
# convert to complex and back
z = z + 1j * 0.0
z = self.dynamics(z, self.n_steps).real
return z
BatchedKoopmanAutoencoder1DReal = nn.vmap(
KoopmanAutoencoder1DReal,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
):::
::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 5, 16, 101, 3
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))
encoder_model = partial(
nn.Dense,
features=d_hidden,
kernel_init=nn.initializers.orthogonal(),
use_bias=False,
)
decoder_model = partial(
nn.Dense,
features=W * C,
kernel_init=nn.initializers.orthogonal(),
use_bias=False,
)
dynamics_model = partial(
LRUDynamicsVarying,
d_hidden=d_hidden,
r_min=0.9,
r_max=1.0,
max_phase=jnp.pi * 2,
model=nn.Dense(
features=d_hidden * 2,
kernel_init=nn.initializers.orthogonal(),
),
clip_eigs=False,
prepend_ones=False,
)
model = BatchedKoopmanAutoencoder1DReal(
encoder_model=encoder_model,
decoder_model=decoder_model,
dynamics_model=dynamics_model,
d_vars=C,
d_model=W,
n_steps=T,
norm="layer",
training=True,
)
vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, W, C):::
Tied autoencoder
::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class TiedAutoencoder(nn.Module):
"""
Tied (real) Autoencoder
"""
d_model: int
d_vars: int
d_hidden: int
dynamics_model: nn.Module
n_steps: int
norm: str = "layer"
training: bool = True
dtype = None
param_dtype = jnp.float32
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
def setup(self):
self.dynamics = self.dynamics_model()
self.kernel = self.param(
"kernel",
self.kernel_init,
(self.d_model * self.d_vars, self.d_hidden),
self.param_dtype,
)
def encode(self, x: ArrayLike) -> ArrayLike:
if len(x.shape) == 2:
x = rearrange(x, "w c -> (w c)")
elif len(x.shape) == 3:
x = rearrange(x, "t w c -> t (w c)")
z = jax.lax.dot_general(
x,
self.kernel,
(((x.ndim - 1,), (0,)), ((), ())),
)
return z
def decode(self, z: ArrayLike) -> ArrayLike:
x = jax.lax.dot_general(
z,
self.kernel.T,
(((z.ndim - 1,), (0,)), ((), ())),
)
if len(x.shape) == 2:
x = rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
elif len(x.shape) == 1:
x = rearrange(x, "(w c) -> w c", w=self.d_model, c=self.d_vars)
return x
def advance(
self,
z: jnp.ndarray, # (hidden_dim,)
) -> jnp.ndarray: # (T, hidden_dim)
# convert to complex and back
z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
return z
def __call__(
self,
inputs: ArrayLike, # (T, H, C)
) -> ArrayLike:
z = self.encode(inputs[0])
z = self.advance(z)
x_hat = self.decode(z)
return x_hat
BatchedTiedAutoencoder = nn.vmap(
TiedAutoencoder,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
):::
::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 5, 16, 101, 3
dummy = jnp.zeros((B, T, W * C))
target = jnp.zeros((B, T, W * C))
d_input = W * C
d_hidden = 128
n_steps = 16
model = BatchedTiedAutoencoder(
d_vars=C,
d_model=W,
d_hidden=d_hidden * 2,
dynamics_model=partial(
LRUDynamicsVarying,
d_hidden=d_hidden,
r_min=0.9,
r_max=1.0,
max_phase=jnp.pi * 2,
model=nn.Dense(
features=d_hidden * 2,
kernel_init=nn.initializers.orthogonal(),
),
clip_eigs=False,
prepend_ones=False,
),
n_steps=n_steps,
)
output, variables = model.init_with_output(jax.random.PRNGKey(0), dummy)
assert output.shape == (B, T, W, C):::
from physmodjax.models.recurrent import DiscreteModalDynamics::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 5, 16, 101, 3
dummy = jnp.zeros((B, T, W * C))
target = jnp.zeros((B, T, W * C))
d_input = W * C
d_hidden = 128
n_steps = 16
model = BatchedTiedAutoencoder(
d_vars=C,
d_model=W,
d_hidden=d_hidden * 2,
n_steps=n_steps,
dynamics_model=partial(
DiscreteModalDynamics,
d_hidden=d_hidden,
clip=True,
r_min=0.3,
r_max=1,
min_phase=0.01,
max_phase=jnp.pi / 2,
prepend_ones=False,
),
)
output, variables = model.init_with_output(jax.random.PRNGKey(0), dummy)
jax.value_and_grad(lambda x: model.apply(variables, x).mean())(dummy)
assert output.shape == (B, T, W, C):::
::: {#cell-25 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class JAESAutoencoder1D(nn.Module):
"""
JAES Autoencoder
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Module
d_vars: int
d_model: int
n_steps: int
modulation_hidden_channels: Sequence[
int
] # layers and hidden channels including the output for the modulation
norm: str = "layer"
training: bool = True
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
@nn.compact
def __call__(
self,
x: jnp.ndarray, # (T, W, C)
) -> jnp.ndarray:
z = self.encode(x[0])
x0 = x[0].reshape(-1)
t_grid = jnp.linspace(0, 1, self.n_steps)[..., None]
varying_amp = ModulatedSiren(
hidden_channels=self.modulation_hidden_channels,
w0_first_layer=1.0,
name="varying_amp",
)(t_grid, x0)
varying_amp = nn.sigmoid(varying_amp) * 3.0
varying_phase = ModulatedSiren(
hidden_channels=self.modulation_hidden_channels,
w0_first_layer=1.0,
name="varying_phase",
)(t_grid, x0)
modulation = varying_amp * jnp.exp(1j * varying_phase)
z = self.advance(z)
z = z * modulation
z = jnp.concatenate([z.real, z.imag], axis=-1)
x_hat = self.decode(z)
return x_hat
def encode(
self,
x: jnp.ndarray, # (W, C) or (T, W, C)
) -> jnp.ndarray: # (T, hidden_dim)
if len(x.shape) == 2:
x = rearrange(x, "w c -> (w c)")
elif len(x.shape) == 3:
x = rearrange(x, "t w c -> t (w c)")
return self.encoder(x)
def decode(
self,
z: jnp.ndarray, # (hidden_dim,) or (T, hidden_dim)
) -> jnp.ndarray: # (W, C) or (T, W, C)
z = self.decoder(z)
if len(z.shape) == 2:
z = rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
elif len(z.shape) == 1:
z = rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
return z
def advance(
self,
z: jnp.ndarray, # (hidden_dim,)
) -> jnp.ndarray: # (T, hidden_dim)
# convert to complex and return
# IMPORTANT: we are returning a complex array unlike the other models
# this is because we are modulating the latent space with a complex number
z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
return z
BatchedJAESAutoencoder1D = nn.vmap(
JAESAutoencoder1D,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
):::
::: {#cell-26 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 5, 16, 101, 2
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))
encoder_model = partial(
nn.Dense,
features=d_hidden * 2,
kernel_init=nn.initializers.orthogonal(),
use_bias=False,
)
decoder_model = partial(
nn.Dense,
features=W * C,
kernel_init=nn.initializers.orthogonal(),
use_bias=False,
)
dynamics_model = partial(
DiscreteModalDynamicsAngleMag,
d_hidden=d_hidden,
r_min=0.5,
r_max=1.0,
min_phase=0.01,
max_phase=jnp.pi * 2,
clip=False,
prepend_ones=False,
)
model = BatchedJAESAutoencoder1D(
encoder_model=encoder_model,
decoder_model=decoder_model,
dynamics_model=dynamics_model,
d_vars=C,
d_model=W,
n_steps=T,
norm="layer",
training=True,
modulation_hidden_channels=[128, d_hidden],
)
vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, W, C):::
::: {#cell-27 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class JAESAutoencoder2D(nn.Module):
"""
JAES Autoencoder
"""
encoder_model: ConvEncoder
decoder_model: ConvDecoder
dynamics_model: LRUDynamics
d_vars: int
d_model: tuple[int, int]
n_steps: int
modulation_hidden_channels: Sequence[
int
] # layers and hidden channels including the output for the modulation
norm: str = "layer"
training: bool = True
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
@nn.compact
def __call__(
self,
x: jnp.ndarray, # (T, H, W, C)
) -> jnp.ndarray:
z = self.encode(x[0])
x0 = x[0].reshape(-1)
t_grid = jnp.linspace(0, 1, self.n_steps)[..., None]
varying_amp = ModulatedSiren(
hidden_channels=self.modulation_hidden_channels,
w0_first_layer=1.0,
name="varying_amp",
)(t_grid, x0)
varying_amp = nn.sigmoid(varying_amp) * 3.0
varying_phase = ModulatedSiren(
hidden_channels=self.modulation_hidden_channels,
w0_first_layer=1.0,
name="varying_phase",
)(t_grid, x0)
modulation = varying_amp * jnp.exp(1j * varying_phase)
z = self.advance(z)
z = z * modulation
z = jnp.concatenate([z.real, z.imag], axis=-1)
x_hat = self.decode(z)
return x_hat
def encode(
self,
x: jnp.ndarray, # (W, H, C) or (T, W, H, C)
) -> jnp.ndarray: # (hidden_dim) or (T, hidden_dim)
if len(x.shape) == 3:
x = rearrange(x, "w h c -> (w h c)")
elif len(x.shape) == 4:
x = rearrange(x, "t w h c -> t (w h c)")
return self.encoder(x)
def decode(
self,
z: jnp.ndarray, # (hidden_dim) or (T, hidden_dim)
) -> jnp.ndarray: # (W, H, C) or (T, W, H, C)
z = self.decoder(z)
if len(z.shape) == 1:
z = rearrange(
z,
"(w h c) -> w h c",
w=self.d_model[0],
h=self.d_model[1],
c=self.d_vars,
)
elif len(z.shape) == 2:
z = rearrange(
z,
"t (w h c) -> t w h c",
w=self.d_model[0],
h=self.d_model[1],
c=self.d_vars,
)
return z
def advance(
self,
z: jnp.ndarray, # (hidden_dim,)
) -> jnp.ndarray:
# convert to complex and back
z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
return z
BatchedJAESAutoencoder2D = nn.vmap(
JAESAutoencoder2D,
in_axes=0, # map over the first axis of the first input not the second
out_axes=0,
variable_axes={"params": None, "batch_stats": None, "cache": 0, "prime": None},
split_rngs={"params": False},
methods=["__call__", "decode", "encode", "advance"],
axis_name="batch",
):::