Recurrent Models

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}

from abc import ABC, abstractmethod
from functools import partial

import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from jax.typing import ArrayLike

from physmodjax.models.mlp import MLP
from physmodjax.utils.clamp import magic_clamp
from physmodjax.utils.eigenvalues import (
    ensure_positive_imaginary_parts,
    multiply_eigenvalues,
)

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def matrix_init(key, shape, dtype=jnp.float32, normalization=1):
    return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization


def nu_init(key, shape, r_min, r_max, dtype=jnp.float32):
    u = jax.random.uniform(key=key, shape=shape, dtype=dtype)
    return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2))


def theta_init(key, shape, max_phase, dtype=jnp.float32):
    u = jax.random.uniform(key, shape=shape, dtype=dtype)
    return jnp.log(max_phase * u)


def gamma_log_init(key, lamb):
    nu, theta = lamb
    diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta))
    return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2))

:::

Base modal class

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class BaseModalDynamics(nn.Module, ABC):
    d_hidden: int
    prepend_ones: bool

    @abstractmethod
    def get_eigenvalues(self) -> ArrayLike:
        """
        Abstract method that must be implemented by subclasses.

        Returns:
            jnp.ndarray: Eigenvalues as a JAX array.
        """
        pass

    def apply_eigenvalue_transform(
        self,
        eigenvalues: ArrayLike,
        n_steps: int,
    ) -> ArrayLike:
        """Compute dynamics over n_steps based on eigenvalues."""
        z = jnp.repeat(eigenvalues[None], n_steps, axis=0)

        if self.prepend_ones:
            z = jnp.concatenate(
                [jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
            )[:-1, :]

        return jax.lax.associative_scan(jnp.multiply, z)

    def compute_dynamics(
        self,
        x: ArrayLike,
        n_steps: int,
    ) -> ArrayLike:
        """Advance the dynamics for `n_steps` using the initial state `x`."""
        eigenvalues = self.get_eigenvalues()
        dynamics = self.apply_eigenvalue_transform(eigenvalues, n_steps)
        return dynamics * x

:::

LRU dynamics

Linear dynamics using initisialisation of the eigenvalues based on the LRU paper

::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class LRUDynamics(BaseModalDynamics):
    """
    This class implements only the dynamics of the LRU model.
    x_{k+1} = A x_k
    """

    r_min: float  # smallest eigenvalue radius
    r_max: float  # largest eigenvalue radius
    max_phase: float  # largest phase
    clip_eigs: bool  # whether to clip the eigenvalues

    def setup(self):
        self.theta_log = self.param(
            "theta_log", partial(theta_init, max_phase=self.max_phase), (self.d_hidden,)
        )
        self.nu_log = self.param(
            "nu_log",
            partial(nu_init, r_min=self.r_min, r_max=self.r_max),
            (self.d_hidden,),
        )

    def __call__(
        self,
        x: jnp.ndarray,  # initial complex state flattened (d_hidden,) complex
        steps: int,  # number of steps to advance
    ) -> jnp.ndarray:  # advanced state (steps, d_hidden) complex
        A_real = -jnp.exp(self.nu_log)
        A_imag = jnp.exp(self.theta_log)

        # clip the eigenvalues to be only negative (not strictly necessary, because of the extra log)
        if self.clip_eigs:
            A_real = jnp.clip(A_real, None, -1e-5)

        A_diag = jnp.exp(A_real + 1j * A_imag)
        z = jnp.repeat(A_diag[None, :], steps, axis=0)

        if self.prepend_ones:
            # prepend ones to the beginning and slice the last element
            # this is needed to start from the initial state
            z = jnp.concatenate(
                [jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
            )[:-1, :]

        # advance the state
        x = jax.lax.associative_scan(jnp.multiply, z) * x

        return x

    def get_eigenvalues(self) -> ArrayLike:
        nu = jnp.exp(self.nu_log)
        theta = jnp.exp(self.theta_log)
        return jnp.exp(-nu + 1j * theta)

    def set_eigenvalues(self, eigenvalues: ArrayLike):
        # convert to continuous
        eigenvalues_mod = jnp.log(eigenvalues)
        eigenvalues_mod = ensure_positive_imaginary_parts(eigenvalues_mod)

        # take the log of each part and assign it to the params
        # note that the sign of the real part is flipped before taking the log
        # otherwise we would get nan values
        self.put_variable("params", "nu_log", jnp.log(-eigenvalues_mod.real))
        self.put_variable("params", "theta_log", jnp.log(eigenvalues_mod.imag))

    # def scale_dynamics(
    #     self,
    #     angle_factor: float,
    #     radius_factor: float,
    # ) -> ArrayLike:
    #     discrete_eigenvalues = self.get_eigenvalues()

    #     discrete_eigenvalues_mod = multiply_eigenvalues(
    #         discrete_eigenvalues,
    #         angle_factor,
    #         radius_factor,
    #     )

    #     # convert to continuous
    #     eigenvalues_mod = jnp.log(discrete_eigenvalues_mod)
    #     eigenvalues_mod = ensure_positive_imaginary_parts(eigenvalues_mod)

    #     # take the log of each part and assign it to the params
    #     # note that the sign of the real part is flipped before taking the log
    #     # otherwise we would get nan values
    #     self.put_variable("params", "nu_log", jnp.log(-eigenvalues_mod.real))
    #     self.put_variable("params", "theta_log", jnp.log(eigenvalues_mod.imag))

    #     return self.get_eigenvalues()

:::

::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

d_hidden = 64
steps = 50
dyn = LRUDynamics(
    d_hidden=d_hidden,
    r_min=0.99,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    clip_eigs=False,
    prepend_ones=False,
)
vars = dyn.init(jax.random.PRNGKey(0), jnp.ones(d_hidden), 50)
out = dyn.apply(vars, jnp.ones((1, d_hidden)), 50)

assert out.shape == (steps, d_hidden)

:::

LRU with MLP

::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class LRUDynamicsVarying(LRUDynamics):
    model: nn.Module  # model to process the linear state

    def setup(self):
        super().setup()

    def __call__(
        self,
        x: jnp.ndarray,  # initial complex state flattened (d_hidden,) complex
        steps: int,  # number of steps to advance
    ) -> jnp.ndarray:  # advanced state (steps, d_hidden) complex
        x = super().__call__(x, steps)
        x_hat = self.model(x.real**2 + x.imag**2)
        x_hat = x_hat[..., : self.d_hidden] + 1j * x_hat[..., self.d_hidden :]
        x = x * x_hat
        return x

:::

::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

d_hidden = 64
steps = 50
model = MLP(hidden_channels=[64, 64, 64])
dyn = LRUDynamicsVarying(
    d_hidden=d_hidden,
    r_min=0.99,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    model=model,
    clip_eigs=False,
    prepend_ones=False,
)

:::

Deep GRU

::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class DeepRNN(nn.Module):
    """
    A deep RNN model that applies a RNN cell over the last dimension of the input.
    Works with nn.GRUCell, nn.RNNCell, nn.SimpleCell, nn.MGUCell.
    """

    d_model: int
    d_vars: int
    n_layers: int
    cell: nn.Module
    training: bool = True
    norm: str = "layer"

    def setup(self):
        # scan does the same thing as nn.RNN (unrolls the over the time dimension)
        self.first_layer = nn.RNN(
            self.cell(features=self.d_model * self.d_vars),
        )

        self.layers = [
            nn.RNN(
                self.cell(features=self.d_model * self.d_vars),
            )
            for _ in range(self.n_layers)
        ]

    def __call__(
        self,
        x0: jnp.ndarray,  # (W, C) # initial state
        x: jnp.ndarray,  # (T, W, C) # empty state
    ) -> jnp.ndarray:  # (T, W, C) # advanced state
        # the rnn works over the last dimension
        # we need to reshape the input to (T, d_model * C)
        x0 = rearrange(x0, "w c -> (w c)")
        x = rearrange(x, "t w c -> t (w c)")
        x = self.first_layer(x, initial_carry=x0)
        for layer in self.layers:
            x = layer(x)
        return rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)


BatchedDeepRNN = nn.vmap(
    DeepRNN,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
    axis_name="batch",
)

:::

::: {#cell-16 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 10, 50, 20, 3
deep_rnn = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
x = jnp.ones((B, T, W, C))
x0 = jnp.ones((B, W, C))
variables = deep_rnn.init(jax.random.PRNGKey(65), x0, x)
out = deep_rnn.apply(variables, x0, x)

assert out.shape == (B, T, W, C)

:::

# TODO
# the LSTMCell will need an extra dimension for the hidden state
# deep_rnn = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
# x = jnp.ones((B, T, W, C))
# x0 = jnp.ones((B, W, C))
# variables = deep_rnn.init(jax.random.PRNGKey(65), (x0, x0), x)
# out = deep_rnn.apply(variables, (x0, x0), x)

Complex oscillator

::: {#cell-19 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def constrain_to_unit_circle(x):
    mag = jnp.abs(x)
    angle = jnp.angle(x)
    # return x * nn.tanh(mag) / mag
    x = magic_clamp(mag, 1e-8, 1.0)
    return x * jnp.exp(1j * angle)

:::

::: {#cell-20 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def init_complex(
    r_min=0.0,
    r_max=1.0,
    min_phase=0.0,
    max_phase=2 * jnp.pi,
    dtype=jnp.float32,
):
    def init(key, shape, dtype=dtype) -> ArrayLike:
        dtype = jax.dtypes.canonicalize_dtype(dtype)
        radius = jax.random.uniform(
            key,
            shape,
            dtype,
            minval=r_min,
            maxval=r_max,
        )
        phase = jax.random.uniform(
            jax.random.split(key)[0],
            shape,
            dtype,
            minval=min_phase,
            maxval=max_phase,
        )
        return radius * jnp.exp(1j * phase)

    return init


def init_real_imag(
    r_min=0.0,
    r_max=1.0,
    min_phase=0.0,
    max_phase=2 * jnp.pi,
    dtype=jnp.float32,
):
    def init(key, shape, dtype=dtype) -> ArrayLike:
        dtype = jax.dtypes.canonicalize_dtype(dtype)
        radius = jax.random.uniform(
            key,
            shape,
            dtype,
            minval=r_min,
            maxval=r_max,
        )
        phase = jax.random.uniform(
            jax.random.split(key)[0],
            shape,
            dtype,
            minval=min_phase,
            maxval=max_phase,
        )
        return jnp.stack([radius * jnp.cos(phase), radius * jnp.sin(phase)], axis=-1)

    return init

:::

::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class DiscreteModalDynamics(BaseModalDynamics):
    clip: bool
    r_min: float
    r_max: float
    min_phase: float
    max_phase: float

    def setup(self):
        self.real_imag = self.param(
            "real_imag",
            init_real_imag(
                r_min=self.r_min,
                r_max=self.r_max,
                min_phase=self.min_phase,
                max_phase=self.max_phase,
            ),
            (self.d_hidden,),
        )

        self.z = self.real_imag[..., 0] + 1j * self.real_imag[..., 1]

        if self.clip:
            self.z = constrain_to_unit_circle(self.z)

    def __call__(
        self,
        x: ArrayLike,  # initial complex state (d_hidden,)
        n_steps: int,  # number of steps to advance
    ):
        z = jnp.repeat(self.z[None, :], n_steps, axis=0)

        if self.prepend_ones:
            # prepend ones to the beginning and slice the last element
            # this is needed to start from the initial state
            z = jnp.concatenate(
                [jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
            )[:-1, :]
        return jax.lax.associative_scan(jnp.multiply, z) * x

    def get_eigenvalues(self) -> ArrayLike:
        if self.clip:
            z = constrain_to_unit_circle(self.z)

        return z

    def set_eigenvalues(self, eigenvalues: ArrayLike):
        eigenvalues_as_real_imag = jnp.stack(
            [jnp.real(eigenvalues), jnp.imag(eigenvalues)],
            axis=-1,
        )
        self.put_variable("params", "real_imag", eigenvalues_as_real_imag)

:::

::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

d_hidden = 10
clip = True
reduce = False
n_steps = 100

model = DiscreteModalDynamics(
    d_hidden=d_hidden,
    clip=clip,
    r_min=0.3,
    r_max=1,
    min_phase=0.01,
    max_phase=jnp.pi / 2,
    prepend_ones=False,
)

output, variables = model.init_with_output(
    jax.random.PRNGKey(0),
    jnp.ones((d_hidden,)),
    n_steps=n_steps,
)

assert output.shape == (
    n_steps,
    d_hidden,
)

:::

::: {#cell-23 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class DiscreteModalDynamicsVarying(DiscreteModalDynamics):
    model: nn.Module  # model to process the linear state

    def setup(self):
        super().setup()

    def __call__(
        self,
        x: jnp.ndarray,  # initial complex state flattened (d_hidden,) complex
        n_steps: int,  # number of steps to advance
    ) -> jnp.ndarray:  # advanced state (steps, d_hidden) complex
        x = super().__call__(x, n_steps)
        x_hat = self.model(x.real**2 + x.imag**2)
        x_hat = x_hat[..., : self.d_hidden] + 1j * x_hat[..., self.d_hidden :]
        x = x * x_hat
        return x

:::

::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

d_hidden = 10
clip = True
reduce = False
n_steps = 100

model = DiscreteModalDynamicsVarying(
    d_hidden=d_hidden,
    clip=clip,
    r_min=0.3,
    r_max=1,
    min_phase=0.01,
    max_phase=jnp.pi / 2,
    model=MLP(hidden_channels=[64, 64, 20]),
    prepend_ones=False,
)

output, variables = model.init_with_output(
    jax.random.PRNGKey(0),
    jnp.ones((d_hidden,)),
    n_steps=n_steps,
)
assert output.shape == (
    n_steps,
    d_hidden,
)

:::

::: {#cell-25 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class DiscreteModalDynamicsAngleMag(BaseModalDynamics):
    clip: bool
    r_min: float
    r_max: float
    min_phase: float
    max_phase: float

    def setup(self):
        self.angles = self.param(
            "angles",
            lambda key, shape: jax.random.uniform(
                key,
                shape,
                minval=self.min_phase,
                maxval=self.max_phase,
            ),
            (self.d_hidden,),
        )

        self.magnitudes = self.param(
            "magnitudes",
            lambda key, shape: jax.random.uniform(
                key,
                shape,
                minval=self.r_min,
                maxval=self.r_max,
            ),
            (self.d_hidden,),
        )

    @nn.compact
    def __call__(
        self,
        x: jnp.ndarray,  # initial complex state flattened (d_hidden,) complex
        n_steps: int,  # number of steps to advance
    ):
        if self.clip:
            magnitudes = magic_clamp(self.magnitudes, 1e-8, 1.0)
        else:
            magnitudes = self.magnitudes

        z = magnitudes * jnp.exp(1j * self.angles)

        z = jnp.repeat(z[None], n_steps, axis=0)

        if self.prepend_ones:
            # prepend ones to the beginning and slice the last element
            # this is needed to start from the initial state
            z = jnp.concatenate(
                [jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
            )[:-1, :]
        return jax.lax.associative_scan(jnp.multiply, z) * x

    def get_eigenvalues(self) -> ArrayLike:
        if self.clip:
            magnitudes = jnp.clip(self.magnitudes, 1e-8, 1.0)
        else:
            magnitudes = self.magnitudes

        return magnitudes * jnp.exp(1j * self.angles)

    def set_eigenvalues(self, eigenvalues: ArrayLike):
        self.put_variable("params", "angles", jnp.angle(eigenvalues))
        self.put_variable("params", "magnitudes", jnp.abs(eigenvalues))

:::