# TODO
# the LSTMCell will need an extra dimension for the hidden state
# deep_rnn = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
# x = jnp.ones((B, T, W, C))
# x0 = jnp.ones((B, W, C))
# variables = deep_rnn.init(jax.random.PRNGKey(65), (x0, x0), x)
# out = deep_rnn.apply(variables, (x0, x0), x)
Recurrent Models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from abc import ABC, abstractmethod
from functools import partial
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from jax.typing import ArrayLike
from physmodjax.models.mlp import MLP
from physmodjax.utils.clamp import magic_clamp
from physmodjax.utils.eigenvalues import (
ensure_positive_imaginary_parts,
multiply_eigenvalues, )
:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def matrix_init(key, shape, dtype=jnp.float32, normalization=1):
return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization
def nu_init(key, shape, r_min, r_max, dtype=jnp.float32):
= jax.random.uniform(key=key, shape=shape, dtype=dtype)
u return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2))
def theta_init(key, shape, max_phase, dtype=jnp.float32):
= jax.random.uniform(key, shape=shape, dtype=dtype)
u return jnp.log(max_phase * u)
def gamma_log_init(key, lamb):
= lamb
nu, theta = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta))
diag_lambda return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2))
:::
Base modal class
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class BaseModalDynamics(nn.Module, ABC):
int
d_hidden: bool
prepend_ones:
@abstractmethod
def get_eigenvalues(self) -> ArrayLike:
"""
Abstract method that must be implemented by subclasses.
Returns:
jnp.ndarray: Eigenvalues as a JAX array.
"""
pass
def apply_eigenvalue_transform(
self,
eigenvalues: ArrayLike,int,
n_steps: -> ArrayLike:
) """Compute dynamics over n_steps based on eigenvalues."""
= jnp.repeat(eigenvalues[None], n_steps, axis=0)
z
if self.prepend_ones:
= jnp.concatenate(
z 1, self.d_hidden), dtype=jnp.complex64), z], axis=0
[jnp.ones((-1, :]
)[:
return jax.lax.associative_scan(jnp.multiply, z)
def compute_dynamics(
self,
x: ArrayLike,int,
n_steps: -> ArrayLike:
) """Advance the dynamics for `n_steps` using the initial state `x`."""
= self.get_eigenvalues()
eigenvalues = self.apply_eigenvalue_transform(eigenvalues, n_steps)
dynamics return dynamics * x
:::
LRU dynamics
Linear dynamics using initisialisation of the eigenvalues based on the LRU paper
::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class LRUDynamics(BaseModalDynamics):
"""
This class implements only the dynamics of the LRU model.
x_{k+1} = A x_k
"""
float # smallest eigenvalue radius
r_min: float # largest eigenvalue radius
r_max: float # largest phase
max_phase: bool # whether to clip the eigenvalues
clip_eigs:
def setup(self):
self.theta_log = self.param(
"theta_log", partial(theta_init, max_phase=self.max_phase), (self.d_hidden,)
)self.nu_log = self.param(
"nu_log",
=self.r_min, r_max=self.r_max),
partial(nu_init, r_minself.d_hidden,),
(
)
def __call__(
self,
# initial complex state flattened (d_hidden,) complex
x: jnp.ndarray, int, # number of steps to advance
steps: -> jnp.ndarray: # advanced state (steps, d_hidden) complex
) = -jnp.exp(self.nu_log)
A_real = jnp.exp(self.theta_log)
A_imag
# clip the eigenvalues to be only negative (not strictly necessary, because of the extra log)
if self.clip_eigs:
= jnp.clip(A_real, None, -1e-5)
A_real
= jnp.exp(A_real + 1j * A_imag)
A_diag = jnp.repeat(A_diag[None, :], steps, axis=0)
z
if self.prepend_ones:
# prepend ones to the beginning and slice the last element
# this is needed to start from the initial state
= jnp.concatenate(
z 1, self.d_hidden), dtype=jnp.complex64), z], axis=0
[jnp.ones((-1, :]
)[:
# advance the state
= jax.lax.associative_scan(jnp.multiply, z) * x
x
return x
def get_eigenvalues(self) -> ArrayLike:
= jnp.exp(self.nu_log)
nu = jnp.exp(self.theta_log)
theta return jnp.exp(-nu + 1j * theta)
def set_eigenvalues(self, eigenvalues: ArrayLike):
# convert to continuous
= jnp.log(eigenvalues)
eigenvalues_mod = ensure_positive_imaginary_parts(eigenvalues_mod)
eigenvalues_mod
# take the log of each part and assign it to the params
# note that the sign of the real part is flipped before taking the log
# otherwise we would get nan values
self.put_variable("params", "nu_log", jnp.log(-eigenvalues_mod.real))
self.put_variable("params", "theta_log", jnp.log(eigenvalues_mod.imag))
# def scale_dynamics(
# self,
# angle_factor: float,
# radius_factor: float,
# ) -> ArrayLike:
# discrete_eigenvalues = self.get_eigenvalues()
# discrete_eigenvalues_mod = multiply_eigenvalues(
# discrete_eigenvalues,
# angle_factor,
# radius_factor,
# )
# # convert to continuous
# eigenvalues_mod = jnp.log(discrete_eigenvalues_mod)
# eigenvalues_mod = ensure_positive_imaginary_parts(eigenvalues_mod)
# # take the log of each part and assign it to the params
# # note that the sign of the real part is flipped before taking the log
# # otherwise we would get nan values
# self.put_variable("params", "nu_log", jnp.log(-eigenvalues_mod.real))
# self.put_variable("params", "theta_log", jnp.log(eigenvalues_mod.imag))
# return self.get_eigenvalues()
:::
::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 64
d_hidden = 50
steps = LRUDynamics(
dyn =d_hidden,
d_hidden=0.99,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=False,
clip_eigs=False,
prepend_ones
)vars = dyn.init(jax.random.PRNGKey(0), jnp.ones(d_hidden), 50)
= dyn.apply(vars, jnp.ones((1, d_hidden)), 50)
out
assert out.shape == (steps, d_hidden)
:::
LRU with MLP
::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class LRUDynamicsVarying(LRUDynamics):
# model to process the linear state
model: nn.Module
def setup(self):
super().setup()
def __call__(
self,
# initial complex state flattened (d_hidden,) complex
x: jnp.ndarray, int, # number of steps to advance
steps: -> jnp.ndarray: # advanced state (steps, d_hidden) complex
) = super().__call__(x, steps)
x = self.model(x.real**2 + x.imag**2)
x_hat = x_hat[..., : self.d_hidden] + 1j * x_hat[..., self.d_hidden :]
x_hat = x * x_hat
x return x
:::
::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 64
d_hidden = 50
steps = MLP(hidden_channels=[64, 64, 64])
model = LRUDynamicsVarying(
dyn =d_hidden,
d_hidden=0.99,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=model,
model=False,
clip_eigs=False,
prepend_ones )
:::
Deep GRU
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DeepRNN(nn.Module):
"""
A deep RNN model that applies a RNN cell over the last dimension of the input.
Works with nn.GRUCell, nn.RNNCell, nn.SimpleCell, nn.MGUCell.
"""
int
d_model: int
d_vars: int
n_layers:
cell: nn.Modulebool = True
training: str = "layer"
norm:
def setup(self):
# scan does the same thing as nn.RNN (unrolls the over the time dimension)
self.first_layer = nn.RNN(
self.cell(features=self.d_model * self.d_vars),
)
self.layers = [
nn.RNN(self.cell(features=self.d_model * self.d_vars),
)for _ in range(self.n_layers)
]
def __call__(
self,
# (W, C) # initial state
x0: jnp.ndarray, # (T, W, C) # empty state
x: jnp.ndarray, -> jnp.ndarray: # (T, W, C) # advanced state
) # the rnn works over the last dimension
# we need to reshape the input to (T, d_model * C)
= rearrange(x0, "w c -> (w c)")
x0 = rearrange(x, "t w c -> t (w c)")
x = self.first_layer(x, initial_carry=x0)
x for layer in self.layers:
= layer(x)
x return rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
= nn.vmap(
BatchedDeepRNN
DeepRNN,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs="batch",
axis_name )
:::
::: {#cell-16 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 10, 50, 20, 3
B, T, W, C = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
deep_rnn = jnp.ones((B, T, W, C))
x = jnp.ones((B, W, C))
x0 = deep_rnn.init(jax.random.PRNGKey(65), x0, x)
variables = deep_rnn.apply(variables, x0, x)
out
assert out.shape == (B, T, W, C)
:::
Complex oscillator
::: {#cell-19 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def constrain_to_unit_circle(x):
= jnp.abs(x)
mag = jnp.angle(x)
angle # return x * nn.tanh(mag) / mag
= magic_clamp(mag, 1e-8, 1.0)
x return x * jnp.exp(1j * angle)
:::
::: {#cell-20 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def init_complex(
=0.0,
r_min=1.0,
r_max=0.0,
min_phase=2 * jnp.pi,
max_phase=jnp.float32,
dtype
):def init(key, shape, dtype=dtype) -> ArrayLike:
= jax.dtypes.canonicalize_dtype(dtype)
dtype = jax.random.uniform(
radius
key,
shape,
dtype,=r_min,
minval=r_max,
maxval
)= jax.random.uniform(
phase 0],
jax.random.split(key)[
shape,
dtype,=min_phase,
minval=max_phase,
maxval
)return radius * jnp.exp(1j * phase)
return init
def init_real_imag(
=0.0,
r_min=1.0,
r_max=0.0,
min_phase=2 * jnp.pi,
max_phase=jnp.float32,
dtype
):def init(key, shape, dtype=dtype) -> ArrayLike:
= jax.dtypes.canonicalize_dtype(dtype)
dtype = jax.random.uniform(
radius
key,
shape,
dtype,=r_min,
minval=r_max,
maxval
)= jax.random.uniform(
phase 0],
jax.random.split(key)[
shape,
dtype,=min_phase,
minval=max_phase,
maxval
)return jnp.stack([radius * jnp.cos(phase), radius * jnp.sin(phase)], axis=-1)
return init
:::
::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DiscreteModalDynamics(BaseModalDynamics):
bool
clip: float
r_min: float
r_max: float
min_phase: float
max_phase:
def setup(self):
self.real_imag = self.param(
"real_imag",
init_real_imag(=self.r_min,
r_min=self.r_max,
r_max=self.min_phase,
min_phase=self.max_phase,
max_phase
),self.d_hidden,),
(
)
self.z = self.real_imag[..., 0] + 1j * self.real_imag[..., 1]
if self.clip:
self.z = constrain_to_unit_circle(self.z)
def __call__(
self,
# initial complex state (d_hidden,)
x: ArrayLike, int, # number of steps to advance
n_steps:
):= jnp.repeat(self.z[None, :], n_steps, axis=0)
z
if self.prepend_ones:
# prepend ones to the beginning and slice the last element
# this is needed to start from the initial state
= jnp.concatenate(
z 1, self.d_hidden), dtype=jnp.complex64), z], axis=0
[jnp.ones((-1, :]
)[:return jax.lax.associative_scan(jnp.multiply, z) * x
def get_eigenvalues(self) -> ArrayLike:
if self.clip:
= constrain_to_unit_circle(self.z)
z
return z
def set_eigenvalues(self, eigenvalues: ArrayLike):
= jnp.stack(
eigenvalues_as_real_imag
[jnp.real(eigenvalues), jnp.imag(eigenvalues)],=-1,
axis
)self.put_variable("params", "real_imag", eigenvalues_as_real_imag)
:::
::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 10
d_hidden = True
clip reduce = False
= 100
n_steps
= DiscreteModalDynamics(
model =d_hidden,
d_hidden=clip,
clip=0.3,
r_min=1,
r_max=0.01,
min_phase=jnp.pi / 2,
max_phase=False,
prepend_ones
)
= model.init_with_output(
output, variables 0),
jax.random.PRNGKey(
jnp.ones((d_hidden,)),=n_steps,
n_steps
)
assert output.shape == (
n_steps,
d_hidden, )
:::
::: {#cell-23 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DiscreteModalDynamicsVarying(DiscreteModalDynamics):
# model to process the linear state
model: nn.Module
def setup(self):
super().setup()
def __call__(
self,
# initial complex state flattened (d_hidden,) complex
x: jnp.ndarray, int, # number of steps to advance
n_steps: -> jnp.ndarray: # advanced state (steps, d_hidden) complex
) = super().__call__(x, n_steps)
x = self.model(x.real**2 + x.imag**2)
x_hat = x_hat[..., : self.d_hidden] + 1j * x_hat[..., self.d_hidden :]
x_hat = x * x_hat
x return x
:::
::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 10
d_hidden = True
clip reduce = False
= 100
n_steps
= DiscreteModalDynamicsVarying(
model =d_hidden,
d_hidden=clip,
clip=0.3,
r_min=1,
r_max=0.01,
min_phase=jnp.pi / 2,
max_phase=MLP(hidden_channels=[64, 64, 20]),
model=False,
prepend_ones
)
= model.init_with_output(
output, variables 0),
jax.random.PRNGKey(
jnp.ones((d_hidden,)),=n_steps,
n_steps
)assert output.shape == (
n_steps,
d_hidden, )
:::
::: {#cell-25 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DiscreteModalDynamicsAngleMag(BaseModalDynamics):
bool
clip: float
r_min: float
r_max: float
min_phase: float
max_phase:
def setup(self):
self.angles = self.param(
"angles",
lambda key, shape: jax.random.uniform(
key,
shape,=self.min_phase,
minval=self.max_phase,
maxval
),self.d_hidden,),
(
)
self.magnitudes = self.param(
"magnitudes",
lambda key, shape: jax.random.uniform(
key,
shape,=self.r_min,
minval=self.r_max,
maxval
),self.d_hidden,),
(
)
@nn.compact
def __call__(
self,
# initial complex state flattened (d_hidden,) complex
x: jnp.ndarray, int, # number of steps to advance
n_steps:
):if self.clip:
= magic_clamp(self.magnitudes, 1e-8, 1.0)
magnitudes else:
= self.magnitudes
magnitudes
= magnitudes * jnp.exp(1j * self.angles)
z
= jnp.repeat(z[None], n_steps, axis=0)
z
if self.prepend_ones:
# prepend ones to the beginning and slice the last element
# this is needed to start from the initial state
= jnp.concatenate(
z 1, self.d_hidden), dtype=jnp.complex64), z], axis=0
[jnp.ones((-1, :]
)[:return jax.lax.associative_scan(jnp.multiply, z) * x
def get_eigenvalues(self) -> ArrayLike:
if self.clip:
= jnp.clip(self.magnitudes, 1e-8, 1.0)
magnitudes else:
= self.magnitudes
magnitudes
return magnitudes * jnp.exp(1j * self.angles)
def set_eigenvalues(self, eigenvalues: ArrayLike):
self.put_variable("params", "angles", jnp.angle(eigenvalues))
self.put_variable("params", "magnitudes", jnp.abs(eigenvalues))
:::