from physmodjax.models.mlp import MLP
Recurrent Models
LRU dynamics
Linear dynamics using initisialisation of the eigenvalues based on the LRU paper
LRU with MLP
LRUDynamicsVarying
LRUDynamicsVarying (d_hidden:int, r_min:float, r_max:float, max_phase:float, clip_eigs:bool, model:flax.linen.module.Module, parent:Union[flax.lin en.module.Module,flax.core.scope.Scope,flax.linen.mod ule._Sentinel,NoneType]=<flax.linen.module._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
= 64
d_hidden = 50
steps = MLP(hidden_channels=[64, 64, 64])
model = LRUDynamicsVarying(
dyn =d_hidden,
d_hidden=0.99,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=model,
model=False,
clip_eigs )
Deep GRU
DeepRNN
DeepRNN (d_model:int, d_vars:int, n_layers:int, cell:flax.linen.module.Module, training:bool=True, norm:str='layer', parent:Union[flax.linen.module.Module,flax.cor e.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax.linen. module._Sentinel object at 0x11a1a75e0>, name:Optional[str]=None)
A deep RNN model that applies a RNN cell over the last dimension of the input. Works with nn.GRUCell, nn.RNNCell, nn.SimpleCell, nn.MGUCell.
= 10, 50, 20, 3
B, T, W, C = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
deep_rnn = jnp.ones((B, T, W, C))
x = jnp.ones((B, W, C))
x0 = deep_rnn.init(jax.random.PRNGKey(65), x0, x)
variables = deep_rnn.apply(variables, x0, x)
out
assert out.shape == (B, T, W, C)