Autoencoder models

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}

from collections.abc import Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from jax.typing import ArrayLike

from physmodjax.models.conv import ConvDecoder, ConvEncoder
from physmodjax.models.mlp import ModulatedSiren
from physmodjax.models.recurrent import LRUDynamics
from physmodjax.utils.data import create_grid
2025-01-07 11:25:03.453384: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.85). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

:::

from functools import partial

from physmodjax.models.recurrent import (
    DiscreteModalDynamicsAngleMag,
    LRUDynamicsVarying,
)

2D Convolutional Autoencoder with Linear Dynamics

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class FourierAutoencoder2D(nn.Module):
    dynamics_model: nn.Module
    d_vars: int
    d_model: tuple[int, int]
    norm: str = "layer"
    training: bool = True
    use_positions: bool = False
    n_modes: int = 20

    def setup(self):
        self.dynamics = self.dynamics_model()
        if self.use_positions:
            self.grid = create_grid(self.d_model[1], self.d_model[0])

    def __call__(
        self,
        x: jnp.ndarray,  # (W, H, C) or (T, W, H, C)
    ) -> jnp.ndarray:
        z = self.encode(x)
        z = self.advance(z, x)
        return self.decode(z)

    def encode(
        self,
        x: jnp.ndarray,  # (W, H, C) or (T, W, H, C)
    ) -> jnp.ndarray:  # (hidden_dim) or (T, hidden_dim) complex
        """
        Spatial encoding of the input data.
        """

        if self.use_positions:
            x = jnp.concatenate([x, self.grid], axis=-1)

        *_, W, H, C = x.shape
        x = jnp.fft.rfft2(x, s=(W * 2 - 1, H * 2 - 1), axes=(-3, -2), norm="ortho")
        x = x[..., : self.n_modes, : self.n_modes, :]

        if len(x.shape) == 3:
            x = rearrange(x, "w h c -> (w h c)")
        elif len(x.shape) == 4:
            x = rearrange(x, "t w h c -> t (w h c)")

        return x

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,) complex
        steps: jnp.ndarray,  # (T,)
    ) -> jnp.ndarray:  # (T, hidden_dim,)
        # z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
        z = self.dynamics(z, steps.shape[0])
        # z = jnp.concatenate([z.real, z.imag], axis=-1)
        return z

    def decode(
        self,
        z: jnp.ndarray,  # (hidden_dim) or (T, hidden_dim) complex
    ) -> jnp.ndarray:  # (W, H, C) or (T, W, H, C) real
        w_modes = self.n_modes
        h_modes = self.n_modes
        if len(z.shape) == 1:
            z = rearrange(z, "(w h c) -> w h c", w=w_modes, h=h_modes, c=self.d_vars)
        elif len(z.shape) == 2:
            z = rearrange(
                z, "t (w h c) -> t w h c", w=w_modes, h=h_modes, c=self.d_vars
            )

        z = jnp.fft.irfft2(
            z, s=(self.d_model[0], self.d_model[1]), axes=(-3, -2), norm="ortho"
        )
        return z


BatchedFourierAutoencoder2D = nn.vmap(
    FourierAutoencoder2D,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
)

:::

::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

d_hidden = 128
B, T, H, W, C = 5, 16, 41, 37, 2
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))

dynamics_model = partial(
    LRUDynamics,
    d_hidden=(20 * 20 * 2),
    r_min=0.9,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    clip_eigs=False,
    prepend_ones=False,
)

model = BatchedFourierAutoencoder2D(
    dynamics_model=dynamics_model,
    d_vars=C,
    d_model=(H, W),
    norm="layer",
    training=True,
    n_modes=20,
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, H, W, C)

:::

::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class DenseKoopmanAutoencoder2D(nn.Module):
    """
    Koopman Dense Autoencoder
    """

    encoder_model: nn.Module
    decoder_model: nn.Module
    dynamics_model: nn.Module
    d_vars: int
    d_model: tuple[int, int]
    n_steps: int
    norm: str = "layer"
    training: bool = True
    use_positions: bool = False

    def setup(self):
        self.encoder = self.encoder_model()
        self.decoder = self.decoder_model()
        self.dynamics = self.dynamics_model()
        if self.use_positions:
            self.grid = create_grid(self.d_model[1], self.d_model[0])

    def __call__(
        self,
        x: jnp.ndarray,  # (W, H, C)
    ) -> jnp.ndarray:
        z = self.encode(x[0])
        z = self.advance(z)
        return self.decode(z)

    def encode(
        self,
        x: jnp.ndarray,  # (W, H, C) or (T, W, H, C)
    ) -> jnp.ndarray:  # (hidden_dim) or (T, hidden_dim)
        if self.use_positions:
            x = jnp.concatenate([x, self.grid], axis=-1)

        if len(x.shape) == 3:
            x = rearrange(x, "w h c -> (w h c)")
        elif len(x.shape) == 4:
            x = rearrange(x, "t w h c -> t (w h c)")
        return self.encoder(x)

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,)
    ) -> jnp.ndarray:  # (T, hidden_dim,)
        z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
        z = self.dynamics(z, self.n_steps)
        z = jnp.concatenate([z.real, z.imag], axis=-1)
        return z

    def decode(
        self,
        z: jnp.ndarray,  # (hidden_dim) or (T, hidden_dim)
    ) -> jnp.ndarray:  # (W, H, C) or (T, W, H, C)
        z = self.decoder(z)
        if len(z.shape) == 1:
            z = rearrange(
                z,
                "(w h c) -> w h c",
                w=self.d_model[0],
                h=self.d_model[1],
                c=self.d_vars,
            )
        elif len(z.shape) == 2:
            z = rearrange(
                z,
                "t (w h c) -> t w h c",
                w=self.d_model[0],
                h=self.d_model[1],
                c=self.d_vars,
            )
        return z


BatchedDenseKoopmanAutoencoder2D = nn.vmap(
    DenseKoopmanAutoencoder2D,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
)

:::

::: {#cell-9 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

n_steps = 16
d_hidden = 128
B, T, H, W, C = 5, n_steps, 41, 37, 3
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))

encoder_model = partial(nn.Dense, features=d_hidden * 2)
decoder_model = partial(nn.Dense, features=H * W * C)
dynamics_model = partial(
    LRUDynamics,
    d_hidden=d_hidden,
    r_min=0.9,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    clip_eigs=False,
    prepend_ones=False,
)

model = BatchedDenseKoopmanAutoencoder2D(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    n_steps=n_steps,
    d_vars=C,
    d_model=(H, W),
    norm="layer",
    training=True,
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, H, W, C)

:::

::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

encoded = model.apply(vars, dummy, method="encode")
decoded = model.apply(vars, encoded, method="decode")

enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode")

:::

::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class KoopmanAutoencoder2D(nn.Module):
    """
    Koopman Autoencoder
    """

    encoder_model: ConvEncoder
    decoder_model: ConvDecoder
    dynamics_model: LRUDynamics
    d_latent_channels: int
    d_latent_dims: tuple[int, int]
    n_steps: int
    norm: str = "layer"
    training: bool = True

    def setup(self):
        self.encoder = self.encoder_model(training=self.training, norm=self.norm)
        self.decoder = self.decoder_model(training=self.training, norm=self.norm)
        self.dynamics = self.dynamics_model()

    def __call__(
        self,
        x: jnp.ndarray,  # (H, W, C)
    ) -> jnp.ndarray:  # (T, H, W, C)
        z = self.encode(x[0])
        z = self.advance(z)
        x_hat = self.decode(z)
        return x_hat

    def encode(
        self,
        x: jnp.ndarray,  # (H, W, C) or (T, H, W, C)
    ) -> jnp.ndarray:  # (hidden_dim,) or (T, hidden_dim)
        z = self.encoder(x)
        if len(z.shape) == 4:
            z = rearrange(z, "t h w c -> t (h w c)")
        elif len(z.shape) == 3:
            z = rearrange(z, "h w c -> (h w c)")
        return z

    def decode(
        self,
        z: jnp.ndarray,  # (hidden_dim,)  or (T, hidden_dim)
    ) -> jnp.ndarray:  # (H, W, C) or (T, H, W, C)
        if len(z.shape) == 2:
            z = rearrange(
                z,
                "t (h w c) -> t h w c",
                h=self.d_latent_dims[0],
                w=self.d_latent_dims[1],
                c=self.d_latent_channels,
            )
        elif len(z.shape) == 1:
            z = rearrange(
                z,
                "(h w c) -> h w c",
                h=self.d_latent_dims[0],
                w=self.d_latent_dims[1],
                c=self.d_latent_channels,
            )
        return self.decoder(z)

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,)
    ) -> jnp.ndarray:
        # convert to complex and back
        z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
        z = self.dynamics(z, self.n_steps)
        z = jnp.concatenate([z.real, z.imag], axis=-1)
        return z


BatchedKoopmanAutoencoder2D = nn.vmap(
    KoopmanAutoencoder2D,
    in_axes=0,  # map over the first axis of the first input not the second
    out_axes=0,
    variable_axes={"params": None, "batch_stats": None, "cache": 0, "prime": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
    axis_name="batch",
)

:::

::: {#cell-12 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

n_steps = 16
d_hidden = 128
B, T, H, W, C = 5, n_steps, 40, 40, 3
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))

encoder_model = partial(ConvEncoder, block_size=(8, 16, 32))
decoder_model = partial(ConvDecoder, block_size=(8, 16, 32))
dynamics_model = partial(
    LRUDynamics,
    d_hidden=16 * 5 * 5,
    r_min=0.9,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    clip_eigs=False,
    prepend_ones=False,
)

model = BatchedKoopmanAutoencoder2D(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    d_latent_channels=32,
    d_latent_dims=(5, 5),
    n_steps=16,
    norm="layer",
    training=True,
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, H, W, C)

:::

::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

encoded = model.apply(vars, dummy, method="encode")
decoded = model.apply(vars, encoded, method="decode")

enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode")

:::

Koopman Autoencoder 1D

::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class KoopmanAutoencoder1D(nn.Module):
    """
    Koopman Autoencoder
    """

    encoder_model: nn.Module
    decoder_model: nn.Module
    dynamics_model: nn.Module
    d_vars: int
    d_model: int
    n_steps: int
    norm: str = "layer"
    training: bool = True

    def setup(self):
        self.encoder = self.encoder_model()
        self.decoder = self.decoder_model()
        self.dynamics = self.dynamics_model()

    def __call__(
        self,
        x: jnp.ndarray,  # (T, W, C)
    ) -> jnp.ndarray:
        z = self.encode(x[0])
        z = self.advance(z)
        x_hat = self.decode(z)
        return x_hat

    def encode(
        self,
        x: jnp.ndarray,  # (W, C) or (T, W, C)
    ) -> jnp.ndarray:  # (T, hidden_dim)
        if len(x.shape) == 2:
            x = rearrange(x, "w c -> (w c)")
        elif len(x.shape) == 3:
            x = rearrange(x, "t w c -> t (w c)")
        return self.encoder(x)

    def decode(
        self,
        z: jnp.ndarray,  # (hidden_dim,) or (T, hidden_dim)
    ) -> jnp.ndarray:  # (W, C) or (T, W, C)
        z = self.decoder(z)
        if len(z.shape) == 2:
            z = rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
        elif len(z.shape) == 1:
            z = rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
        return z

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,)
    ) -> jnp.ndarray:  # (T, hidden_dim)
        # convert to complex and back
        z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
        z = self.dynamics(z, self.n_steps)
        z = jnp.concatenate([z.real, z.imag], axis=-1)
        return z


BatchedKoopmanAutoencoder1D = nn.vmap(
    KoopmanAutoencoder1D,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
)

:::

::: {#cell-16 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 5, 16, 101, 3
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))

encoder_model = partial(
    nn.Dense,
    features=d_hidden * 2,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
decoder_model = partial(
    nn.Dense,
    features=W * C,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
dynamics_model = partial(
    LRUDynamicsVarying,
    d_hidden=d_hidden,
    r_min=0.9,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    model=nn.Dense(features=d_hidden * 2, kernel_init=nn.initializers.orthogonal()),
    clip_eigs=False,
    prepend_ones=False,
)

model = BatchedKoopmanAutoencoder1D(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    d_vars=C,
    d_model=W,
    n_steps=T,
    norm="layer",
    training=True,
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, W, C)
E0107 11:25:15.482350   37996 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[13], line 42
     20 dynamics_model = partial(
     21     LRUDynamicsVarying,
     22     d_hidden=d_hidden,
   (...)
     28     prepend_ones=False,
     29 )
     31 model = BatchedKoopmanAutoencoder1D(
     32     encoder_model=encoder_model,
     33     decoder_model=decoder_model,
   (...)
     39     training=True,
     40 )
---> 42 vars = model.init(jax.random.PRNGKey(0), dummy)
     43 out = model.apply(vars, dummy)
     45 assert out.shape == (B, T, W, C)

    [... skipping hidden 17 frame]

Cell In[12], line 27, in KoopmanAutoencoder1D.__call__(self, x)
     23 def __call__(
     24     self,
     25     x: jnp.ndarray,  # (T, W, C)
     26 ) -> jnp.ndarray:
---> 27     z = self.encode(x[0])
     28     z = self.advance(z)
     29     x_hat = self.decode(z)

    [... skipping hidden 2 frame]

Cell In[12], line 40, in KoopmanAutoencoder1D.encode(self, x)
     38 elif len(x.shape) == 3:
     39     x = rearrange(x, "t w c -> t (w c)")
---> 40 return self.encoder(x)

    [... skipping hidden 2 frame]

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/flax/linen/linear.py:251, in Dense.__call__(self, inputs)
    241 @compact
    242 def __call__(self, inputs: Array) -> Array:
    243   """Applies a linear transformation to the inputs along the last dimension.
    244 
    245   Args:
   (...)
    249     The transformed input.
    250   """
--> 251   kernel = self.param(
    252     'kernel',
    253     self.kernel_init,
    254     (jnp.shape(inputs)[-1], self.features),
    255     self.param_dtype,
    256   )
    257   if self.use_bias:
    258     bias = self.param(
    259       'bias', self.bias_init, (self.features,), self.param_dtype
    260     )

    [... skipping hidden 2 frame]

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/jax/_src/nn/initializers.py:611, in orthogonal.<locals>.init(key, shape, dtype)
    609 matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
    610 A = random.normal(key, matrix_shape, dtype)
--> 611 Q, R = jnp.linalg.qr(A)
    612 diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
    613 Q *= diag_sign # needed for a uniform distribution

    [... skipping hidden 10 frame]

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1185, in ExecuteReplicated.__call__(self, *args)
   1183   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1184 else:
-> 1185   results = self.xla_executable.execute_sharded(input_bufs)
   1186 if dispatch.needs_check_special():
   1187   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

:::

::: {#cell-17 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

print(dummy.shape)
encoded = model.apply(vars, dummy, method="encode")
print("encoded", encoded.shape)
decoded = model.apply(vars, encoded, method="decode")
assert decoded.shape == dummy.shape

enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode")

assert dec_sequence.shape == (B, T, W, C)

:::

::: {#cell-18 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class KoopmanAutoencoder1DReal(nn.Module):
    """
    Koopman Autoencoder but with real encoding and decoding
    """

    encoder_model: nn.Module
    decoder_model: nn.Module
    dynamics_model: nn.Module
    d_vars: int
    d_model: int
    n_steps: int
    norm: str = "layer"
    training: bool = True

    def setup(self):
        self.encoder = self.encoder_model()
        self.decoder = self.decoder_model()
        self.dynamics = self.dynamics_model()

    def __call__(
        self,
        x: jnp.ndarray,  # (T, W, C)
    ) -> jnp.ndarray:
        z = self.encode(x[0])
        z = self.advance(z)
        x_hat = self.decode(z)
        return x_hat

    def encode(
        self,
        x: jnp.ndarray,  # (W, C) or (T, W, C)
    ) -> jnp.ndarray:  # (T, hidden_dim)
        if len(x.shape) == 2:
            x = rearrange(x, "w c -> (w c)")
        elif len(x.shape) == 3:
            x = rearrange(x, "t w c -> t (w c)")
        return self.encoder(x)

    def decode(
        self,
        z: jnp.ndarray,  # (hidden_dim,) or (T, hidden_dim)
    ) -> jnp.ndarray:  # (W, C) or (T, W, C)
        z = self.decoder(z)
        if len(z.shape) == 2:
            z = rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
        elif len(z.shape) == 1:
            z = rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
        return z

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,)
    ) -> jnp.ndarray:  # (T, hidden_dim)
        # convert to complex and back
        z = z + 1j * 0.0
        z = self.dynamics(z, self.n_steps).real
        return z


BatchedKoopmanAutoencoder1DReal = nn.vmap(
    KoopmanAutoencoder1DReal,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
)

:::

::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 5, 16, 101, 3
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))

encoder_model = partial(
    nn.Dense,
    features=d_hidden,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
decoder_model = partial(
    nn.Dense,
    features=W * C,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
dynamics_model = partial(
    LRUDynamicsVarying,
    d_hidden=d_hidden,
    r_min=0.9,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    model=nn.Dense(
        features=d_hidden * 2,
        kernel_init=nn.initializers.orthogonal(),
    ),
    clip_eigs=False,
    prepend_ones=False,
)

model = BatchedKoopmanAutoencoder1DReal(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    d_vars=C,
    d_model=W,
    n_steps=T,
    norm="layer",
    training=True,
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, W, C)

:::

Tied autoencoder

::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class TiedAutoencoder(nn.Module):
    """
    Tied (real) Autoencoder
    """

    d_model: int
    d_vars: int
    d_hidden: int
    dynamics_model: nn.Module
    n_steps: int
    norm: str = "layer"
    training: bool = True
    dtype = None
    param_dtype = jnp.float32
    kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()

    def setup(self):
        self.dynamics = self.dynamics_model()
        self.kernel = self.param(
            "kernel",
            self.kernel_init,
            (self.d_model * self.d_vars, self.d_hidden),
            self.param_dtype,
        )

    def encode(self, x: ArrayLike) -> ArrayLike:
        if len(x.shape) == 2:
            x = rearrange(x, "w c -> (w c)")
        elif len(x.shape) == 3:
            x = rearrange(x, "t w c -> t (w c)")

        z = jax.lax.dot_general(
            x,
            self.kernel,
            (((x.ndim - 1,), (0,)), ((), ())),
        )
        return z

    def decode(self, z: ArrayLike) -> ArrayLike:
        x = jax.lax.dot_general(
            z,
            self.kernel.T,
            (((z.ndim - 1,), (0,)), ((), ())),
        )

        if len(x.shape) == 2:
            x = rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
        elif len(x.shape) == 1:
            x = rearrange(x, "(w c) -> w c", w=self.d_model, c=self.d_vars)

        return x

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,)
    ) -> jnp.ndarray:  # (T, hidden_dim)
        # convert to complex and back
        z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
        z = self.dynamics(z, self.n_steps)
        z = jnp.concatenate([z.real, z.imag], axis=-1)
        return z

    def __call__(
        self,
        inputs: ArrayLike,  # (T, H, C)
    ) -> ArrayLike:
        z = self.encode(inputs[0])
        z = self.advance(z)
        x_hat = self.decode(z)
        return x_hat


BatchedTiedAutoencoder = nn.vmap(
    TiedAutoencoder,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
)

:::

::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 5, 16, 101, 3

dummy = jnp.zeros((B, T, W * C))
target = jnp.zeros((B, T, W * C))

d_input = W * C
d_hidden = 128
n_steps = 16

model = BatchedTiedAutoencoder(
    d_vars=C,
    d_model=W,
    d_hidden=d_hidden * 2,
    dynamics_model=partial(
        LRUDynamicsVarying,
        d_hidden=d_hidden,
        r_min=0.9,
        r_max=1.0,
        max_phase=jnp.pi * 2,
        model=nn.Dense(
            features=d_hidden * 2,
            kernel_init=nn.initializers.orthogonal(),
        ),
        clip_eigs=False,
        prepend_ones=False,
    ),
    n_steps=n_steps,
)

output, variables = model.init_with_output(jax.random.PRNGKey(0), dummy)

assert output.shape == (B, T, W, C)

:::

from physmodjax.models.recurrent import DiscreteModalDynamics

::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 5, 16, 101, 3

dummy = jnp.zeros((B, T, W * C))
target = jnp.zeros((B, T, W * C))

d_input = W * C
d_hidden = 128
n_steps = 16

model = BatchedTiedAutoencoder(
    d_vars=C,
    d_model=W,
    d_hidden=d_hidden * 2,
    n_steps=n_steps,
    dynamics_model=partial(
        DiscreteModalDynamics,
        d_hidden=d_hidden,
        clip=True,
        r_min=0.3,
        r_max=1,
        min_phase=0.01,
        max_phase=jnp.pi / 2,
        prepend_ones=False,
    ),
)

output, variables = model.init_with_output(jax.random.PRNGKey(0), dummy)

jax.value_and_grad(lambda x: model.apply(variables, x).mean())(dummy)
assert output.shape == (B, T, W, C)

:::

::: {#cell-25 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class JAESAutoencoder1D(nn.Module):
    """
    JAES Autoencoder
    """

    encoder_model: nn.Module
    decoder_model: nn.Module
    dynamics_model: nn.Module
    d_vars: int
    d_model: int
    n_steps: int
    modulation_hidden_channels: Sequence[
        int
    ]  # layers and hidden channels including the output for the modulation
    norm: str = "layer"
    training: bool = True

    def setup(self):
        self.encoder = self.encoder_model()
        self.decoder = self.decoder_model()
        self.dynamics = self.dynamics_model()

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,  # (T, W, C)
    ) -> jnp.ndarray:
        z = self.encode(x[0])

        x0 = x[0].reshape(-1)

        t_grid = jnp.linspace(0, 1, self.n_steps)[..., None]

        varying_amp = ModulatedSiren(
            hidden_channels=self.modulation_hidden_channels,
            w0_first_layer=1.0,
            name="varying_amp",
        )(t_grid, x0)
        varying_amp = nn.sigmoid(varying_amp) * 3.0

        varying_phase = ModulatedSiren(
            hidden_channels=self.modulation_hidden_channels,
            w0_first_layer=1.0,
            name="varying_phase",
        )(t_grid, x0)

        modulation = varying_amp * jnp.exp(1j * varying_phase)

        z = self.advance(z)
        z = z * modulation

        z = jnp.concatenate([z.real, z.imag], axis=-1)
        x_hat = self.decode(z)

        return x_hat

    def encode(
        self,
        x: jnp.ndarray,  # (W, C) or (T, W, C)
    ) -> jnp.ndarray:  # (T, hidden_dim)
        if len(x.shape) == 2:
            x = rearrange(x, "w c -> (w c)")
        elif len(x.shape) == 3:
            x = rearrange(x, "t w c -> t (w c)")
        return self.encoder(x)

    def decode(
        self,
        z: jnp.ndarray,  # (hidden_dim,) or (T, hidden_dim)
    ) -> jnp.ndarray:  # (W, C) or (T, W, C)
        z = self.decoder(z)
        if len(z.shape) == 2:
            z = rearrange(z, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
        elif len(z.shape) == 1:
            z = rearrange(z, "(w c) -> w c", w=self.d_model, c=self.d_vars)
        return z

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,)
    ) -> jnp.ndarray:  # (T, hidden_dim)
        # convert to complex and return
        # IMPORTANT: we are returning a complex array unlike the other models
        # this is because we are modulating the latent space with a complex number
        z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
        z = self.dynamics(z, self.n_steps)
        return z


BatchedJAESAutoencoder1D = nn.vmap(
    JAESAutoencoder1D,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
)

:::

::: {#cell-26 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 5, 16, 101, 2
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))

encoder_model = partial(
    nn.Dense,
    features=d_hidden * 2,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
decoder_model = partial(
    nn.Dense,
    features=W * C,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
dynamics_model = partial(
    DiscreteModalDynamicsAngleMag,
    d_hidden=d_hidden,
    r_min=0.5,
    r_max=1.0,
    min_phase=0.01,
    max_phase=jnp.pi * 2,
    clip=False,
    prepend_ones=False,
)

model = BatchedJAESAutoencoder1D(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    d_vars=C,
    d_model=W,
    n_steps=T,
    norm="layer",
    training=True,
    modulation_hidden_channels=[128, d_hidden],
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, W, C)

:::

::: {#cell-27 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class JAESAutoencoder2D(nn.Module):
    """
    JAES Autoencoder
    """

    encoder_model: ConvEncoder
    decoder_model: ConvDecoder
    dynamics_model: LRUDynamics
    d_vars: int
    d_model: tuple[int, int]
    n_steps: int
    modulation_hidden_channels: Sequence[
        int
    ]  # layers and hidden channels including the output for the modulation
    norm: str = "layer"
    training: bool = True

    def setup(self):
        self.encoder = self.encoder_model()
        self.decoder = self.decoder_model()
        self.dynamics = self.dynamics_model()

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,  # (T, H, W, C)
    ) -> jnp.ndarray:
        z = self.encode(x[0])

        x0 = x[0].reshape(-1)

        t_grid = jnp.linspace(0, 1, self.n_steps)[..., None]

        varying_amp = ModulatedSiren(
            hidden_channels=self.modulation_hidden_channels,
            w0_first_layer=1.0,
            name="varying_amp",
        )(t_grid, x0)
        varying_amp = nn.sigmoid(varying_amp) * 3.0

        varying_phase = ModulatedSiren(
            hidden_channels=self.modulation_hidden_channels,
            w0_first_layer=1.0,
            name="varying_phase",
        )(t_grid, x0)

        modulation = varying_amp * jnp.exp(1j * varying_phase)

        z = self.advance(z)
        z = z * modulation

        z = jnp.concatenate([z.real, z.imag], axis=-1)
        x_hat = self.decode(z)

        return x_hat

    def encode(
        self,
        x: jnp.ndarray,  # (W, H, C) or (T, W, H, C)
    ) -> jnp.ndarray:  # (hidden_dim) or (T, hidden_dim)
        if len(x.shape) == 3:
            x = rearrange(x, "w h c -> (w h c)")
        elif len(x.shape) == 4:
            x = rearrange(x, "t w h c -> t (w h c)")
        return self.encoder(x)

    def decode(
        self,
        z: jnp.ndarray,  # (hidden_dim) or (T, hidden_dim)
    ) -> jnp.ndarray:  # (W, H, C) or (T, W, H, C)
        z = self.decoder(z)
        if len(z.shape) == 1:
            z = rearrange(
                z,
                "(w h c) -> w h c",
                w=self.d_model[0],
                h=self.d_model[1],
                c=self.d_vars,
            )
        elif len(z.shape) == 2:
            z = rearrange(
                z,
                "t (w h c) -> t w h c",
                w=self.d_model[0],
                h=self.d_model[1],
                c=self.d_vars,
            )
        return z

    def advance(
        self,
        z: jnp.ndarray,  # (hidden_dim,)
    ) -> jnp.ndarray:
        # convert to complex and back
        z = z[: z.shape[0] // 2] + 1j * z[z.shape[0] // 2 :]
        z = self.dynamics(z, self.n_steps)
        return z


BatchedJAESAutoencoder2D = nn.vmap(
    JAESAutoencoder2D,
    in_axes=0,  # map over the first axis of the first input not the second
    out_axes=0,
    variable_axes={"params": None, "batch_stats": None, "cache": 0, "prime": None},
    split_rngs={"params": False},
    methods=["__call__", "decode", "encode", "advance"],
    axis_name="batch",
)

:::