= 64
d_hidden = 50
steps = LRUDynamics(d_hidden=d_hidden, r_min=0.99, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)
dyn vars = dyn.init(jax.random.PRNGKey(0), jnp.ones((d_hidden)), 50)
= dyn.apply(vars, jnp.ones((1, d_hidden)), 50)
out
assert out.shape == (steps, d_hidden)
SSM Models
S5 Model
adapted from https://github.com/lindermanlab/S5
init_CV
init_CV (init_fun, rng, shape, V)
Initialize C_tilde=CV. First sample C. Then compute CV. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() rng: jax random key to be used with init function. shape (tuple): desired shape (H,P) V: (complex64) the eigenvectors used for initialization Returns: C_tilde (complex64) of shape (H,P,2)
trunc_standard_normal
trunc_standard_normal (key, shape)
*Sample C with a truncated normal distribution with standard deviation 1. Args: key: jax random key shape (tuple): desired shape, of length 3, (H,P,_) Returns: sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)*
init_VinvB
init_VinvB (init_fun, rng, shape, Vinv)
Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() rng: jax random key to be used with init function. shape (tuple): desired shape (P,H) Vinv: (complex64) the inverse eigenvectors used for initialization Returns: B_tilde (complex64) of shape (P,H,2)
init_log_steps
init_log_steps (key, input)
Initialize an array of learnable timescale parameters Args: key: jax random key input: tuple containing the array shape H and dt_min and dt_max Returns: initialized array of timescales (float32): (H,)
log_step_initializer
log_step_initializer (dt_min=0.001, dt_max=0.1)
Initialize the learnable timescale Delta by sampling uniformly between dt_min and dt_max. Args: dt_min (float32): minimum value dt_max (float32): maximum value Returns: init function
make_DPLR_HiPPO
make_DPLR_HiPPO (N)
*Makes components needed for DPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Note, we will only use the diagonal part Args: N:
Returns: eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, eigenvectors V, HiPPO B pre-conjugation*
make_NPLR_HiPPO
make_NPLR_HiPPO (N)
*Makes components needed for NPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size
Returns: N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B*
make_HiPPO
make_HiPPO (N)
Create a HiPPO-LegS matrix. From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size Returns: N x N HiPPO LegS matrix
apply_ssm
apply_ssm (Lambda_bar, B_bar, C_tilde, input_sequence, conj_sym, bidirectional)
Compute the LxH output of discretized SSM given an LxH input. Args: Lambda_bar (complex64): discretized diagonal state matrix (P,) B_bar (complex64): discretized input matrix (P, H) C_tilde (complex64): output matrix (H, P) input_sequence (float32): input sequence of features (L, H) conj_sym (bool): whether conjugate symmetry is enforced bidirectional (bool): whether bidirectional setup is used, Note for this case C_tilde will have 2P cols Returns: ys (float32): the SSM outputs (S5 layer preactivations) (L, H)
apply_dynamics
apply_dynamics (x0, steps, Lambda_bar, B_bar, C_tilde, conj_sym, bidirectional)
binary_operator
binary_operator (q_i, q_j)
Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. Args: q_i: tuple containing A_i and Bu_i at position i (P,), (P,) q_j: tuple containing A_j and Bu_j at position j (P,), (P,) Returns: new element ( A_out, Bu_out )
discretize_zoh
discretize_zoh (Lambda, B_tilde, Delta)
Discretize a diagonalized, continuous-time linear SSM using zero-order hold method. Args: Lambda (complex64): diagonal state matrix (P,) B_tilde (complex64): input matrix (P, H) Delta (float32): discretization step sizes (P,) Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
discretize_bilinear
discretize_bilinear (Lambda, B_tilde, Delta)
Discretize a diagonalized, continuous-time linear SSM using bilinear transform method. Args: Lambda (complex64): diagonal state matrix (P,) B_tilde (complex64): input matrix (P, H) Delta (float32): discretization step sizes (P,) Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)
S5SSM
S5SSM (d_model:int, d_hidden:int, C_init:str='lecun_normal', discretization:str='zoh', dt_min:float=0.0001, dt_max:float=0.1, conj_sym:bool=True, clip_eigs:bool=False, bidirectional:bool=False, step_rescale:float=1.0, blocks:int=16, n_steps:Optional[int]=None, parent:Union[flax.linen.module.Module, flax.core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax. linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
LRU Model
adapted from https://github.com/NicolasZucchet/minimal-LRU
gamma_log_init
gamma_log_init (key, lamb)
theta_init
theta_init (key, shape, max_phase, dtype=<class 'jax.numpy.float32'>)
nu_init
nu_init (key, shape, r_min, r_max, dtype=<class 'jax.numpy.float32'>)
matrix_init
matrix_init (key, shape, dtype=<class 'jax.numpy.float32'>, normalization=1)
LRUDynamics
LRUDynamics (d_hidden:int, r_min:float, r_max:float, max_phase:float, clip_eigs:bool, parent:Union[flax.linen.module.Module,flax.c ore.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax. linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
This class implements only the dynamics of the LRU model. x_{k+1} = A x_k
apply_lru_dynamics_from_ic
apply_lru_dynamics_from_ic (ic:jax.Array, n_steps:int, discrete_lambda:jax.Array, B_norm:jax.Array, C:jax.Array)
Type | Details | |
---|---|---|
ic | Array | (1, d_model) |
n_steps | int | |
discrete_lambda | Array | (d_hidden,) |
B_norm | Array | (d_hidden, d_model) |
C | Array | (d_model, d_hidden) |
apply_lru_dynamics
apply_lru_dynamics (inputs:jax.Array, discrete_lambda:jax.Array, B_norm:jax.Array, C:jax.Array, D:jax.Array)
Type | Details | |
---|---|---|
inputs | Array | (time, d_model) |
discrete_lambda | Array | (d_hidden,) |
B_norm | Array | (d_hidden, d_model) |
C | Array | (d_model, d_hidden) |
D | Array | (d_model,) |
LRU
LRU (d_hidden:int, d_model:int, r_min:float=0.0, r_max:float=1.0, max_phase:float=6.28, n_steps:Optional[int]=None, parent:Union[flax. linen.module.Module,flax.core.scope.Scope,flax.linen.module._Sentine l,NoneType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
LRU module in charge of the recurrent processing. Implementation following the one of Orvieto et al. 2023.
Deep (Stacked) and Batched versions
SequenceLayer
SequenceLayer (ssm:flax.linen.module.Module, d_model:int, dropout:float=0.0, norm:str='layer', training:bool=True, activation:str='half_glu1', prenorm:bool=True, parent:Unio n[flax.linen.module.Module,flax.core.scope.Scope,flax.line n.module._Sentinel,NoneType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
Single layer, with one SSM module, GLU, dropout and batch/layer norm
StackedSSM
StackedSSM (ssm:flax.linen.module.Module, d_model:int, d_vars:int, n_layers:int, ssm_first_layer:flax.linen.module.Module=None, n_steps:Optional[int]=None, dropout:float=0.0, training:bool=True, norm:str='layer', activation:str='half_glu1', prenorm:bool=True, parent:Union[f lax.linen.module.Module,flax.core.scope.Scope,flax.linen.modu le._Sentinel,NoneType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
= 10, 50, 20, 3
B, T, W, C = 64
d_hidden = BatchStackedSSMModel(
deep_ssm =partial(S5SSM, d_hidden=d_hidden, n_steps=50),
ssm_first_layer=partial(S5SSM, d_hidden=d_hidden),
ssm=W,
d_model=C,
d_vars=2,
n_layers
)= jnp.empty((B, T, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, W, C)
= BatchStackedSSMModel(
deep_ssm =partial(LRU, d_hidden=d_hidden, n_steps=50),
ssm_first_layer=partial(LRU, d_hidden=d_hidden),
ssm=W,
d_model=C,
d_vars=2,
n_layers
)= jnp.empty((B, T, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, W, C)
StackedSSM2D
StackedSSM2D (ssm:flax.linen.module.Module, d_model:Tuple[int,int], d_vars:int, n_layers:int, ssm_first_layer:flax.linen.module.Module=None, n_steps:Optional[int]=None, dropout:float=0.0, training:bool=True, norm:str='layer', activation:str='half_glu1', prenorm:bool=True, parent:Union [flax.linen.module.Module,flax.core.scope.Scope,flax.linen. module._Sentinel,NoneType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
= 10, 50, 20, 20, 3
B, T, H, W, C = BatchStackedSSM2DModel(
deep_ssm =partial(LRU, d_hidden=d_hidden, n_steps=T),
ssm_first_layer=partial(LRU, d_hidden=d_hidden),
ssm=(H, W),
d_model=C,
d_vars=2,
n_layers
)
= jnp.empty((B, T, H, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, H, W, C)
= 10, 50, 20, 20, 3
B, T, H, W, C = BatchStackedSSM2DModel(
deep_ssm =partial(S5SSM, d_hidden=d_hidden, n_steps=T),
ssm_first_layer=partial(S5SSM, d_hidden=d_hidden),
ssm=(H, W),
d_model=C,
d_vars=2,
n_layers
)
= jnp.empty((B, T, H, W, C))
x = deep_ssm.init(jax.random.PRNGKey(65), x)
variables = deep_ssm.apply(variables, x)
out
assert out.shape == (B, T, H, W, C)