SSM Models

S5 Model

adapted from https://github.com/lindermanlab/S5

::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}

from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp
import jax.numpy as np
from flax import linen as nn
from jax import random
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh

from physmodjax.models.recurrent import gamma_log_init, matrix_init, nu_init, theta_init

:::

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def make_HiPPO(N):
    """Create a HiPPO-LegS matrix.
    From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Args:
        N (int32): state size
    Returns:
        N x N HiPPO LegS matrix
    """
    P = np.sqrt(1 + 2 * np.arange(N))
    A = P[:, np.newaxis] * P[np.newaxis, :]
    A = np.tril(A) - np.diag(np.arange(N))
    return -A


def make_NPLR_HiPPO(N):
    """
    Makes components needed for NPLR representation of HiPPO-LegS
     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Args:
        N (int32): state size

    Returns:
        N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B

    """
    # Make -HiPPO
    hippo = make_HiPPO(N)

    # Add in a rank 1 term. Makes it Normal.
    P = np.sqrt(np.arange(N) + 0.5)

    # HiPPO also specifies the B matrix
    B = np.sqrt(2 * np.arange(N) + 1.0)
    return hippo, P, B


def make_DPLR_HiPPO(N):
    """
    Makes components needed for DPLR representation of HiPPO-LegS
     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Note, we will only use the diagonal part
    Args:
        N:

    Returns:
        eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,
        eigenvectors V, HiPPO B pre-conjugation

    """
    A, P, B = make_NPLR_HiPPO(N)

    S = A + P[:, np.newaxis] * P[np.newaxis, :]

    S_diag = np.diagonal(S)
    Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)

    # Diagonalize S to V \Lambda V^*
    Lambda_imag, V = eigh(S * -1j)

    P = V.conj().T @ P
    B_orig = B
    B = V.conj().T @ B
    return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig


def log_step_initializer(dt_min=0.001, dt_max=0.1):
    """Initialize the learnable timescale Delta by sampling
    uniformly between dt_min and dt_max.
    Args:
        dt_min (float32): minimum value
        dt_max (float32): maximum value
    Returns:
        init function
    """

    def init(key, shape):
        """Init function
        Args:
            key: jax random key
            shape tuple: desired shape
        Returns:
            sampled log_step (float32)
        """
        return random.uniform(key, shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(
            dt_min
        )

    return init


def init_log_steps(key, input):
    """Initialize an array of learnable timescale parameters
    Args:
        key: jax random key
        input: tuple containing the array shape H and
               dt_min and dt_max
    Returns:
        initialized array of timescales (float32): (H,)
    """
    H, dt_min, dt_max = input
    log_steps = []
    for i in range(H):
        key, skey = random.split(key)
        log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,))
        log_steps.append(log_step)

    return np.array(log_steps)


def init_VinvB(init_fun, rng, shape, Vinv):
    """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.
    Note we will parameterize this with two different matrices for complex
    numbers.
     Args:
         init_fun:  the initialization function to use, e.g. lecun_normal()
         rng:       jax random key to be used with init function.
         shape (tuple): desired shape  (P,H)
         Vinv: (complex64)     the inverse eigenvectors used for initialization
     Returns:
         B_tilde (complex64) of shape (P,H,2)
    """
    B = init_fun(rng, shape)
    VinvB = Vinv @ B
    VinvB_real = VinvB.real
    VinvB_imag = VinvB.imag
    return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)


def trunc_standard_normal(key, shape):
    """Sample C with a truncated normal distribution with standard deviation 1.
    Args:
        key: jax random key
        shape (tuple): desired shape, of length 3, (H,P,_)
    Returns:
        sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)
    """
    H, P, _ = shape
    Cs = []
    for i in range(H):
        key, skey = random.split(key)
        C = lecun_normal()(skey, shape=(1, P, 2))
        Cs.append(C)
    return np.array(Cs)[:, 0]


def init_CV(init_fun, rng, shape, V):
    """Initialize C_tilde=CV. First sample C. Then compute CV.
    Note we will parameterize this with two different matrices for complex
    numbers.
     Args:
         init_fun:  the initialization function to use, e.g. lecun_normal()
         rng:       jax random key to be used with init function.
         shape (tuple): desired shape  (H,P)
         V: (complex64)     the eigenvectors used for initialization
     Returns:
         C_tilde (complex64) of shape (H,P,2)
    """
    C_ = init_fun(rng, shape)
    C = C_[..., 0] + 1j * C_[..., 1]
    CV = C @ V
    CV_real = CV.real
    CV_imag = CV.imag
    return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1)

:::

::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

# Discretization functions
def discretize_bilinear(Lambda, B_tilde, Delta):
    """Discretize a diagonalized, continuous-time linear SSM
    using bilinear transform method.
    Args:
        Lambda (complex64): diagonal state matrix              (P,)
        B_tilde (complex64): input matrix                      (P, H)
        Delta (float32): discretization step sizes             (P,)
    Returns:
        discretized Lambda_bar (complex64), B_bar (complex64)  (P,), (P,H)
    """
    Identity = np.ones(Lambda.shape[0])

    BL = 1 / (Identity - (Delta / 2.0) * Lambda)
    Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda)
    B_bar = (BL * Delta)[..., None] * B_tilde
    return Lambda_bar, B_bar


def discretize_zoh(Lambda, B_tilde, Delta):
    """Discretize a diagonalized, continuous-time linear SSM
    using zero-order hold method.
    Args:
        Lambda (complex64): diagonal state matrix              (P,)
        B_tilde (complex64): input matrix                      (P, H)
        Delta (float32): discretization step sizes             (P,)
    Returns:
        discretized Lambda_bar (complex64), B_bar (complex64)  (P,), (P,H)
    """
    Identity = np.ones(Lambda.shape[0])
    Lambda_bar = np.exp(Lambda * Delta)
    B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde
    return Lambda_bar, B_bar


# Parallel scan operations
@jax.vmap
def binary_operator(q_i, q_j):
    """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
    Args:
        q_i: tuple containing A_i and Bu_i at position i       (P,), (P,)
        q_j: tuple containing A_j and Bu_j at position j       (P,), (P,)
    Returns:
        new element ( A_out, Bu_out )
    """
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j * A_i, A_j * b_i + b_j


def apply_dynamics(
    x0,
    steps,
    Lambda_bar,
    B_bar,
    C_tilde,
    conj_sym,
    bidirectional,
):
    Lambda_elements = Lambda_bar * np.ones((steps, Lambda_bar.shape[0]))
    h0 = B_bar @ x0
    xs = jax.lax.associative_scan(np.multiply, Lambda_elements) * h0

    if bidirectional:
        xs2 = jax.lax.associative_scan(np.multiply, Lambda_elements, reverse=True) * h0
        xs = np.concatenate((xs, xs2), axis=-1)

    if conj_sym:
        return jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs)
    else:
        return jax.vmap(lambda x: (C_tilde @ x).real)(xs)


def apply_ssm(
    Lambda_bar,
    B_bar,
    C_tilde,
    input_sequence,
    conj_sym,
    bidirectional,
):
    """Compute the LxH output of discretized SSM given an LxH input.
    Args:
        Lambda_bar (complex64): discretized diagonal state matrix    (P,)
        B_bar      (complex64): discretized input matrix             (P, H)
        C_tilde    (complex64): output matrix                        (H, P)
        input_sequence (float32): input sequence of features         (L, H)
        conj_sym (bool):         whether conjugate symmetry is enforced
        bidirectional (bool):    whether bidirectional setup is used,
                              Note for this case C_tilde will have 2P cols
    Returns:
        ys (float32): the SSM outputs (S5 layer preactivations)      (L, H)
    """
    Lambda_elements = Lambda_bar * np.ones(
        (input_sequence.shape[0], Lambda_bar.shape[0])
    )

    Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence)

    _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements))

    if bidirectional:
        _, xs2 = jax.lax.associative_scan(
            binary_operator, (Lambda_elements, Bu_elements), reverse=True
        )
        xs = np.concatenate((xs, xs2), axis=-1)

    if conj_sym:
        return jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs)
    else:
        return jax.vmap(lambda x: (C_tilde @ x).real)(xs)

:::

::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class S5SSM(nn.Module):
    d_model: int
    d_hidden: int
    C_init: str = "lecun_normal"
    discretization: str = "zoh"
    dt_min: float = 0.0001
    dt_max: float = 0.1
    conj_sym: bool = True
    clip_eigs: bool = False
    bidirectional: bool = False
    step_rescale: float = 1.0
    blocks: int = 16
    n_steps: Optional[int] = None

    """ The S5 SSM
        Args:
            Lambda_re_init (complex64): Real part of init diag state matrix  (P,)
            Lambda_im_init (complex64): Imag part of init diag state matrix  (P,)
            V           (complex64): Eigenvectors used for init           (P,P)
            Vinv        (complex64): Inverse eigenvectors used for init   (P,P)
            d_model     (int32):     Number of features of input seq 
            d_hidden    (int32):     state size
            C_init      (string):    Specifies How C is initialized
                         Options: [trunc_standard_normal: sample from truncated standard normal 
                                                        and then multiply by V, i.e. C_tilde=CV.
                                   lecun_normal: sample from Lecun_normal and then multiply by V.
                                   complex_normal: directly sample a complex valued output matrix 
                                                    from standard normal, does not multiply by V]
            conj_sym    (bool):    Whether conjugate symmetry is enforced
            clip_eigs   (bool):    Whether to enforce left-half plane condition, i.e.
                                   constrain real part of eigenvalues to be negative. 
                                   True recommended for autoregressive task/unbounded sequence lengths
                                   Discussed in https://arxiv.org/pdf/2206.11893.pdf.
            bidirectional (bool):  Whether model is bidirectional, if True, uses two C matrices
            discretization: (string) Specifies discretization method 
                             options: [zoh: zero-order hold method,
                                       bilinear: bilinear transform]
            dt_min:      (float32): minimum value to draw timescale values from when 
                                    initializing log_step
            dt_max:      (float32): maximum value to draw timescale values from when 
                                    initializing log_step
            step_rescale:  (float32): allows for uniformly changing the timescale parameter, e.g. after training 
                                    on a different resolution for the speech commands benchmark
    """

    def setup(self):
        """Initializes parameters once and performs discretization each time
        the SSM is applied to a sequence
        """
        self.H = self.d_model
        self.P = self.d_hidden

        # Initialize state matrix A using approximation to HiPPO-LegS matrix

        block_size = int(self.P / self.blocks)
        # Initialize state matrix A using approximation to HiPPO-LegS matrix
        Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size)

        if self.conj_sym:
            # Need to account for case where we actually sample real B and C, and then multiply
            # by the half sized Vinv and possibly V
            block_size = block_size // 2
            P = self.P // 2
            local_P = 2 * P
        else:
            local_P = P

        Lambda = Lambda[:block_size]
        V = V[:, :block_size]
        Vc = V.conj().T

        # If initializing state matrix A as block-diagonal, put HiPPO approximation
        # on each block
        Lambda = (Lambda * np.ones((self.blocks, block_size))).ravel()
        self.V = jax.scipy.linalg.block_diag(*([V] * self.blocks))
        self.Vinv = jax.scipy.linalg.block_diag(*([Vc] * self.blocks))

        # Initialize diagonal state to state matrix Lambda (eigenvalues)
        self.Lambda_re = self.param(
            "Lambda_re", lambda rng, shape: Lambda.real, (None,)
        )
        self.Lambda_im = self.param(
            "Lambda_im", lambda rng, shape: Lambda.imag, (None,)
        )
        if self.clip_eigs:
            self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im
        else:
            self.Lambda = self.Lambda_re + 1j * self.Lambda_im

        # Initialize input to state (B) matrix
        B_init = lecun_normal()
        B_shape = (local_P, self.H)
        self.B = self.param(
            "B", lambda rng, shape: init_VinvB(B_init, rng, shape, self.Vinv), B_shape
        )
        B_tilde = self.B[..., 0] + 1j * self.B[..., 1]

        # Initialize state to output (C) matrix
        if self.C_init in ["trunc_standard_normal"]:
            C_init = trunc_standard_normal
            C_shape = (self.H, local_P, 2)
        elif self.C_init in ["lecun_normal"]:
            C_init = lecun_normal()
            C_shape = (self.H, local_P, 2)
        elif self.C_init in ["complex_normal"]:
            C_init = normal(stddev=0.5**0.5)
        else:
            raise NotImplementedError(
                "C_init method {} not implemented".format(self.C_init)
            )

        if self.C_init in ["complex_normal"]:
            if self.bidirectional:
                C = self.param("C", C_init, (self.H, 2 * P, 2))
                self.C_tilde = C[..., 0] + 1j * C[..., 1]

            else:
                C = self.param("C", C_init, (self.H, P, 2))
                self.C_tilde = C[..., 0] + 1j * C[..., 1]

        else:
            if self.bidirectional:
                self.C1 = self.param(
                    "C1",
                    lambda rng, shape: init_CV(C_init, rng, shape, self.V),
                    C_shape,
                )
                self.C2 = self.param(
                    "C2",
                    lambda rng, shape: init_CV(C_init, rng, shape, self.V),
                    C_shape,
                )

                C1 = self.C1[..., 0] + 1j * self.C1[..., 1]
                C2 = self.C2[..., 0] + 1j * self.C2[..., 1]
                self.C_tilde = np.concatenate((C1, C2), axis=-1)

            else:
                self.C = self.param(
                    "C", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape
                )

                self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1]

        # Initialize feedthrough (D) matrix
        self.D = self.param("D", normal(stddev=1.0), (self.H,))

        # Initialize learnable discretization timescale value
        self.log_step = self.param(
            "log_step", init_log_steps, (P, self.dt_min, self.dt_max)
        )
        step = self.step_rescale * np.exp(self.log_step[:, 0])

        # Discretize
        if self.discretization in ["zoh"]:
            self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step)
        elif self.discretization in ["bilinear"]:
            self.Lambda_bar, self.B_bar = discretize_bilinear(
                self.Lambda, B_tilde, step
            )
        else:
            raise NotImplementedError(
                "Discretization method {} not implemented".format(self.discretization)
            )

    def __call__(self, input_sequence):
        """
        Compute the LxH output of the S5 SSM given an LxH input sequence
        using a parallel scan.
        Args:
             input_sequence (float32): input sequence (L, H)
        Returns:
            output sequence (float32): (L, H)
        """

        if self.n_steps:
            ys = apply_dynamics(
                input_sequence[0],
                self.n_steps,
                self.Lambda_bar,
                self.B_bar,
                self.C_tilde,
                self.conj_sym,
                self.bidirectional,
            )
            return ys
        else:
            ys = apply_ssm(
                self.Lambda_bar,
                self.B_bar,
                self.C_tilde,
                input_sequence,
                self.conj_sym,
                self.bidirectional,
            )
            # Add feedthrough matrix output Du;
            Du = jax.vmap(lambda u: self.D * u)(input_sequence)
            return ys + Du

:::

LRU Model

adapted from https://github.com/NicolasZucchet/minimal-LRU

::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

parallel_scan = jax.lax.associative_scan

:::

::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def apply_lru_dynamics(
    inputs: jnp.ndarray,  # (time, d_model)
    discrete_lambda: jnp.ndarray,  # (d_hidden,)
    B_norm: jnp.ndarray,  # (d_hidden, d_model)
    C: jnp.ndarray,  # (d_model, d_hidden)
    D: jnp.ndarray,  # (d_model,)
):
    Lambda_elements = jnp.repeat(discrete_lambda[None, ...], inputs.shape[0], axis=0)

    Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
    _, hidden_states = jax.lax.associative_scan(
        binary_operator, (Lambda_elements, Bu_elements)
    )
    return jax.vmap(lambda h, x: (C @ h).real + D * x)(hidden_states, inputs)


def apply_lru_dynamics_from_ic(
    ic: jnp.ndarray,  # (1, d_model)
    n_steps: int,
    discrete_lambda: jnp.ndarray,  # (d_hidden,)
    B_norm: jnp.ndarray,  # (d_hidden, d_model)
    C: jnp.ndarray,  # (d_model, d_hidden)
):
    Lambda_elements = jnp.repeat(discrete_lambda[None, ...], n_steps, axis=0)
    h0 = B_norm @ ic[0]
    hidden_states = jax.lax.associative_scan(jnp.multiply, Lambda_elements) * h0
    return jax.vmap(lambda h: (C @ h).real)(hidden_states)

:::

::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class LRU(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension
    d_model: int  # input and output dimensions
    r_min: float = 0.0  # smallest lambda norm
    r_max: float = 1.0  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda
    n_steps: Optional[int] = None  # number of steps to advance

    def setup(self):
        theta_log = self.param(
            "theta_log", partial(theta_init, max_phase=self.max_phase), (self.d_hidden,)
        )
        nu_log = self.param(
            "nu_log",
            partial(nu_init, r_min=self.r_min, r_max=self.r_max),
            (self.d_hidden,),
        )
        gamma_log = self.param("gamma_log", gamma_log_init, (nu_log, theta_log))

        # Glorot initialized Input/Output projection matrices
        B_re = self.param(
            "B_re",
            partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        B_im = self.param(
            "B_im",
            partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        C_re = self.param(
            "C_re",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        C_im = self.param(
            "C_im",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.D = self.param("D", matrix_init, (self.d_model,))

        self.C = C_re + 1j * C_im
        B = B_re + 1j * B_im
        self.B_norm = B * jnp.exp(gamma_log)[..., None]

        self.discrete_diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j * jnp.exp(theta_log))

    def __call__(
        self,
        inputs: jnp.ndarray,  # (time, d_model)
    ):
        if self.n_steps is not None:
            return apply_lru_dynamics_from_ic(
                inputs,
                self.n_steps,
                self.discrete_diag_lambda,
                self.B_norm,
                self.C,
            )
        else:
            return apply_lru_dynamics(
                inputs,
                self.discrete_diag_lambda,
                self.B_norm,
                self.C,
                self.D,
            )

:::

Deep (Stacked) and Batched versions

::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}

from einops import rearrange
from physmodjax.models.mlp import MLP

:::

::: {#cell-16 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class SequenceLayer(nn.Module):
    """Single layer, with one SSM module, GLU, dropout and batch/layer norm"""

    ssm: nn.Module  # ssm module
    d_model: int  # model size
    dropout: float = 0.0  # dropout probability
    norm: str = "layer"  # which normalization to use
    training: bool = True  # in training mode (dropout in trainign mode only)
    activation: str = "half_glu1"  # activation function
    prenorm: bool = True  # whether to use pre or post normalization

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq = self.ssm()
        self.out1 = nn.Dense(self.d_model)
        self.out2 = nn.Dense(self.d_model)

        # self.d_model -> self.d_model * 4 -> self.d_model
        # GPT mlp
        self.mlp = MLP(
            hidden_channels=[self.d_model * 4, self.d_model],
            activation=nn.gelu,
        )

        if self.norm in ["layer"]:
            self.normalization = nn.LayerNorm()
        else:
            self.normalization = nn.BatchNorm(
                use_running_average=not self.training, axis_name="batch"
            )
        self.drop = nn.Dropout(
            self.dropout,
            broadcast_dims=[0],
            deterministic=not self.training,
        )

    def __call__(self, x):
        skip = x
        if self.prenorm:
            x = self.normalization(x)  # pre normalization
        x = self.seq(x)  # call LRU
        if self.activation in ["full_glu"]:
            x = self.drop(nn.gelu(x))
            x = self.out1(x) * jax.nn.sigmoid(self.out2(x))
            x = self.drop(x)
        elif self.activation in ["half_glu1"]:
            x = self.drop(nn.gelu(x))
            x = x * jax.nn.sigmoid(self.out2(x))
            x = self.drop(x)
        elif self.activation in ["gelu"]:
            x = self.drop(nn.gelu(x))
        elif self.activation in ["mlp"]:
            x = self.drop(self.mlp(x))
        else:
            raise NotImplementedError(f"Activation {self.activation} not implemented")
        x = skip + x  # skip connection
        if not self.prenorm:
            x = self.normalization(x)
        return x

:::

::: {#cell-17 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class StackedSSM(nn.Module):
    ssm: nn.Module  # ssm module
    d_model: int  # model size
    d_vars: int  # number of variables
    n_layers: int  # number of layers
    ssm_first_layer: nn.Module = None  # optional first layer usually for one-to-many
    n_steps: Optional[int] = None  # number of steps to advance
    dropout: float = 0.0  # dropout probability
    training: bool = True
    norm: str = "layer"
    activation: str = "half_glu1"
    prenorm: bool = True

    def setup(self):
        if self.ssm_first_layer is not None:
            self.first_layer = self.ssm_first_layer(
                d_model=self.d_model * self.d_vars,
                n_steps=self.n_steps,
            )
        self.layers = [
            SequenceLayer(
                ssm=partial(self.ssm, d_model=self.d_model * self.d_vars),
                d_model=self.d_model * self.d_vars,
                dropout=self.dropout,
                training=self.training,
                norm=self.norm,
                activation=self.activation,
                prenorm=self.prenorm,
            )
            for _ in range(self.n_layers)
        ]

    def __call__(
        self,
        x: jnp.ndarray,  # (T, ...) or (W, C) # input
    ):
        x = rearrange(x, "t w c -> t (w c)")

        if self.ssm_first_layer is not None:
            x = self.first_layer(x)
        else:
            x = jnp.concatenate(
                [x[0:1], jnp.zeros((x.shape[0] - 1, x.shape[1]))], axis=0
            )

        for layer in self.layers:
            x = layer(x)  # apply each layer

        return rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)


BatchStackedSSMModel = nn.vmap(
    StackedSSM,
    in_axes=0,
    out_axes=0,
    variable_axes={
        "params": None,
        "dropout": None,
        "batch_stats": None,
        "cache": 0,
        "prime": None,
    },
    split_rngs={"params": False, "dropout": True},
    axis_name="batch",
)

:::

::: {#cell-18 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, W, C = 10, 50, 20, 3
d_hidden = 64
deep_ssm = BatchStackedSSMModel(
    ssm_first_layer=partial(S5SSM, d_hidden=d_hidden, n_steps=50),
    ssm=partial(S5SSM, d_hidden=d_hidden),
    d_model=W,
    d_vars=C,
    n_layers=2,
)
x = jnp.empty((B, T, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, W, C)

:::

::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

deep_ssm = BatchStackedSSMModel(
    ssm_first_layer=partial(LRU, d_hidden=d_hidden, n_steps=50),
    ssm=partial(LRU, d_hidden=d_hidden),
    d_model=W,
    d_vars=C,
    n_layers=2,
)
x = jnp.empty((B, T, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, W, C)

:::

::: {#cell-20 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class StackedSSM2D(nn.Module):
    ssm: nn.Module  # ssm module
    d_model: tuple[int, int]
    d_vars: int  # number of variables
    n_layers: int  # number of layers
    ssm_first_layer: nn.Module = None  # optional first layer usually for one-to-many
    n_steps: Optional[int] = None  # number of steps to advance
    dropout: float = 0.0  # dropout probability
    training: bool = True
    norm: str = "layer"
    activation: str = "half_glu1"
    prenorm: bool = True

    def setup(self):
        if self.ssm_first_layer is not None:
            self.first_layer = self.ssm_first_layer(
                d_model=self.d_model[0] * self.d_model[1] * self.d_vars,
                n_steps=self.n_steps,
            )
        self.layers = [
            SequenceLayer(
                ssm=partial(
                    self.ssm, d_model=self.d_model[0] * self.d_model[1] * self.d_vars
                ),
                d_model=self.d_model[0] * self.d_model[1] * self.d_vars,
                dropout=self.dropout,
                training=self.training,
                norm=self.norm,
                activation=self.activation,
                prenorm=self.prenorm,
            )
            for _ in range(self.n_layers)
        ]

    def __call__(
        self,
        x: jnp.ndarray,  # (T, H, W, C) or (H, W, C) # input
    ):
        x = rearrange(x, "t h w c -> t (h w c)")

        if self.ssm_first_layer is not None:
            x = self.first_layer(x)
        else:
            x = jnp.concatenate(
                [x[0:1], jnp.zeros((x.shape[0] - 1, x.shape[1]))], axis=0
            )

        for layer in self.layers:
            x = layer(x)  # apply each layer

        return rearrange(
            x,
            "t (h w c) -> t h w c",
            h=self.d_model[0],
            w=self.d_model[1],
            c=self.d_vars,
        )


BatchStackedSSM2DModel = nn.vmap(
    StackedSSM2D,
    in_axes=0,
    out_axes=0,
    variable_axes={
        "params": None,
        "dropout": None,
        "batch_stats": None,
        "cache": 0,
        "prime": None,
    },
    split_rngs={"params": False, "dropout": True},
    axis_name="batch",
)

:::

::: {#cell-21 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, H, W, C = 10, 50, 20, 20, 3
deep_ssm = BatchStackedSSM2DModel(
    ssm_first_layer=partial(LRU, d_hidden=d_hidden, n_steps=T),
    ssm=partial(LRU, d_hidden=d_hidden),
    d_model=(H, W),
    d_vars=C,
    n_layers=2,
)

x = jnp.empty((B, T, H, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, H, W, C)

:::

::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

B, T, H, W, C = 10, 50, 20, 20, 3
deep_ssm = BatchStackedSSM2DModel(
    ssm_first_layer=partial(S5SSM, d_hidden=d_hidden, n_steps=T),
    ssm=partial(S5SSM, d_hidden=d_hidden),
    d_model=(H, W),
    d_vars=C,
    n_layers=2,
)

x = jnp.empty((B, T, H, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, H, W, C)

:::