SSM Models
S5 Model
adapted from https://github.com/lindermanlab/S5
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
import jax.numpy as np
from flax import linen as nn
from jax import random
from jax.nn.initializers import lecun_normal, normal
from jax.numpy.linalg import eigh
from physmodjax.models.recurrent import gamma_log_init, matrix_init, nu_init, theta_init
:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def make_HiPPO(N):
"""Create a HiPPO-LegS matrix.
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Args:
N (int32): state size
Returns:
N x N HiPPO LegS matrix
"""
= np.sqrt(1 + 2 * np.arange(N))
P = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
A return -A
def make_NPLR_HiPPO(N):
"""
Makes components needed for NPLR representation of HiPPO-LegS
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Args:
N (int32): state size
Returns:
N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B
"""
# Make -HiPPO
= make_HiPPO(N)
hippo
# Add in a rank 1 term. Makes it Normal.
= np.sqrt(np.arange(N) + 0.5)
P
# HiPPO also specifies the B matrix
= np.sqrt(2 * np.arange(N) + 1.0)
B return hippo, P, B
def make_DPLR_HiPPO(N):
"""
Makes components needed for DPLR representation of HiPPO-LegS
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Note, we will only use the diagonal part
Args:
N:
Returns:
eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,
eigenvectors V, HiPPO B pre-conjugation
"""
= make_NPLR_HiPPO(N)
A, P, B
= A + P[:, np.newaxis] * P[np.newaxis, :]
S
= np.diagonal(S)
S_diag = np.mean(S_diag) * np.ones_like(S_diag)
Lambda_real
# Diagonalize S to V \Lambda V^*
= eigh(S * -1j)
Lambda_imag, V
= V.conj().T @ P
P = B
B_orig = V.conj().T @ B
B return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig
def log_step_initializer(dt_min=0.001, dt_max=0.1):
"""Initialize the learnable timescale Delta by sampling
uniformly between dt_min and dt_max.
Args:
dt_min (float32): minimum value
dt_max (float32): maximum value
Returns:
init function
"""
def init(key, shape):
"""Init function
Args:
key: jax random key
shape tuple: desired shape
Returns:
sampled log_step (float32)
"""
return random.uniform(key, shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(
dt_min
)
return init
def init_log_steps(key, input):
"""Initialize an array of learnable timescale parameters
Args:
key: jax random key
input: tuple containing the array shape H and
dt_min and dt_max
Returns:
initialized array of timescales (float32): (H,)
"""
= input
H, dt_min, dt_max = []
log_steps for i in range(H):
= random.split(key)
key, skey = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,))
log_step
log_steps.append(log_step)
return np.array(log_steps)
def init_VinvB(init_fun, rng, shape, Vinv):
"""Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.
Note we will parameterize this with two different matrices for complex
numbers.
Args:
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (P,H)
Vinv: (complex64) the inverse eigenvectors used for initialization
Returns:
B_tilde (complex64) of shape (P,H,2)
"""
= init_fun(rng, shape)
B = Vinv @ B
VinvB = VinvB.real
VinvB_real = VinvB.imag
VinvB_imag return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)
def trunc_standard_normal(key, shape):
"""Sample C with a truncated normal distribution with standard deviation 1.
Args:
key: jax random key
shape (tuple): desired shape, of length 3, (H,P,_)
Returns:
sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)
"""
= shape
H, P, _ = []
Cs for i in range(H):
= random.split(key)
key, skey = lecun_normal()(skey, shape=(1, P, 2))
C
Cs.append(C)return np.array(Cs)[:, 0]
def init_CV(init_fun, rng, shape, V):
"""Initialize C_tilde=CV. First sample C. Then compute CV.
Note we will parameterize this with two different matrices for complex
numbers.
Args:
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (H,P)
V: (complex64) the eigenvectors used for initialization
Returns:
C_tilde (complex64) of shape (H,P,2)
"""
= init_fun(rng, shape)
C_ = C_[..., 0] + 1j * C_[..., 1]
C = C @ V
CV = CV.real
CV_real = CV.imag
CV_imag return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1)
:::
::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
# Discretization functions
def discretize_bilinear(Lambda, B_tilde, Delta):
"""Discretize a diagonalized, continuous-time linear SSM
using bilinear transform method.
Args:
Lambda (complex64): diagonal state matrix (P,)
B_tilde (complex64): input matrix (P, H)
Delta (float32): discretization step sizes (P,)
Returns:
discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
"""
= np.ones(Lambda.shape[0])
Identity
= 1 / (Identity - (Delta / 2.0) * Lambda)
BL = BL * (Identity + (Delta / 2.0) * Lambda)
Lambda_bar = (BL * Delta)[..., None] * B_tilde
B_bar return Lambda_bar, B_bar
def discretize_zoh(Lambda, B_tilde, Delta):
"""Discretize a diagonalized, continuous-time linear SSM
using zero-order hold method.
Args:
Lambda (complex64): diagonal state matrix (P,)
B_tilde (complex64): input matrix (P, H)
Delta (float32): discretization step sizes (P,)
Returns:
discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
"""
= np.ones(Lambda.shape[0])
Identity = np.exp(Lambda * Delta)
Lambda_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde
B_bar return Lambda_bar, B_bar
# Parallel scan operations
@jax.vmap
def binary_operator(q_i, q_j):
"""Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
Args:
q_i: tuple containing A_i and Bu_i at position i (P,), (P,)
q_j: tuple containing A_j and Bu_j at position j (P,), (P,)
Returns:
new element ( A_out, Bu_out )
"""
= q_i
A_i, b_i = q_j
A_j, b_j return A_j * A_i, A_j * b_i + b_j
def apply_dynamics(
x0,
steps,
Lambda_bar,
B_bar,
C_tilde,
conj_sym,
bidirectional,
):= Lambda_bar * np.ones((steps, Lambda_bar.shape[0]))
Lambda_elements = B_bar @ x0
h0 = jax.lax.associative_scan(np.multiply, Lambda_elements) * h0
xs
if bidirectional:
= jax.lax.associative_scan(np.multiply, Lambda_elements, reverse=True) * h0
xs2 = np.concatenate((xs, xs2), axis=-1)
xs
if conj_sym:
return jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs)
else:
return jax.vmap(lambda x: (C_tilde @ x).real)(xs)
def apply_ssm(
Lambda_bar,
B_bar,
C_tilde,
input_sequence,
conj_sym,
bidirectional,
):"""Compute the LxH output of discretized SSM given an LxH input.
Args:
Lambda_bar (complex64): discretized diagonal state matrix (P,)
B_bar (complex64): discretized input matrix (P, H)
C_tilde (complex64): output matrix (H, P)
input_sequence (float32): input sequence of features (L, H)
conj_sym (bool): whether conjugate symmetry is enforced
bidirectional (bool): whether bidirectional setup is used,
Note for this case C_tilde will have 2P cols
Returns:
ys (float32): the SSM outputs (S5 layer preactivations) (L, H)
"""
= Lambda_bar * np.ones(
Lambda_elements 0], Lambda_bar.shape[0])
(input_sequence.shape[
)
= jax.vmap(lambda u: B_bar @ u)(input_sequence)
Bu_elements
= jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements))
_, xs
if bidirectional:
= jax.lax.associative_scan(
_, xs2 =True
binary_operator, (Lambda_elements, Bu_elements), reverse
)= np.concatenate((xs, xs2), axis=-1)
xs
if conj_sym:
return jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs)
else:
return jax.vmap(lambda x: (C_tilde @ x).real)(xs)
:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class S5SSM(nn.Module):
int
d_model: int
d_hidden: str = "lecun_normal"
C_init: str = "zoh"
discretization: float = 0.0001
dt_min: float = 0.1
dt_max: bool = True
conj_sym: bool = False
clip_eigs: bool = False
bidirectional: float = 1.0
step_rescale: int = 16
blocks: int] = None
n_steps: Optional[
""" The S5 SSM
Args:
Lambda_re_init (complex64): Real part of init diag state matrix (P,)
Lambda_im_init (complex64): Imag part of init diag state matrix (P,)
V (complex64): Eigenvectors used for init (P,P)
Vinv (complex64): Inverse eigenvectors used for init (P,P)
d_model (int32): Number of features of input seq
d_hidden (int32): state size
C_init (string): Specifies How C is initialized
Options: [trunc_standard_normal: sample from truncated standard normal
and then multiply by V, i.e. C_tilde=CV.
lecun_normal: sample from Lecun_normal and then multiply by V.
complex_normal: directly sample a complex valued output matrix
from standard normal, does not multiply by V]
conj_sym (bool): Whether conjugate symmetry is enforced
clip_eigs (bool): Whether to enforce left-half plane condition, i.e.
constrain real part of eigenvalues to be negative.
True recommended for autoregressive task/unbounded sequence lengths
Discussed in https://arxiv.org/pdf/2206.11893.pdf.
bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices
discretization: (string) Specifies discretization method
options: [zoh: zero-order hold method,
bilinear: bilinear transform]
dt_min: (float32): minimum value to draw timescale values from when
initializing log_step
dt_max: (float32): maximum value to draw timescale values from when
initializing log_step
step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training
on a different resolution for the speech commands benchmark
"""
def setup(self):
"""Initializes parameters once and performs discretization each time
the SSM is applied to a sequence
"""
self.H = self.d_model
self.P = self.d_hidden
# Initialize state matrix A using approximation to HiPPO-LegS matrix
= int(self.P / self.blocks)
block_size # Initialize state matrix A using approximation to HiPPO-LegS matrix
= make_DPLR_HiPPO(block_size)
Lambda, _, B, V, B_orig
if self.conj_sym:
# Need to account for case where we actually sample real B and C, and then multiply
# by the half sized Vinv and possibly V
= block_size // 2
block_size = self.P // 2
P = 2 * P
local_P else:
= P
local_P
= Lambda[:block_size]
Lambda = V[:, :block_size]
V = V.conj().T
Vc
# If initializing state matrix A as block-diagonal, put HiPPO approximation
# on each block
= (Lambda * np.ones((self.blocks, block_size))).ravel()
Lambda self.V = jax.scipy.linalg.block_diag(*([V] * self.blocks))
self.Vinv = jax.scipy.linalg.block_diag(*([Vc] * self.blocks))
# Initialize diagonal state to state matrix Lambda (eigenvalues)
self.Lambda_re = self.param(
"Lambda_re", lambda rng, shape: Lambda.real, (None,)
)self.Lambda_im = self.param(
"Lambda_im", lambda rng, shape: Lambda.imag, (None,)
)if self.clip_eigs:
self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im
else:
self.Lambda = self.Lambda_re + 1j * self.Lambda_im
# Initialize input to state (B) matrix
= lecun_normal()
B_init = (local_P, self.H)
B_shape self.B = self.param(
"B", lambda rng, shape: init_VinvB(B_init, rng, shape, self.Vinv), B_shape
)= self.B[..., 0] + 1j * self.B[..., 1]
B_tilde
# Initialize state to output (C) matrix
if self.C_init in ["trunc_standard_normal"]:
= trunc_standard_normal
C_init = (self.H, local_P, 2)
C_shape elif self.C_init in ["lecun_normal"]:
= lecun_normal()
C_init = (self.H, local_P, 2)
C_shape elif self.C_init in ["complex_normal"]:
= normal(stddev=0.5**0.5)
C_init else:
raise NotImplementedError(
"C_init method {} not implemented".format(self.C_init)
)
if self.C_init in ["complex_normal"]:
if self.bidirectional:
= self.param("C", C_init, (self.H, 2 * P, 2))
C self.C_tilde = C[..., 0] + 1j * C[..., 1]
else:
= self.param("C", C_init, (self.H, P, 2))
C self.C_tilde = C[..., 0] + 1j * C[..., 1]
else:
if self.bidirectional:
self.C1 = self.param(
"C1",
lambda rng, shape: init_CV(C_init, rng, shape, self.V),
C_shape,
)self.C2 = self.param(
"C2",
lambda rng, shape: init_CV(C_init, rng, shape, self.V),
C_shape,
)
= self.C1[..., 0] + 1j * self.C1[..., 1]
C1 = self.C2[..., 0] + 1j * self.C2[..., 1]
C2 self.C_tilde = np.concatenate((C1, C2), axis=-1)
else:
self.C = self.param(
"C", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape
)
self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1]
# Initialize feedthrough (D) matrix
self.D = self.param("D", normal(stddev=1.0), (self.H,))
# Initialize learnable discretization timescale value
self.log_step = self.param(
"log_step", init_log_steps, (P, self.dt_min, self.dt_max)
)= self.step_rescale * np.exp(self.log_step[:, 0])
step
# Discretize
if self.discretization in ["zoh"]:
self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step)
elif self.discretization in ["bilinear"]:
self.Lambda_bar, self.B_bar = discretize_bilinear(
self.Lambda, B_tilde, step
)else:
raise NotImplementedError(
"Discretization method {} not implemented".format(self.discretization)
)
def __call__(self, input_sequence):
"""
Compute the LxH output of the S5 SSM given an LxH input sequence
using a parallel scan.
Args:
input_sequence (float32): input sequence (L, H)
Returns:
output sequence (float32): (L, H)
"""
if self.n_steps:
= apply_dynamics(
ys 0],
input_sequence[self.n_steps,
self.Lambda_bar,
self.B_bar,
self.C_tilde,
self.conj_sym,
self.bidirectional,
)return ys
else:
= apply_ssm(
ys self.Lambda_bar,
self.B_bar,
self.C_tilde,
input_sequence,self.conj_sym,
self.bidirectional,
)# Add feedthrough matrix output Du;
= jax.vmap(lambda u: self.D * u)(input_sequence)
Du return ys + Du
:::
LRU Model
adapted from https://github.com/NicolasZucchet/minimal-LRU
::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
= jax.lax.associative_scan parallel_scan
:::
::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def apply_lru_dynamics(
# (time, d_model)
inputs: jnp.ndarray, # (d_hidden,)
discrete_lambda: jnp.ndarray, # (d_hidden, d_model)
B_norm: jnp.ndarray, # (d_model, d_hidden)
C: jnp.ndarray, # (d_model,)
D: jnp.ndarray,
):= jnp.repeat(discrete_lambda[None, ...], inputs.shape[0], axis=0)
Lambda_elements
= jax.vmap(lambda u: B_norm @ u)(inputs)
Bu_elements = jax.lax.associative_scan(
_, hidden_states
binary_operator, (Lambda_elements, Bu_elements)
)return jax.vmap(lambda h, x: (C @ h).real + D * x)(hidden_states, inputs)
def apply_lru_dynamics_from_ic(
# (1, d_model)
ic: jnp.ndarray, int,
n_steps: # (d_hidden,)
discrete_lambda: jnp.ndarray, # (d_hidden, d_model)
B_norm: jnp.ndarray, # (d_model, d_hidden)
C: jnp.ndarray,
):= jnp.repeat(discrete_lambda[None, ...], n_steps, axis=0)
Lambda_elements = B_norm @ ic[0]
h0 = jax.lax.associative_scan(jnp.multiply, Lambda_elements) * h0
hidden_states return jax.vmap(lambda h: (C @ h).real)(hidden_states)
:::
::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class LRU(nn.Module):
"""
LRU module in charge of the recurrent processing.
Implementation following the one of Orvieto et al. 2023.
"""
int # hidden state dimension
d_hidden: int # input and output dimensions
d_model: float = 0.0 # smallest lambda norm
r_min: float = 1.0 # largest lambda norm
r_max: float = 6.28 # max phase lambda
max_phase: int] = None # number of steps to advance
n_steps: Optional[
def setup(self):
= self.param(
theta_log "theta_log", partial(theta_init, max_phase=self.max_phase), (self.d_hidden,)
)= self.param(
nu_log "nu_log",
=self.r_min, r_max=self.r_max),
partial(nu_init, r_minself.d_hidden,),
(
)= self.param("gamma_log", gamma_log_init, (nu_log, theta_log))
gamma_log
# Glorot initialized Input/Output projection matrices
= self.param(
B_re "B_re",
=jnp.sqrt(2 * self.d_model)),
partial(matrix_init, normalizationself.d_hidden, self.d_model),
(
)= self.param(
B_im "B_im",
=jnp.sqrt(2 * self.d_model)),
partial(matrix_init, normalizationself.d_hidden, self.d_model),
(
)= self.param(
C_re "C_re",
=jnp.sqrt(self.d_hidden)),
partial(matrix_init, normalizationself.d_model, self.d_hidden),
(
)= self.param(
C_im "C_im",
=jnp.sqrt(self.d_hidden)),
partial(matrix_init, normalizationself.d_model, self.d_hidden),
(
)self.D = self.param("D", matrix_init, (self.d_model,))
self.C = C_re + 1j * C_im
= B_re + 1j * B_im
B self.B_norm = B * jnp.exp(gamma_log)[..., None]
self.discrete_diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j * jnp.exp(theta_log))
def __call__(
self,
# (time, d_model)
inputs: jnp.ndarray,
):if self.n_steps is not None:
return apply_lru_dynamics_from_ic(
inputs,self.n_steps,
self.discrete_diag_lambda,
self.B_norm,
self.C,
)else:
return apply_lru_dynamics(
inputs,self.discrete_diag_lambda,
self.B_norm,
self.C,
self.D,
)
:::
Deep (Stacked) and Batched versions
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from einops import rearrange
from physmodjax.models.mlp import MLP
:::
::: {#cell-16 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SequenceLayer(nn.Module):
"""Single layer, with one SSM module, GLU, dropout and batch/layer norm"""
# ssm module
ssm: nn.Module int # model size
d_model: float = 0.0 # dropout probability
dropout: str = "layer" # which normalization to use
norm: bool = True # in training mode (dropout in trainign mode only)
training: str = "half_glu1" # activation function
activation: bool = True # whether to use pre or post normalization
prenorm:
def setup(self):
"""Initializes the ssm, layer norm and dropout"""
self.seq = self.ssm()
self.out1 = nn.Dense(self.d_model)
self.out2 = nn.Dense(self.d_model)
# self.d_model -> self.d_model * 4 -> self.d_model
# GPT mlp
self.mlp = MLP(
=[self.d_model * 4, self.d_model],
hidden_channels=nn.gelu,
activation
)
if self.norm in ["layer"]:
self.normalization = nn.LayerNorm()
else:
self.normalization = nn.BatchNorm(
=not self.training, axis_name="batch"
use_running_average
)self.drop = nn.Dropout(
self.dropout,
=[0],
broadcast_dims=not self.training,
deterministic
)
def __call__(self, x):
= x
skip if self.prenorm:
= self.normalization(x) # pre normalization
x = self.seq(x) # call LRU
x if self.activation in ["full_glu"]:
= self.drop(nn.gelu(x))
x = self.out1(x) * jax.nn.sigmoid(self.out2(x))
x = self.drop(x)
x elif self.activation in ["half_glu1"]:
= self.drop(nn.gelu(x))
x = x * jax.nn.sigmoid(self.out2(x))
x = self.drop(x)
x elif self.activation in ["gelu"]:
= self.drop(nn.gelu(x))
x elif self.activation in ["mlp"]:
= self.drop(self.mlp(x))
x else:
raise NotImplementedError(f"Activation {self.activation} not implemented")
= skip + x # skip connection
x if not self.prenorm:
= self.normalization(x)
x return x
:::
::: {#cell-17 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class StackedSSM(nn.Module):
# ssm module
ssm: nn.Module int # model size
d_model: int # number of variables
d_vars: int # number of layers
n_layers: = None # optional first layer usually for one-to-many
ssm_first_layer: nn.Module int] = None # number of steps to advance
n_steps: Optional[float = 0.0 # dropout probability
dropout: bool = True
training: str = "layer"
norm: str = "half_glu1"
activation: bool = True
prenorm:
def setup(self):
if self.ssm_first_layer is not None:
self.first_layer = self.ssm_first_layer(
=self.d_model * self.d_vars,
d_model=self.n_steps,
n_steps
)self.layers = [
SequenceLayer(=partial(self.ssm, d_model=self.d_model * self.d_vars),
ssm=self.d_model * self.d_vars,
d_model=self.dropout,
dropout=self.training,
training=self.norm,
norm=self.activation,
activation=self.prenorm,
prenorm
)for _ in range(self.n_layers)
]
def __call__(
self,
# (T, ...) or (W, C) # input
x: jnp.ndarray,
):= rearrange(x, "t w c -> t (w c)")
x
if self.ssm_first_layer is not None:
= self.first_layer(x)
x else:
= jnp.concatenate(
x 0:1], jnp.zeros((x.shape[0] - 1, x.shape[1]))], axis=0
[x[
)
for layer in self.layers:
= layer(x) # apply each layer
x
return rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
= nn.vmap(
BatchStackedSSMModel
StackedSSM,=0,
in_axes=0,
out_axes={
variable_axes"params": None,
"dropout": None,
"batch_stats": None,
"cache": 0,
"prime": None,
},={"params": False, "dropout": True},
split_rngs="batch",
axis_name )
:::
::: {#cell-18 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 10, 50, 20, 3
B, T, W, C = 64
d_hidden = BatchStackedSSMModel(
deep_ssm =partial(S5SSM, d_hidden=d_hidden, n_steps=50),
ssm_first_layer=partial(S5SSM, d_hidden=d_hidden),
ssm=W,
d_model=C,
d_vars=2,
n_layers
)= jnp.empty((B, T, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, W, C)
:::
::: {#cell-19 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= BatchStackedSSMModel(
deep_ssm =partial(LRU, d_hidden=d_hidden, n_steps=50),
ssm_first_layer=partial(LRU, d_hidden=d_hidden),
ssm=W,
d_model=C,
d_vars=2,
n_layers
)= jnp.empty((B, T, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, W, C)
:::
::: {#cell-20 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class StackedSSM2D(nn.Module):
# ssm module
ssm: nn.Module tuple[int, int]
d_model: int # number of variables
d_vars: int # number of layers
n_layers: = None # optional first layer usually for one-to-many
ssm_first_layer: nn.Module int] = None # number of steps to advance
n_steps: Optional[float = 0.0 # dropout probability
dropout: bool = True
training: str = "layer"
norm: str = "half_glu1"
activation: bool = True
prenorm:
def setup(self):
if self.ssm_first_layer is not None:
self.first_layer = self.ssm_first_layer(
=self.d_model[0] * self.d_model[1] * self.d_vars,
d_model=self.n_steps,
n_steps
)self.layers = [
SequenceLayer(=partial(
ssmself.ssm, d_model=self.d_model[0] * self.d_model[1] * self.d_vars
),=self.d_model[0] * self.d_model[1] * self.d_vars,
d_model=self.dropout,
dropout=self.training,
training=self.norm,
norm=self.activation,
activation=self.prenorm,
prenorm
)for _ in range(self.n_layers)
]
def __call__(
self,
# (T, H, W, C) or (H, W, C) # input
x: jnp.ndarray,
):= rearrange(x, "t h w c -> t (h w c)")
x
if self.ssm_first_layer is not None:
= self.first_layer(x)
x else:
= jnp.concatenate(
x 0:1], jnp.zeros((x.shape[0] - 1, x.shape[1]))], axis=0
[x[
)
for layer in self.layers:
= layer(x) # apply each layer
x
return rearrange(
x,"t (h w c) -> t h w c",
=self.d_model[0],
h=self.d_model[1],
w=self.d_vars,
c
)
= nn.vmap(
BatchStackedSSM2DModel
StackedSSM2D,=0,
in_axes=0,
out_axes={
variable_axes"params": None,
"dropout": None,
"batch_stats": None,
"cache": 0,
"prime": None,
},={"params": False, "dropout": True},
split_rngs="batch",
axis_name )
:::
::: {#cell-21 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 10, 50, 20, 20, 3
B, T, H, W, C = BatchStackedSSM2DModel(
deep_ssm =partial(LRU, d_hidden=d_hidden, n_steps=T),
ssm_first_layer=partial(LRU, d_hidden=d_hidden),
ssm=(H, W),
d_model=C,
d_vars=2,
n_layers
)
= jnp.empty((B, T, H, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, H, W, C)
:::
::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 10, 50, 20, 20, 3
B, T, H, W, C = BatchStackedSSM2DModel(
deep_ssm =partial(S5SSM, d_hidden=d_hidden, n_steps=T),
ssm_first_layer=partial(S5SSM, d_hidden=d_hidden),
ssm=(H, W),
d_model=C,
d_vars=2,
n_layers
)
= jnp.empty((B, T, H, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, H, W, C)
:::