from functools import partial
from physmodjax.models.recurrent import (
DiscreteModalDynamicsAngleMag,
LRUDynamicsVarying, )
Autoencoder models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from collections.abc import Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from jax.typing import ArrayLike
from physmodjax.models.conv import ConvDecoder, ConvEncoder
from physmodjax.models.mlp import ModulatedSiren
from physmodjax.models.recurrent import LRUDynamics
from physmodjax.utils.data import create_grid
2025-01-07 11:25:03.453384: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.85). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
:::
2D Convolutional Autoencoder with Linear Dynamics
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FourierAutoencoder2D(nn.Module):
dynamics_model: nn.Moduleint
d_vars: tuple[int, int]
d_model: str = "layer"
norm: bool = True
training: bool = False
use_positions: int = 20
n_modes:
def setup(self):
self.dynamics = self.dynamics_model()
if self.use_positions:
self.grid = create_grid(self.d_model[1], self.d_model[0])
def __call__(
self,
# (W, H, C) or (T, W, H, C)
x: jnp.ndarray, -> jnp.ndarray:
) = self.encode(x)
z = self.advance(z, x)
z return self.decode(z)
def encode(
self,
# (W, H, C) or (T, W, H, C)
x: jnp.ndarray, -> jnp.ndarray: # (hidden_dim) or (T, hidden_dim) complex
) """
Spatial encoding of the input data.
"""
if self.use_positions:
= jnp.concatenate([x, self.grid], axis=-1)
x
*_, W, H, C = x.shape
= jnp.fft.rfft2(x, s=(W * 2 - 1, H * 2 - 1), axes=(-3, -2), norm="ortho")
x = x[..., : self.n_modes, : self.n_modes, :]
x
if len(x.shape) == 3:
= rearrange(x, "w h c -> (w h c)")
x elif len(x.shape) == 4:
= rearrange(x, "t w h c -> t (w h c)")
x
return x
def advance(
self,
# (hidden_dim,) complex
z: jnp.ndarray, # (T,)
steps: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim,)
) # z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
= self.dynamics(z, steps.shape[0])
z # z = jnp.concatenate([z.real, z.imag], axis=-1)
return z
def decode(
self,
# (hidden_dim) or (T, hidden_dim) complex
z: jnp.ndarray, -> jnp.ndarray: # (W, H, C) or (T, W, H, C) real
) = self.n_modes
w_modes = self.n_modes
h_modes if len(z.shape) == 1:
= rearrange(z, "(w h c) -> w h c", w=w_modes, h=h_modes, c=self.d_vars)
z elif len(z.shape) == 2:
= rearrange(
z "t (w h c) -> t w h c", w=w_modes, h=h_modes, c=self.d_vars
z,
)
= jnp.fft.irfft2(
z =(self.d_model[0], self.d_model[1]), axes=(-3, -2), norm="ortho"
z, s
)return z
= nn.vmap(
BatchedFourierAutoencoder2D
FourierAutoencoder2D,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods )
:::
::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 128
d_hidden = 5, 16, 41, 37, 2
B, T, H, W, C = jnp.zeros((B, T, H, W, C))
dummy = jnp.zeros((B, T, H, W, C))
target
= partial(
dynamics_model
LRUDynamics,=(20 * 20 * 2),
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=False,
clip_eigs=False,
prepend_ones
)
= BatchedFourierAutoencoder2D(
model =dynamics_model,
dynamics_model=C,
d_vars=(H, W),
d_model="layer",
norm=True,
training=20,
n_modes
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, H, W, C)
:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DenseKoopmanAutoencoder2D(nn.Module):
"""
Koopman Dense Autoencoder
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Moduleint
d_vars: tuple[int, int]
d_model: int
n_steps: str = "layer"
norm: bool = True
training: bool = False
use_positions:
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
if self.use_positions:
self.grid = create_grid(self.d_model[1], self.d_model[0])
def __call__(
self,
# (W, H, C)
x: jnp.ndarray, -> jnp.ndarray:
) = self.encode(x[0])
z = self.advance(z)
z return self.decode(z)
def encode(
self,
# (W, H, C) or (T, W, H, C)
x: jnp.ndarray, -> jnp.ndarray: # (hidden_dim) or (T, hidden_dim)
) if self.use_positions:
= jnp.concatenate([x, self.grid], axis=-1)
x
if len(x.shape) == 3:
= rearrange(x, "w h c -> (w h c)")
x elif len(x.shape) == 4:
= rearrange(x, "t w h c -> t (w h c)")
x return self.encoder(x)
def advance(
self,
# (hidden_dim,)
z: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim,)
) = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
z return z
def decode(
self,
# (hidden_dim) or (T, hidden_dim)
z: jnp.ndarray, -> jnp.ndarray: # (W, H, C) or (T, W, H, C)
) = self.decoder(z)
z if len(z.shape) == 1:
= rearrange(
z
z,"(w h c) -> w h c",
=self.d_model[0],
w=self.d_model[1],
h=self.d_vars,
c
)elif len(z.shape) == 2:
= rearrange(
z
z,"t (w h c) -> t w h c",
=self.d_model[0],
w=self.d_model[1],
h=self.d_vars,
c
)return z
= nn.vmap(
BatchedDenseKoopmanAutoencoder2D
DenseKoopmanAutoencoder2D,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods )
:::
::: {#cell-9 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 16
n_steps = 128
d_hidden = 5, n_steps, 41, 37, 3
B, T, H, W, C = jnp.zeros((B, T, H, W, C))
dummy = jnp.zeros((B, T, H, W, C))
target
= partial(nn.Dense, features=d_hidden * 2)
encoder_model = partial(nn.Dense, features=H * W * C)
decoder_model = partial(
dynamics_model
LRUDynamics,=d_hidden,
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=False,
clip_eigs=False,
prepend_ones
)
= BatchedDenseKoopmanAutoencoder2D(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=n_steps,
n_steps=C,
d_vars=(H, W),
d_model="layer",
norm=True,
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, H, W, C)
:::
::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= model.apply(vars, dummy, method="encode")
encoded = model.apply(vars, encoded, method="decode")
decoded
= model.apply(vars, target, method="encode")
enc_sequence = model.apply(vars, enc_sequence, method="decode") dec_sequence
:::
::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class KoopmanAutoencoder2D(nn.Module):
"""
Koopman Autoencoder
"""
encoder_model: ConvEncoder
decoder_model: ConvDecoder
dynamics_model: LRUDynamicsint
d_latent_channels: tuple[int, int]
d_latent_dims: int
n_steps: str = "layer"
norm: bool = True
training:
def setup(self):
self.encoder = self.encoder_model(training=self.training, norm=self.norm)
self.decoder = self.decoder_model(training=self.training, norm=self.norm)
self.dynamics = self.dynamics_model()
def __call__(
self,
# (H, W, C)
x: jnp.ndarray, -> jnp.ndarray: # (T, H, W, C)
) = self.encode(x[0])
z = self.advance(z)
z = self.decode(z)
x_hat return x_hat
def encode(
self,
# (H, W, C) or (T, H, W, C)
x: jnp.ndarray, -> jnp.ndarray: # (hidden_dim,) or (T, hidden_dim)
) = self.encoder(x)
z if len(z.shape) == 4:
= rearrange(z, "t h w c -> t (h w c)")
z elif len(z.shape) == 3:
= rearrange(z, "h w c -> (h w c)")
z return z
def decode(
self,
# (hidden_dim,) or (T, hidden_dim)
z: jnp.ndarray, -> jnp.ndarray: # (H, W, C) or (T, H, W, C)
) if len(z.shape) == 2:
= rearrange(
z
z,"t (h w c) -> t h w c",
=self.d_latent_dims[0],
h=self.d_latent_dims[1],
w=self.d_latent_channels,
c
)elif len(z.shape) == 1:
= rearrange(
z
z,"(h w c) -> h w c",
=self.d_latent_dims[0],
h=self.d_latent_dims[1],
w=self.d_latent_channels,
c
)return self.decoder(z)
def advance(
self,
# (hidden_dim,)
z: jnp.ndarray, -> jnp.ndarray:
) # convert to complex and back
= z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
z return z
= nn.vmap(
BatchedKoopmanAutoencoder2D
KoopmanAutoencoder2D,=0, # map over the first axis of the first input not the second
in_axes=0,
out_axes={"params": None, "batch_stats": None, "cache": 0, "prime": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods="batch",
axis_name )
:::
::: {#cell-12 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 16
n_steps = 128
d_hidden = 5, n_steps, 40, 40, 3
B, T, H, W, C = jnp.zeros((B, T, H, W, C))
dummy = jnp.zeros((B, T, H, W, C))
target
= partial(ConvEncoder, block_size=(8, 16, 32))
encoder_model = partial(ConvDecoder, block_size=(8, 16, 32))
decoder_model = partial(
dynamics_model
LRUDynamics,=16 * 5 * 5,
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=False,
clip_eigs=False,
prepend_ones
)
= BatchedKoopmanAutoencoder2D(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=32,
d_latent_channels=(5, 5),
d_latent_dims=16,
n_steps="layer",
norm=True,
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out assert out.shape == (B, T, H, W, C)
:::
::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= model.apply(vars, dummy, method="encode")
encoded = model.apply(vars, encoded, method="decode")
decoded
= model.apply(vars, target, method="encode")
enc_sequence = model.apply(vars, enc_sequence, method="decode") dec_sequence
:::
Koopman Autoencoder 1D
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class KoopmanAutoencoder1D(nn.Module):
"""
Koopman Autoencoder
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Moduleint
d_vars: int
d_model: int
n_steps: str = "layer"
norm: bool = True
training:
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
def __call__(
self,
# (T, W, C)
x: jnp.ndarray, -> jnp.ndarray:
) = self.encode(x[0])
z = self.advance(z)
z = self.decode(z)
x_hat return x_hat
def encode(
self,
# (W, C) or (T, W, C)
x: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim)
) if len(x.shape) == 2:
= rearrange(x, "w c -> (w c)")
x elif len(x.shape) == 3:
= rearrange(x, "t w c -> t (w c)")
x return self.encoder(x)
def decode(
self,
# (hidden_dim,) or (T, hidden_dim)
z: jnp.ndarray, -> jnp.ndarray: # (W, C) or (T, W, C)
) = self.decoder(z)
z if len(z.shape) == 2:
= rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
z elif len(z.shape) == 1:
= rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
z return z
def advance(
self,
# (hidden_dim,)
z: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim)
) # convert to complex and back
= z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
z return z
= nn.vmap(
BatchedKoopmanAutoencoder1D
KoopmanAutoencoder1D,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods )
:::
::: {#cell-16 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 5, 16, 101, 3
B, T, W, C = 128
d_hidden = jnp.zeros((B, T, W, C))
dummy = jnp.zeros((B, T, W, C))
target
= partial(
encoder_model
nn.Dense,=d_hidden * 2,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
decoder_model
nn.Dense,=W * C,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
dynamics_model
LRUDynamicsVarying,=d_hidden,
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=nn.Dense(features=d_hidden * 2, kernel_init=nn.initializers.orthogonal()),
model=False,
clip_eigs=False,
prepend_ones
)
= BatchedKoopmanAutoencoder1D(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=C,
d_vars=W,
d_model=T,
n_steps="layer",
norm=True,
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, W, C)
E0107 11:25:15.482350 37996 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
--------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) Cell In[13], line 42 20 dynamics_model = partial( 21 LRUDynamicsVarying, 22 d_hidden=d_hidden, (...) 28 prepend_ones=False, 29 ) 31 model = BatchedKoopmanAutoencoder1D( 32 encoder_model=encoder_model, 33 decoder_model=decoder_model, (...) 39 training=True, 40 ) ---> 42 vars = model.init(jax.random.PRNGKey(0), dummy) 43 out = model.apply(vars, dummy) 45 assert out.shape == (B, T, W, C) [... skipping hidden 17 frame] Cell In[12], line 27, in KoopmanAutoencoder1D.__call__(self, x) 23 def __call__( 24 self, 25 x: jnp.ndarray, # (T, W, C) 26 ) -> jnp.ndarray: ---> 27 z = self.encode(x[0]) 28 z = self.advance(z) 29 x_hat = self.decode(z) [... skipping hidden 2 frame] Cell In[12], line 40, in KoopmanAutoencoder1D.encode(self, x) 38 elif len(x.shape) == 3: 39 x = rearrange(x, "t w c -> t (w c)") ---> 40 return self.encoder(x) [... skipping hidden 2 frame] File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/flax/linen/linear.py:251, in Dense.__call__(self, inputs) 241 @compact 242 def __call__(self, inputs: Array) -> Array: 243 """Applies a linear transformation to the inputs along the last dimension. 244 245 Args: (...) 249 The transformed input. 250 """ --> 251 kernel = self.param( 252 'kernel', 253 self.kernel_init, 254 (jnp.shape(inputs)[-1], self.features), 255 self.param_dtype, 256 ) 257 if self.use_bias: 258 bias = self.param( 259 'bias', self.bias_init, (self.features,), self.param_dtype 260 ) [... skipping hidden 2 frame] File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/jax/_src/nn/initializers.py:611, in orthogonal.<locals>.init(key, shape, dtype) 609 matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) 610 A = random.normal(key, matrix_shape, dtype) --> 611 Q, R = jnp.linalg.qr(A) 612 diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim) 613 Q *= diag_sign # needed for a uniform distribution [... skipping hidden 10 frame] File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1185, in ExecuteReplicated.__call__(self, *args) 1183 self._handle_token_bufs(result_token_bufs, sharded_runtime_token) 1184 else: -> 1185 results = self.xla_executable.execute_sharded(input_bufs) 1186 if dispatch.needs_check_special(): 1187 out_arrays = results.disassemble_into_single_device_arrays() XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
:::
::: {#cell-17 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
print(dummy.shape)
= model.apply(vars, dummy, method="encode")
encoded print("encoded", encoded.shape)
= model.apply(vars, encoded, method="decode")
decoded assert decoded.shape == dummy.shape
= model.apply(vars, target, method="encode")
enc_sequence = model.apply(vars, enc_sequence, method="decode")
dec_sequence
assert dec_sequence.shape == (B, T, W, C)
:::
::: {#cell-18 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class KoopmanAutoencoder1DReal(nn.Module):
"""
Koopman Autoencoder but with real encoding and decoding
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Moduleint
d_vars: int
d_model: int
n_steps: str = "layer"
norm: bool = True
training:
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
def __call__(
self,
# (T, W, C)
x: jnp.ndarray, -> jnp.ndarray:
) = self.encode(x[0])
z = self.advance(z)
z = self.decode(z)
x_hat return x_hat
def encode(
self,
# (W, C) or (T, W, C)
x: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim)
) if len(x.shape) == 2:
= rearrange(x, "w c -> (w c)")
x elif len(x.shape) == 3:
= rearrange(x, "t w c -> t (w c)")
x return self.encoder(x)
def decode(
self,
# (hidden_dim,) or (T, hidden_dim)
z: jnp.ndarray, -> jnp.ndarray: # (W, C) or (T, W, C)
) = self.decoder(z)
z if len(z.shape) == 2:
= rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
z elif len(z.shape) == 1:
= rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
z return z
def advance(
self,
# (hidden_dim,)
z: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim)
) # convert to complex and back
= z + 1j * 0.0
z = self.dynamics(z, self.n_steps).real
z return z
= nn.vmap(
BatchedKoopmanAutoencoder1DReal
KoopmanAutoencoder1DReal,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods )
:::
::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 5, 16, 101, 3
B, T, W, C = 128
d_hidden = jnp.zeros((B, T, W, C))
dummy = jnp.zeros((B, T, W, C))
target
= partial(
encoder_model
nn.Dense,=d_hidden,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
decoder_model
nn.Dense,=W * C,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
dynamics_model
LRUDynamicsVarying,=d_hidden,
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=nn.Dense(
model=d_hidden * 2,
features=nn.initializers.orthogonal(),
kernel_init
),=False,
clip_eigs=False,
prepend_ones
)
= BatchedKoopmanAutoencoder1DReal(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=C,
d_vars=W,
d_model=T,
n_steps="layer",
norm=True,
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, W, C)
:::
Tied autoencoder
::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class TiedAutoencoder(nn.Module):
"""
Tied (real) Autoencoder
"""
int
d_model: int
d_vars: int
d_hidden:
dynamics_model: nn.Moduleint
n_steps: str = "layer"
norm: bool = True
training: = None
dtype = jnp.float32
param_dtype = nn.initializers.lecun_normal()
kernel_init: nn.initializers.Initializer
def setup(self):
self.dynamics = self.dynamics_model()
self.kernel = self.param(
"kernel",
self.kernel_init,
self.d_model * self.d_vars, self.d_hidden),
(self.param_dtype,
)
def encode(self, x: ArrayLike) -> ArrayLike:
if len(x.shape) == 2:
= rearrange(x, "w c -> (w c)")
x elif len(x.shape) == 3:
= rearrange(x, "t w c -> t (w c)")
x
= jax.lax.dot_general(
z
x,self.kernel,
- 1,), (0,)), ((), ())),
(((x.ndim
)return z
def decode(self, z: ArrayLike) -> ArrayLike:
= jax.lax.dot_general(
x
z,self.kernel.T,
- 1,), (0,)), ((), ())),
(((z.ndim
)
if len(x.shape) == 2:
= rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
x elif len(x.shape) == 1:
= rearrange(x, "(w c) -> w c", w=self.d_model, c=self.d_vars)
x
return x
def advance(
self,
# (hidden_dim,)
z: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim)
) # convert to complex and back
= z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z = jnp.concatenate([z.real, z.imag], axis=-1)
z return z
def __call__(
self,
# (T, H, C)
inputs: ArrayLike, -> ArrayLike:
) = self.encode(inputs[0])
z = self.advance(z)
z = self.decode(z)
x_hat return x_hat
= nn.vmap(
BatchedTiedAutoencoder
TiedAutoencoder,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods )
:::
::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 5, 16, 101, 3
B, T, W, C
= jnp.zeros((B, T, W * C))
dummy = jnp.zeros((B, T, W * C))
target
= W * C
d_input = 128
d_hidden = 16
n_steps
= BatchedTiedAutoencoder(
model =C,
d_vars=W,
d_model=d_hidden * 2,
d_hidden=partial(
dynamics_model
LRUDynamicsVarying,=d_hidden,
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=nn.Dense(
model=d_hidden * 2,
features=nn.initializers.orthogonal(),
kernel_init
),=False,
clip_eigs=False,
prepend_ones
),=n_steps,
n_steps
)
= model.init_with_output(jax.random.PRNGKey(0), dummy)
output, variables
assert output.shape == (B, T, W, C)
:::
from physmodjax.models.recurrent import DiscreteModalDynamics
::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 5, 16, 101, 3
B, T, W, C
= jnp.zeros((B, T, W * C))
dummy = jnp.zeros((B, T, W * C))
target
= W * C
d_input = 128
d_hidden = 16
n_steps
= BatchedTiedAutoencoder(
model =C,
d_vars=W,
d_model=d_hidden * 2,
d_hidden=n_steps,
n_steps=partial(
dynamics_model
DiscreteModalDynamics,=d_hidden,
d_hidden=True,
clip=0.3,
r_min=1,
r_max=0.01,
min_phase=jnp.pi / 2,
max_phase=False,
prepend_ones
),
)
= model.init_with_output(jax.random.PRNGKey(0), dummy)
output, variables
lambda x: model.apply(variables, x).mean())(dummy)
jax.value_and_grad(assert output.shape == (B, T, W, C)
:::
::: {#cell-25 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class JAESAutoencoder1D(nn.Module):
"""
JAES Autoencoder
"""
encoder_model: nn.Module
decoder_model: nn.Module
dynamics_model: nn.Moduleint
d_vars: int
d_model: int
n_steps:
modulation_hidden_channels: Sequence[int
# layers and hidden channels including the output for the modulation
] str = "layer"
norm: bool = True
training:
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
@nn.compact
def __call__(
self,
# (T, W, C)
x: jnp.ndarray, -> jnp.ndarray:
) = self.encode(x[0])
z
= x[0].reshape(-1)
x0
= jnp.linspace(0, 1, self.n_steps)[..., None]
t_grid
= ModulatedSiren(
varying_amp =self.modulation_hidden_channels,
hidden_channels=1.0,
w0_first_layer="varying_amp",
name
)(t_grid, x0)= nn.sigmoid(varying_amp) * 3.0
varying_amp
= ModulatedSiren(
varying_phase =self.modulation_hidden_channels,
hidden_channels=1.0,
w0_first_layer="varying_phase",
name
)(t_grid, x0)
= varying_amp * jnp.exp(1j * varying_phase)
modulation
= self.advance(z)
z = z * modulation
z
= jnp.concatenate([z.real, z.imag], axis=-1)
z = self.decode(z)
x_hat
return x_hat
def encode(
self,
# (W, C) or (T, W, C)
x: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim)
) if len(x.shape) == 2:
= rearrange(x, "w c -> (w c)")
x elif len(x.shape) == 3:
= rearrange(x, "t w c -> t (w c)")
x return self.encoder(x)
def decode(
self,
# (hidden_dim,) or (T, hidden_dim)
z: jnp.ndarray, -> jnp.ndarray: # (W, C) or (T, W, C)
) = self.decoder(z)
z if len(z.shape) == 2:
= rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
z elif len(z.shape) == 1:
= rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
z return z
def advance(
self,
# (hidden_dim,)
z: jnp.ndarray, -> jnp.ndarray: # (T, hidden_dim)
) # convert to complex and return
# IMPORTANT: we are returning a complex array unlike the other models
# this is because we are modulating the latent space with a complex number
= z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z return z
= nn.vmap(
BatchedJAESAutoencoder1D
JAESAutoencoder1D,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods )
:::
::: {#cell-26 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 5, 16, 101, 2
B, T, W, C = 128
d_hidden = jnp.zeros((B, T, W, C))
dummy = jnp.zeros((B, T, W, C))
target
= partial(
encoder_model
nn.Dense,=d_hidden * 2,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
decoder_model
nn.Dense,=W * C,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
dynamics_model
DiscreteModalDynamicsAngleMag,=d_hidden,
d_hidden=0.5,
r_min=1.0,
r_max=0.01,
min_phase=jnp.pi * 2,
max_phase=False,
clip=False,
prepend_ones
)
= BatchedJAESAutoencoder1D(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=C,
d_vars=W,
d_model=T,
n_steps="layer",
norm=True,
training=[128, d_hidden],
modulation_hidden_channels
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, W, C)
:::
::: {#cell-27 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class JAESAutoencoder2D(nn.Module):
"""
JAES Autoencoder
"""
encoder_model: ConvEncoder
decoder_model: ConvDecoder
dynamics_model: LRUDynamicsint
d_vars: tuple[int, int]
d_model: int
n_steps:
modulation_hidden_channels: Sequence[int
# layers and hidden channels including the output for the modulation
] str = "layer"
norm: bool = True
training:
def setup(self):
self.encoder = self.encoder_model()
self.decoder = self.decoder_model()
self.dynamics = self.dynamics_model()
@nn.compact
def __call__(
self,
# (T, H, W, C)
x: jnp.ndarray, -> jnp.ndarray:
) = self.encode(x[0])
z
= x[0].reshape(-1)
x0
= jnp.linspace(0, 1, self.n_steps)[..., None]
t_grid
= ModulatedSiren(
varying_amp =self.modulation_hidden_channels,
hidden_channels=1.0,
w0_first_layer="varying_amp",
name
)(t_grid, x0)= nn.sigmoid(varying_amp) * 3.0
varying_amp
= ModulatedSiren(
varying_phase =self.modulation_hidden_channels,
hidden_channels=1.0,
w0_first_layer="varying_phase",
name
)(t_grid, x0)
= varying_amp * jnp.exp(1j * varying_phase)
modulation
= self.advance(z)
z = z * modulation
z
= jnp.concatenate([z.real, z.imag], axis=-1)
z = self.decode(z)
x_hat
return x_hat
def encode(
self,
# (W, H, C) or (T, W, H, C)
x: jnp.ndarray, -> jnp.ndarray: # (hidden_dim) or (T, hidden_dim)
) if len(x.shape) == 3:
= rearrange(x, "w h c -> (w h c)")
x elif len(x.shape) == 4:
= rearrange(x, "t w h c -> t (w h c)")
x return self.encoder(x)
def decode(
self,
# (hidden_dim) or (T, hidden_dim)
z: jnp.ndarray, -> jnp.ndarray: # (W, H, C) or (T, W, H, C)
) = self.decoder(z)
z if len(z.shape) == 1:
= rearrange(
z
z,"(w h c) -> w h c",
=self.d_model[0],
w=self.d_model[1],
h=self.d_vars,
c
)elif len(z.shape) == 2:
= rearrange(
z
z,"t (w h c) -> t w h c",
=self.d_model[0],
w=self.d_model[1],
h=self.d_vars,
c
)return z
def advance(
self,
# (hidden_dim,)
z: jnp.ndarray, -> jnp.ndarray:
) # convert to complex and back
= z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
z = self.dynamics(z, self.n_steps)
z return z
= nn.vmap(
BatchedJAESAutoencoder2D
JAESAutoencoder2D,=0, # map over the first axis of the first input not the second
in_axes=0,
out_axes={"params": None, "batch_stats": None, "cache": 0, "prime": None},
variable_axes={"params": False},
split_rngs=["__call__", "decode", "encode", "advance"],
methods="batch",
axis_name )
:::