SSM Models
S5 Model
adapted from https://github.com/lindermanlab/S5
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
import jax.numpy as np
from flax import linen as nn
from jax import random
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh
from physmodjax.models.recurrent import gamma_log_init, matrix_init, nu_init, theta_init:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def make_HiPPO(N):
"""Create a HiPPO-LegS matrix.
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Args:
N (int32): state size
Returns:
N x N HiPPO LegS matrix
"""
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
def make_NPLR_HiPPO(N):
"""
Makes components needed for NPLR representation of HiPPO-LegS
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Args:
N (int32): state size
Returns:
N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B
"""
# Make -HiPPO
hippo = make_HiPPO(N)
# Add in a rank 1 term. Makes it Normal.
P = np.sqrt(np.arange(N) + 0.5)
# HiPPO also specifies the B matrix
B = np.sqrt(2 * np.arange(N) + 1.0)
return hippo, P, B
def make_DPLR_HiPPO(N):
"""
Makes components needed for DPLR representation of HiPPO-LegS
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Note, we will only use the diagonal part
Args:
N:
Returns:
eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,
eigenvectors V, HiPPO B pre-conjugation
"""
A, P, B = make_NPLR_HiPPO(N)
S = A + P[:, np.newaxis] * P[np.newaxis, :]
S_diag = np.diagonal(S)
Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)
# Diagonalize S to V \Lambda V^*
Lambda_imag, V = eigh(S * -1j)
P = V.conj().T @ P
B_orig = B
B = V.conj().T @ B
return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig
def log_step_initializer(dt_min=0.001, dt_max=0.1):
"""Initialize the learnable timescale Delta by sampling
uniformly between dt_min and dt_max.
Args:
dt_min (float32): minimum value
dt_max (float32): maximum value
Returns:
init function
"""
def init(key, shape):
"""Init function
Args:
key: jax random key
shape tuple: desired shape
Returns:
sampled log_step (float32)
"""
return random.uniform(key, shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(
dt_min
)
return init
def init_log_steps(key, input):
"""Initialize an array of learnable timescale parameters
Args:
key: jax random key
input: tuple containing the array shape H and
dt_min and dt_max
Returns:
initialized array of timescales (float32): (H,)
"""
H, dt_min, dt_max = input
log_steps = []
for i in range(H):
key, skey = random.split(key)
log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,))
log_steps.append(log_step)
return np.array(log_steps)
def init_VinvB(init_fun, rng, shape, Vinv):
"""Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.
Note we will parameterize this with two different matrices for complex
numbers.
Args:
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (P,H)
Vinv: (complex64) the inverse eigenvectors used for initialization
Returns:
B_tilde (complex64) of shape (P,H,2)
"""
B = init_fun(rng, shape)
VinvB = Vinv @ B
VinvB_real = VinvB.real
VinvB_imag = VinvB.imag
return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)
def trunc_standard_normal(key, shape):
"""Sample C with a truncated normal distribution with standard deviation 1.
Args:
key: jax random key
shape (tuple): desired shape, of length 3, (H,P,_)
Returns:
sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)
"""
H, P, _ = shape
Cs = []
for i in range(H):
key, skey = random.split(key)
C = lecun_normal()(skey, shape=(1, P, 2))
Cs.append(C)
return np.array(Cs)[:, 0]
def init_CV(init_fun, rng, shape, V):
"""Initialize C_tilde=CV. First sample C. Then compute CV.
Note we will parameterize this with two different matrices for complex
numbers.
Args:
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (H,P)
V: (complex64) the eigenvectors used for initialization
Returns:
C_tilde (complex64) of shape (H,P,2)
"""
C_ = init_fun(rng, shape)
C = C_[..., 0] + 1j * C_[..., 1]
CV = C @ V
CV_real = CV.real
CV_imag = CV.imag
return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1):::
::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
# Discretization functions
def discretize_bilinear(Lambda, B_tilde, Delta):
"""Discretize a diagonalized, continuous-time linear SSM
using bilinear transform method.
Args:
Lambda (complex64): diagonal state matrix (P,)
B_tilde (complex64): input matrix (P, H)
Delta (float32): discretization step sizes (P,)
Returns:
discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
"""
Identity = np.ones(Lambda.shape[0])
BL = 1 / (Identity - (Delta / 2.0) * Lambda)
Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda)
B_bar = (BL * Delta)[..., None] * B_tilde
return Lambda_bar, B_bar
def discretize_zoh(Lambda, B_tilde, Delta):
"""Discretize a diagonalized, continuous-time linear SSM
using zero-order hold method.
Args:
Lambda (complex64): diagonal state matrix (P,)
B_tilde (complex64): input matrix (P, H)
Delta (float32): discretization step sizes (P,)
Returns:
discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
"""
Identity = np.ones(Lambda.shape[0])
Lambda_bar = np.exp(Lambda * Delta)
B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde
return Lambda_bar, B_bar
# Parallel scan operations
@jax.vmap
def binary_operator(q_i, q_j):
"""Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
Args:
q_i: tuple containing A_i and Bu_i at position i (P,), (P,)
q_j: tuple containing A_j and Bu_j at position j (P,), (P,)
Returns:
new element ( A_out, Bu_out )
"""
A_i, b_i = q_i
A_j, b_j = q_j
return A_j * A_i, A_j * b_i + b_j
def apply_dynamics(
x0,
steps,
Lambda_bar,
B_bar,
C_tilde,
conj_sym,
bidirectional,
):
Lambda_elements = Lambda_bar * np.ones((steps, Lambda_bar.shape[0]))
h0 = B_bar @ x0
xs = jax.lax.associative_scan(np.multiply, Lambda_elements) * h0
if bidirectional:
xs2 = jax.lax.associative_scan(np.multiply, Lambda_elements, reverse=True) * h0
xs = np.concatenate((xs, xs2), axis=-1)
if conj_sym:
return jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs)
else:
return jax.vmap(lambda x: (C_tilde @ x).real)(xs)
def apply_ssm(
Lambda_bar,
B_bar,
C_tilde,
input_sequence,
conj_sym,
bidirectional,
):
"""Compute the LxH output of discretized SSM given an LxH input.
Args:
Lambda_bar (complex64): discretized diagonal state matrix (P,)
B_bar (complex64): discretized input matrix (P, H)
C_tilde (complex64): output matrix (H, P)
input_sequence (float32): input sequence of features (L, H)
conj_sym (bool): whether conjugate symmetry is enforced
bidirectional (bool): whether bidirectional setup is used,
Note for this case C_tilde will have 2P cols
Returns:
ys (float32): the SSM outputs (S5 layer preactivations) (L, H)
"""
Lambda_elements = Lambda_bar * np.ones(
(input_sequence.shape[0], Lambda_bar.shape[0])
)
Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence)
_, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements))
if bidirectional:
_, xs2 = jax.lax.associative_scan(
binary_operator, (Lambda_elements, Bu_elements), reverse=True
)
xs = np.concatenate((xs, xs2), axis=-1)
if conj_sym:
return jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs)
else:
return jax.vmap(lambda x: (C_tilde @ x).real)(xs):::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class S5SSM(nn.Module):
d_model: int
d_hidden: int
C_init: str = "lecun_normal"
discretization: str = "zoh"
dt_min: float = 0.0001
dt_max: float = 0.1
conj_sym: bool = True
clip_eigs: bool = False
bidirectional: bool = False
step_rescale: float = 1.0
blocks: int = 16
n_steps: Optional[int] = None
""" The S5 SSM
Args:
Lambda_re_init (complex64): Real part of init diag state matrix (P,)
Lambda_im_init (complex64): Imag part of init diag state matrix (P,)
V (complex64): Eigenvectors used for init (P,P)
Vinv (complex64): Inverse eigenvectors used for init (P,P)
d_model (int32): Number of features of input seq
d_hidden (int32): state size
C_init (string): Specifies How C is initialized
Options: [trunc_standard_normal: sample from truncated standard normal
and then multiply by V, i.e. C_tilde=CV.
lecun_normal: sample from Lecun_normal and then multiply by V.
complex_normal: directly sample a complex valued output matrix
from standard normal, does not multiply by V]
conj_sym (bool): Whether conjugate symmetry is enforced
clip_eigs (bool): Whether to enforce left-half plane condition, i.e.
constrain real part of eigenvalues to be negative.
True recommended for autoregressive task/unbounded sequence lengths
Discussed in https://arxiv.org/pdf/2206.11893.pdf.
bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices
discretization: (string) Specifies discretization method
options: [zoh: zero-order hold method,
bilinear: bilinear transform]
dt_min: (float32): minimum value to draw timescale values from when
initializing log_step
dt_max: (float32): maximum value to draw timescale values from when
initializing log_step
step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training
on a different resolution for the speech commands benchmark
"""
def setup(self):
"""Initializes parameters once and performs discretization each time
the SSM is applied to a sequence
"""
self.H = self.d_model
self.P = self.d_hidden
# Initialize state matrix A using approximation to HiPPO-LegS matrix
block_size = int(self.P / self.blocks)
# Initialize state matrix A using approximation to HiPPO-LegS matrix
Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size)
if self.conj_sym:
# Need to account for case where we actually sample real B and C, and then multiply
# by the half sized Vinv and possibly V
block_size = block_size // 2
P = self.P // 2
local_P = 2 * P
else:
local_P = P
Lambda = Lambda[:block_size]
V = V[:, :block_size]
Vc = V.conj().T
# If initializing state matrix A as block-diagonal, put HiPPO approximation
# on each block
Lambda = (Lambda * np.ones((self.blocks, block_size))).ravel()
self.V = jax.scipy.linalg.block_diag(*([V] * self.blocks))
self.Vinv = jax.scipy.linalg.block_diag(*([Vc] * self.blocks))
# Initialize diagonal state to state matrix Lambda (eigenvalues)
self.Lambda_re = self.param(
"Lambda_re", lambda rng, shape: Lambda.real, (None,)
)
self.Lambda_im = self.param(
"Lambda_im", lambda rng, shape: Lambda.imag, (None,)
)
if self.clip_eigs:
self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im
else:
self.Lambda = self.Lambda_re + 1j * self.Lambda_im
# Initialize input to state (B) matrix
B_init = lecun_normal()
B_shape = (local_P, self.H)
self.B = self.param(
"B", lambda rng, shape: init_VinvB(B_init, rng, shape, self.Vinv), B_shape
)
B_tilde = self.B[..., 0] + 1j * self.B[..., 1]
# Initialize state to output (C) matrix
if self.C_init in ["trunc_standard_normal"]:
C_init = trunc_standard_normal
C_shape = (self.H, local_P, 2)
elif self.C_init in ["lecun_normal"]:
C_init = lecun_normal()
C_shape = (self.H, local_P, 2)
elif self.C_init in ["complex_normal"]:
C_init = normal(stddev=0.5**0.5)
else:
raise NotImplementedError(
"C_init method {} not implemented".format(self.C_init)
)
if self.C_init in ["complex_normal"]:
if self.bidirectional:
C = self.param("C", C_init, (self.H, 2 * P, 2))
self.C_tilde = C[..., 0] + 1j * C[..., 1]
else:
C = self.param("C", C_init, (self.H, P, 2))
self.C_tilde = C[..., 0] + 1j * C[..., 1]
else:
if self.bidirectional:
self.C1 = self.param(
"C1",
lambda rng, shape: init_CV(C_init, rng, shape, self.V),
C_shape,
)
self.C2 = self.param(
"C2",
lambda rng, shape: init_CV(C_init, rng, shape, self.V),
C_shape,
)
C1 = self.C1[..., 0] + 1j * self.C1[..., 1]
C2 = self.C2[..., 0] + 1j * self.C2[..., 1]
self.C_tilde = np.concatenate((C1, C2), axis=-1)
else:
self.C = self.param(
"C", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape
)
self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1]
# Initialize feedthrough (D) matrix
self.D = self.param("D", normal(stddev=1.0), (self.H,))
# Initialize learnable discretization timescale value
self.log_step = self.param(
"log_step", init_log_steps, (P, self.dt_min, self.dt_max)
)
step = self.step_rescale * np.exp(self.log_step[:, 0])
# Discretize
if self.discretization in ["zoh"]:
self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step)
elif self.discretization in ["bilinear"]:
self.Lambda_bar, self.B_bar = discretize_bilinear(
self.Lambda, B_tilde, step
)
else:
raise NotImplementedError(
"Discretization method {} not implemented".format(self.discretization)
)
def __call__(self, input_sequence):
"""
Compute the LxH output of the S5 SSM given an LxH input sequence
using a parallel scan.
Args:
input_sequence (float32): input sequence (L, H)
Returns:
output sequence (float32): (L, H)
"""
if self.n_steps:
ys = apply_dynamics(
input_sequence[0],
self.n_steps,
self.Lambda_bar,
self.B_bar,
self.C_tilde,
self.conj_sym,
self.bidirectional,
)
return ys
else:
ys = apply_ssm(
self.Lambda_bar,
self.B_bar,
self.C_tilde,
input_sequence,
self.conj_sym,
self.bidirectional,
)
# Add feedthrough matrix output Du;
Du = jax.vmap(lambda u: self.D * u)(input_sequence)
return ys + Du:::
LRU Model
adapted from https://github.com/NicolasZucchet/minimal-LRU
::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
parallel_scan = jax.lax.associative_scan:::
::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def apply_lru_dynamics(
inputs: jnp.ndarray, # (time, d_model)
discrete_lambda: jnp.ndarray, # (d_hidden,)
B_norm: jnp.ndarray, # (d_hidden, d_model)
C: jnp.ndarray, # (d_model, d_hidden)
D: jnp.ndarray, # (d_model,)
):
Lambda_elements = jnp.repeat(discrete_lambda[None, ...], inputs.shape[0], axis=0)
Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
_, hidden_states = jax.lax.associative_scan(
binary_operator, (Lambda_elements, Bu_elements)
)
return jax.vmap(lambda h, x: (C @ h).real + D * x)(hidden_states, inputs)
def apply_lru_dynamics_from_ic(
ic: jnp.ndarray, # (1, d_model)
n_steps: int,
discrete_lambda: jnp.ndarray, # (d_hidden,)
B_norm: jnp.ndarray, # (d_hidden, d_model)
C: jnp.ndarray, # (d_model, d_hidden)
):
Lambda_elements = jnp.repeat(discrete_lambda[None, ...], n_steps, axis=0)
h0 = B_norm @ ic[0]
hidden_states = jax.lax.associative_scan(jnp.multiply, Lambda_elements) * h0
return jax.vmap(lambda h: (C @ h).real)(hidden_states):::
::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class LRU(nn.Module):
"""
LRU module in charge of the recurrent processing.
Implementation following the one of Orvieto et al. 2023.
"""
d_hidden: int # hidden state dimension
d_model: int # input and output dimensions
r_min: float = 0.0 # smallest lambda norm
r_max: float = 1.0 # largest lambda norm
max_phase: float = 6.28 # max phase lambda
n_steps: Optional[int] = None # number of steps to advance
def setup(self):
theta_log = self.param(
"theta_log", partial(theta_init, max_phase=self.max_phase), (self.d_hidden,)
)
nu_log = self.param(
"nu_log",
partial(nu_init, r_min=self.r_min, r_max=self.r_max),
(self.d_hidden,),
)
gamma_log = self.param("gamma_log", gamma_log_init, (nu_log, theta_log))
# Glorot initialized Input/Output projection matrices
B_re = self.param(
"B_re",
partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
(self.d_hidden, self.d_model),
)
B_im = self.param(
"B_im",
partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
(self.d_hidden, self.d_model),
)
C_re = self.param(
"C_re",
partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
(self.d_model, self.d_hidden),
)
C_im = self.param(
"C_im",
partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
(self.d_model, self.d_hidden),
)
self.D = self.param("D", matrix_init, (self.d_model,))
self.C = C_re + 1j * C_im
B = B_re + 1j * B_im
self.B_norm = B * jnp.exp(gamma_log)[..., None]
self.discrete_diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j * jnp.exp(theta_log))
def __call__(
self,
inputs: jnp.ndarray, # (time, d_model)
):
if self.n_steps is not None:
return apply_lru_dynamics_from_ic(
inputs,
self.n_steps,
self.discrete_diag_lambda,
self.B_norm,
self.C,
)
else:
return apply_lru_dynamics(
inputs,
self.discrete_diag_lambda,
self.B_norm,
self.C,
self.D,
):::
Deep (Stacked) and Batched versions
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from einops import rearrange
from physmodjax.models.mlp import MLP:::
::: {#cell-16 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SequenceLayer(nn.Module):
"""Single layer, with one SSM module, GLU, dropout and batch/layer norm"""
ssm: nn.Module # ssm module
d_model: int # model size
dropout: float = 0.0 # dropout probability
norm: str = "layer" # which normalization to use
training: bool = True # in training mode (dropout in trainign mode only)
activation: str = "half_glu1" # activation function
prenorm: bool = True # whether to use pre or post normalization
def setup(self):
"""Initializes the ssm, layer norm and dropout"""
self.seq = self.ssm()
self.out1 = nn.Dense(self.d_model)
self.out2 = nn.Dense(self.d_model)
# self.d_model -> self.d_model * 4 -> self.d_model
# GPT mlp
self.mlp = MLP(
hidden_channels=[self.d_model * 4, self.d_model],
activation=nn.gelu,
)
if self.norm in ["layer"]:
self.normalization = nn.LayerNorm()
else:
self.normalization = nn.BatchNorm(
use_running_average=not self.training, axis_name="batch"
)
self.drop = nn.Dropout(
self.dropout,
broadcast_dims=[0],
deterministic=not self.training,
)
def __call__(self, x):
skip = x
if self.prenorm:
x = self.normalization(x) # pre normalization
x = self.seq(x) # call LRU
if self.activation in ["full_glu"]:
x = self.drop(nn.gelu(x))
x = self.out1(x) * jax.nn.sigmoid(self.out2(x))
x = self.drop(x)
elif self.activation in ["half_glu1"]:
x = self.drop(nn.gelu(x))
x = x * jax.nn.sigmoid(self.out2(x))
x = self.drop(x)
elif self.activation in ["gelu"]:
x = self.drop(nn.gelu(x))
elif self.activation in ["mlp"]:
x = self.drop(self.mlp(x))
else:
raise NotImplementedError(f"Activation {self.activation} not implemented")
x = skip + x # skip connection
if not self.prenorm:
x = self.normalization(x)
return x:::
::: {#cell-17 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class StackedSSM(nn.Module):
ssm: nn.Module # ssm module
d_model: int # model size
d_vars: int # number of variables
n_layers: int # number of layers
ssm_first_layer: nn.Module = None # optional first layer usually for one-to-many
n_steps: Optional[int] = None # number of steps to advance
dropout: float = 0.0 # dropout probability
training: bool = True
norm: str = "layer"
activation: str = "half_glu1"
prenorm: bool = True
def setup(self):
if self.ssm_first_layer is not None:
self.first_layer = self.ssm_first_layer(
d_model=self.d_model * self.d_vars,
n_steps=self.n_steps,
)
self.layers = [
SequenceLayer(
ssm=partial(self.ssm, d_model=self.d_model * self.d_vars),
d_model=self.d_model * self.d_vars,
dropout=self.dropout,
training=self.training,
norm=self.norm,
activation=self.activation,
prenorm=self.prenorm,
)
for _ in range(self.n_layers)
]
def __call__(
self,
x: jnp.ndarray, # (T, ...) or (W, C) # input
):
x = rearrange(x, "t w c -> t (w c)")
if self.ssm_first_layer is not None:
x = self.first_layer(x)
else:
x = jnp.concatenate(
[x[0:1], jnp.zeros((x.shape[0] - 1, x.shape[1]))], axis=0
)
for layer in self.layers:
x = layer(x) # apply each layer
return rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
BatchStackedSSMModel = nn.vmap(
StackedSSM,
in_axes=0,
out_axes=0,
variable_axes={
"params": None,
"dropout": None,
"batch_stats": None,
"cache": 0,
"prime": None,
},
split_rngs={"params": False, "dropout": True},
axis_name="batch",
):::
::: {#cell-18 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 10, 50, 20, 3
d_hidden = 64
deep_ssm = BatchStackedSSMModel(
ssm_first_layer=partial(S5SSM, d_hidden=d_hidden, n_steps=50),
ssm=partial(S5SSM, d_hidden=d_hidden),
d_model=W,
d_vars=C,
n_layers=2,
)
x = jnp.empty((B, T, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)
assert out.shape == (B, T, W, C):::
::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
deep_ssm = BatchStackedSSMModel(
ssm_first_layer=partial(LRU, d_hidden=d_hidden, n_steps=50),
ssm=partial(LRU, d_hidden=d_hidden),
d_model=W,
d_vars=C,
n_layers=2,
)
x = jnp.empty((B, T, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)
assert out.shape == (B, T, W, C):::
::: {#cell-20 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class StackedSSM2D(nn.Module):
ssm: nn.Module # ssm module
d_model: tuple[int, int]
d_vars: int # number of variables
n_layers: int # number of layers
ssm_first_layer: nn.Module = None # optional first layer usually for one-to-many
n_steps: Optional[int] = None # number of steps to advance
dropout: float = 0.0 # dropout probability
training: bool = True
norm: str = "layer"
activation: str = "half_glu1"
prenorm: bool = True
def setup(self):
if self.ssm_first_layer is not None:
self.first_layer = self.ssm_first_layer(
d_model=self.d_model[0] * self.d_model[1] * self.d_vars,
n_steps=self.n_steps,
)
self.layers = [
SequenceLayer(
ssm=partial(
self.ssm, d_model=self.d_model[0] * self.d_model[1] * self.d_vars
),
d_model=self.d_model[0] * self.d_model[1] * self.d_vars,
dropout=self.dropout,
training=self.training,
norm=self.norm,
activation=self.activation,
prenorm=self.prenorm,
)
for _ in range(self.n_layers)
]
def __call__(
self,
x: jnp.ndarray, # (T, H, W, C) or (H, W, C) # input
):
x = rearrange(x, "t h w c -> t (h w c)")
if self.ssm_first_layer is not None:
x = self.first_layer(x)
else:
x = jnp.concatenate(
[x[0:1], jnp.zeros((x.shape[0] - 1, x.shape[1]))], axis=0
)
for layer in self.layers:
x = layer(x) # apply each layer
return rearrange(
x,
"t (h w c) -> t h w c",
h=self.d_model[0],
w=self.d_model[1],
c=self.d_vars,
)
BatchStackedSSM2DModel = nn.vmap(
StackedSSM2D,
in_axes=0,
out_axes=0,
variable_axes={
"params": None,
"dropout": None,
"batch_stats": None,
"cache": 0,
"prime": None,
},
split_rngs={"params": False, "dropout": True},
axis_name="batch",
):::
::: {#cell-21 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, H, W, C = 10, 50, 20, 20, 3
deep_ssm = BatchStackedSSM2DModel(
ssm_first_layer=partial(LRU, d_hidden=d_hidden, n_steps=T),
ssm=partial(LRU, d_hidden=d_hidden),
d_model=(H, W),
d_vars=C,
n_layers=2,
)
x = jnp.empty((B, T, H, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)
assert out.shape == (B, T, H, W, C):::
::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, H, W, C = 10, 50, 20, 20, 3
deep_ssm = BatchStackedSSM2DModel(
ssm_first_layer=partial(S5SSM, d_hidden=d_hidden, n_steps=T),
ssm=partial(S5SSM, d_hidden=d_hidden),
d_model=(H, W),
d_vars=C,
n_layers=2,
)
x = jnp.empty((B, T, H, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)
assert out.shape == (B, T, H, W, C):::