Recurrent Models

LRU dynamics

Linear dynamics using initisialisation of the eigenvalues based on the LRU paper

LRU with MLP


source

LRUDynamicsVarying

 LRUDynamicsVarying (d_hidden:int, r_min:float, r_max:float,
                     max_phase:float, clip_eigs:bool,
                     model:flax.linen.module.Module, parent:Union[flax.lin
                     en.module.Module,flax.core.scope.Scope,flax.linen.mod
                     ule._Sentinel,NoneType]=<flax.linen.module._Sentinel
                     object at 0x11a1a75e0>, name:Optional[str]=None)
from physmodjax.models.mlp import MLP
d_hidden = 64
steps = 50
model = MLP(hidden_channels=[64, 64, 64])
dyn = LRUDynamicsVarying(
    d_hidden=d_hidden,
    r_min=0.99,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    model=model,
    clip_eigs=False,
)

Deep GRU


source

DeepRNN

 DeepRNN (d_model:int, d_vars:int, n_layers:int,
          cell:flax.linen.module.Module, training:bool=True,
          norm:str='layer', parent:Union[flax.linen.module.Module,flax.cor
          e.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax.linen.
          module._Sentinel object at 0x11a1a75e0>,
          name:Optional[str]=None)

A deep RNN model that applies a RNN cell over the last dimension of the input. Works with nn.GRUCell, nn.RNNCell, nn.SimpleCell, nn.MGUCell.

B, T, W, C = 10, 50, 20, 3
deep_rnn = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
x = jnp.ones((B, T, W, C))
x0 = jnp.ones((B, W, C))
variables = deep_rnn.init(jax.random.PRNGKey(65), x0, x)
out = deep_rnn.apply(variables, x0, x)

assert out.shape == (B, T, W, C)