SSM Models

S5 Model

adapted from https://github.com/lindermanlab/S5


source

init_CV

 init_CV (init_fun, rng, shape, V)

Initialize C_tilde=CV. First sample C. Then compute CV. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() rng: jax random key to be used with init function. shape (tuple): desired shape (H,P) V: (complex64) the eigenvectors used for initialization Returns: C_tilde (complex64) of shape (H,P,2)


source

trunc_standard_normal

 trunc_standard_normal (key, shape)

*Sample C with a truncated normal distribution with standard deviation 1. Args: key: jax random key shape (tuple): desired shape, of length 3, (H,P,_) Returns: sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)*


source

init_VinvB

 init_VinvB (init_fun, rng, shape, Vinv)

Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() rng: jax random key to be used with init function. shape (tuple): desired shape (P,H) Vinv: (complex64) the inverse eigenvectors used for initialization Returns: B_tilde (complex64) of shape (P,H,2)


source

init_log_steps

 init_log_steps (key, input)

Initialize an array of learnable timescale parameters Args: key: jax random key input: tuple containing the array shape H and dt_min and dt_max Returns: initialized array of timescales (float32): (H,)


source

log_step_initializer

 log_step_initializer (dt_min=0.001, dt_max=0.1)

Initialize the learnable timescale Delta by sampling uniformly between dt_min and dt_max. Args: dt_min (float32): minimum value dt_max (float32): maximum value Returns: init function


source

make_DPLR_HiPPO

 make_DPLR_HiPPO (N)

*Makes components needed for DPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Note, we will only use the diagonal part Args: N:

Returns: eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, eigenvectors V, HiPPO B pre-conjugation*


source

make_NPLR_HiPPO

 make_NPLR_HiPPO (N)

*Makes components needed for NPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size

Returns: N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B*


source

make_HiPPO

 make_HiPPO (N)

Create a HiPPO-LegS matrix. From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size Returns: N x N HiPPO LegS matrix


source

apply_ssm

 apply_ssm (Lambda_bar, B_bar, C_tilde, input_sequence, conj_sym,
            bidirectional)

Compute the LxH output of discretized SSM given an LxH input. Args: Lambda_bar (complex64): discretized diagonal state matrix (P,) B_bar (complex64): discretized input matrix (P, H) C_tilde (complex64): output matrix (H, P) input_sequence (float32): input sequence of features (L, H) conj_sym (bool): whether conjugate symmetry is enforced bidirectional (bool): whether bidirectional setup is used, Note for this case C_tilde will have 2P cols Returns: ys (float32): the SSM outputs (S5 layer preactivations) (L, H)


source

apply_dynamics

 apply_dynamics (x0, steps, Lambda_bar, B_bar, C_tilde, conj_sym,
                 bidirectional)

source

binary_operator

 binary_operator (q_i, q_j)

Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. Args: q_i: tuple containing A_i and Bu_i at position i (P,), (P,) q_j: tuple containing A_j and Bu_j at position j (P,), (P,) Returns: new element ( A_out, Bu_out )


source

discretize_zoh

 discretize_zoh (Lambda, B_tilde, Delta)

Discretize a diagonalized, continuous-time linear SSM using zero-order hold method. Args: Lambda (complex64): diagonal state matrix (P,) B_tilde (complex64): input matrix (P, H) Delta (float32): discretization step sizes (P,) Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)


source

discretize_bilinear

 discretize_bilinear (Lambda, B_tilde, Delta)

Discretize a diagonalized, continuous-time linear SSM using bilinear transform method. Args: Lambda (complex64): diagonal state matrix (P,) B_tilde (complex64): input matrix (P, H) Delta (float32): discretization step sizes (P,) Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)


source

S5SSM

 S5SSM (d_model:int, d_hidden:int, C_init:str='lecun_normal',
        discretization:str='zoh', dt_min:float=0.0001, dt_max:float=0.1,
        conj_sym:bool=True, clip_eigs:bool=False,
        bidirectional:bool=False, step_rescale:float=1.0, blocks:int=16,
        n_steps:Optional[int]=None, parent:Union[flax.linen.module.Module,
        flax.core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax.
        linen.module._Sentinel object at 0x12ff8ca90>,
        name:Optional[str]=None)

LRU Model

adapted from https://github.com/NicolasZucchet/minimal-LRU


source

gamma_log_init

 gamma_log_init (key, lamb)

source

theta_init

 theta_init (key, shape, max_phase, dtype=<class 'jax.numpy.float32'>)

source

nu_init

 nu_init (key, shape, r_min, r_max, dtype=<class 'jax.numpy.float32'>)

source

matrix_init

 matrix_init (key, shape, dtype=<class 'jax.numpy.float32'>,
              normalization=1)

source

LRUDynamics

 LRUDynamics (d_hidden:int, r_min:float, r_max:float, max_phase:float,
              clip_eigs:bool, parent:Union[flax.linen.module.Module,flax.c
              ore.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax.
              linen.module._Sentinel object at 0x12ff8ca90>,
              name:Optional[str]=None)

This class implements only the dynamics of the LRU model. x_{k+1} = A x_k

d_hidden = 64
steps = 50
dyn = LRUDynamics(d_hidden=d_hidden, r_min=0.99, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)
vars = dyn.init(jax.random.PRNGKey(0), jnp.ones((d_hidden)), 50)
out = dyn.apply(vars, jnp.ones((1, d_hidden)), 50)

assert out.shape == (steps, d_hidden)

source

apply_lru_dynamics_from_ic

 apply_lru_dynamics_from_ic (ic:jax.Array, n_steps:int,
                             discrete_lambda:jax.Array, B_norm:jax.Array,
                             C:jax.Array)
Type Details
ic Array (1, d_model)
n_steps int
discrete_lambda Array (d_hidden,)
B_norm Array (d_hidden, d_model)
C Array (d_model, d_hidden)

source

apply_lru_dynamics

 apply_lru_dynamics (inputs:jax.Array, discrete_lambda:jax.Array,
                     B_norm:jax.Array, C:jax.Array, D:jax.Array)
Type Details
inputs Array (time, d_model)
discrete_lambda Array (d_hidden,)
B_norm Array (d_hidden, d_model)
C Array (d_model, d_hidden)
D Array (d_model,)

source

LRU

 LRU (d_hidden:int, d_model:int, r_min:float=0.0, r_max:float=1.0,
      max_phase:float=6.28, n_steps:Optional[int]=None, parent:Union[flax.
      linen.module.Module,flax.core.scope.Scope,flax.linen.module._Sentine
      l,NoneType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>,
      name:Optional[str]=None)

LRU module in charge of the recurrent processing. Implementation following the one of Orvieto et al. 2023.

Deep (Stacked) and Batched versions


source

SequenceLayer

 SequenceLayer (ssm:flax.linen.module.Module, d_model:int,
                dropout:float=0.0, norm:str='layer', training:bool=True,
                activation:str='half_glu1', prenorm:bool=True, parent:Unio
                n[flax.linen.module.Module,flax.core.scope.Scope,flax.line
                n.module._Sentinel,NoneType]=<flax.linen.module._Sentinel
                object at 0x12ff8ca90>, name:Optional[str]=None)

Single layer, with one SSM module, GLU, dropout and batch/layer norm


source

StackedSSM

 StackedSSM (ssm:flax.linen.module.Module, d_model:int, d_vars:int,
             n_layers:int, ssm_first_layer:flax.linen.module.Module=None,
             n_steps:Optional[int]=None, dropout:float=0.0,
             training:bool=True, norm:str='layer',
             activation:str='half_glu1', prenorm:bool=True, parent:Union[f
             lax.linen.module.Module,flax.core.scope.Scope,flax.linen.modu
             le._Sentinel,NoneType]=<flax.linen.module._Sentinel object at
             0x12ff8ca90>, name:Optional[str]=None)
B, T, W, C = 10, 50, 20, 3
d_hidden = 64
deep_ssm = BatchStackedSSMModel(
    ssm_first_layer=partial(S5SSM, d_hidden=d_hidden, n_steps=50),
    ssm=partial(S5SSM, d_hidden=d_hidden),
    d_model=W,
    d_vars=C,
    n_layers=2,
)
x = jnp.empty((B, T, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, W, C)
deep_ssm = BatchStackedSSMModel(
    ssm_first_layer=partial(LRU, d_hidden=d_hidden, n_steps=50),
    ssm=partial(LRU, d_hidden=d_hidden),
    d_model=W,
    d_vars=C,
    n_layers=2,
)
x = jnp.empty((B, T, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, W, C)

source

StackedSSM2D

 StackedSSM2D (ssm:flax.linen.module.Module, d_model:Tuple[int,int],
               d_vars:int, n_layers:int,
               ssm_first_layer:flax.linen.module.Module=None,
               n_steps:Optional[int]=None, dropout:float=0.0,
               training:bool=True, norm:str='layer',
               activation:str='half_glu1', prenorm:bool=True, parent:Union
               [flax.linen.module.Module,flax.core.scope.Scope,flax.linen.
               module._Sentinel,NoneType]=<flax.linen.module._Sentinel
               object at 0x12ff8ca90>, name:Optional[str]=None)
B, T, H, W, C = 10, 50, 20, 20, 3
deep_ssm = BatchStackedSSM2DModel(
    ssm_first_layer=partial(LRU, d_hidden=d_hidden, n_steps=T),
    ssm=partial(LRU, d_hidden=d_hidden),
    d_model=(H, W),
    d_vars=C,
    n_layers=2,
)

x = jnp.empty((B, T, H, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, H, W, C)
B, T, H, W, C = 10, 50, 20, 20, 3
deep_ssm = BatchStackedSSM2DModel(
    ssm_first_layer=partial(S5SSM, d_hidden=d_hidden, n_steps=T),
    ssm=partial(S5SSM, d_hidden=d_hidden),
    d_model=(H, W),
    d_vars=C,
    n_layers=2,
)

x = jnp.empty((B, T, H, W, C))
variables = deep_ssm.init(jax.random.PRNGKey(65), x)
out = deep_ssm.apply(variables, x)

assert out.shape == (B, T, H, W, C)