# TODO
# the LSTMCell will need an extra dimension for the hidden state
# deep_rnn = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
# x = jnp.ones((B, T, W, C))
# x0 = jnp.ones((B, W, C))
# variables = deep_rnn.init(jax.random.PRNGKey(65), (x0, x0), x)
# out = deep_rnn.apply(variables, (x0, x0), x)Recurrent Models
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’ 6=‘i’}
from abc import ABC, abstractmethod
from functools import partial
import flax.linen as nn
import jax
import jax.numpy as jnp
from einops import rearrange
from jax.typing import ArrayLike
from physmodjax.models.mlp import MLP
from physmodjax.utils.clamp import magic_clamp
from physmodjax.utils.eigenvalues import (
ensure_positive_imaginary_parts,
multiply_eigenvalues,
):::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def matrix_init(key, shape, dtype=jnp.float32, normalization=1):
return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization
def nu_init(key, shape, r_min, r_max, dtype=jnp.float32):
u = jax.random.uniform(key=key, shape=shape, dtype=dtype)
return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2))
def theta_init(key, shape, max_phase, dtype=jnp.float32):
u = jax.random.uniform(key, shape=shape, dtype=dtype)
return jnp.log(max_phase * u)
def gamma_log_init(key, lamb):
nu, theta = lamb
diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta))
return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2)):::
Base modal class
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class BaseModalDynamics(nn.Module, ABC):
d_hidden: int
prepend_ones: bool
@abstractmethod
def get_eigenvalues(self) -> ArrayLike:
"""
Abstract method that must be implemented by subclasses.
Returns:
jnp.ndarray: Eigenvalues as a JAX array.
"""
pass
def apply_eigenvalue_transform(
self,
eigenvalues: ArrayLike,
n_steps: int,
) -> ArrayLike:
"""Compute dynamics over n_steps based on eigenvalues."""
z = jnp.repeat(eigenvalues[None], n_steps, axis=0)
if self.prepend_ones:
z = jnp.concatenate(
[jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
)[:-1, :]
return jax.lax.associative_scan(jnp.multiply, z)
def compute_dynamics(
self,
x: ArrayLike,
n_steps: int,
) -> ArrayLike:
"""Advance the dynamics for `n_steps` using the initial state `x`."""
eigenvalues = self.get_eigenvalues()
dynamics = self.apply_eigenvalue_transform(eigenvalues, n_steps)
return dynamics * x:::
LRU dynamics
Linear dynamics using initisialisation of the eigenvalues based on the LRU paper
::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class LRUDynamics(BaseModalDynamics):
"""
This class implements only the dynamics of the LRU model.
x_{k+1} = A x_k
"""
r_min: float # smallest eigenvalue radius
r_max: float # largest eigenvalue radius
max_phase: float # largest phase
clip_eigs: bool # whether to clip the eigenvalues
def setup(self):
self.theta_log = self.param(
"theta_log", partial(theta_init, max_phase=self.max_phase), (self.d_hidden,)
)
self.nu_log = self.param(
"nu_log",
partial(nu_init, r_min=self.r_min, r_max=self.r_max),
(self.d_hidden,),
)
def __call__(
self,
x: jnp.ndarray, # initial complex state flattened (d_hidden,) complex
steps: int, # number of steps to advance
) -> jnp.ndarray: # advanced state (steps, d_hidden) complex
A_real = -jnp.exp(self.nu_log)
A_imag = jnp.exp(self.theta_log)
# clip the eigenvalues to be only negative (not strictly necessary, because of the extra log)
if self.clip_eigs:
A_real = jnp.clip(A_real, None, -1e-5)
A_diag = jnp.exp(A_real + 1j * A_imag)
z = jnp.repeat(A_diag[None, :], steps, axis=0)
if self.prepend_ones:
# prepend ones to the beginning and slice the last element
# this is needed to start from the initial state
z = jnp.concatenate(
[jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
)[:-1, :]
# advance the state
x = jax.lax.associative_scan(jnp.multiply, z) * x
return x
def get_eigenvalues(self) -> ArrayLike:
nu = jnp.exp(self.nu_log)
theta = jnp.exp(self.theta_log)
return jnp.exp(-nu + 1j * theta)
def set_eigenvalues(self, eigenvalues: ArrayLike):
# convert to continuous
eigenvalues_mod = jnp.log(eigenvalues)
eigenvalues_mod = ensure_positive_imaginary_parts(eigenvalues_mod)
# take the log of each part and assign it to the params
# note that the sign of the real part is flipped before taking the log
# otherwise we would get nan values
self.put_variable("params", "nu_log", jnp.log(-eigenvalues_mod.real))
self.put_variable("params", "theta_log", jnp.log(eigenvalues_mod.imag))
# def scale_dynamics(
# self,
# angle_factor: float,
# radius_factor: float,
# ) -> ArrayLike:
# discrete_eigenvalues = self.get_eigenvalues()
# discrete_eigenvalues_mod = multiply_eigenvalues(
# discrete_eigenvalues,
# angle_factor,
# radius_factor,
# )
# # convert to continuous
# eigenvalues_mod = jnp.log(discrete_eigenvalues_mod)
# eigenvalues_mod = ensure_positive_imaginary_parts(eigenvalues_mod)
# # take the log of each part and assign it to the params
# # note that the sign of the real part is flipped before taking the log
# # otherwise we would get nan values
# self.put_variable("params", "nu_log", jnp.log(-eigenvalues_mod.real))
# self.put_variable("params", "theta_log", jnp.log(eigenvalues_mod.imag))
# return self.get_eigenvalues():::
::: {#cell-10 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
d_hidden = 64
steps = 50
dyn = LRUDynamics(
d_hidden=d_hidden,
r_min=0.99,
r_max=1.0,
max_phase=jnp.pi * 2,
clip_eigs=False,
prepend_ones=False,
)
vars = dyn.init(jax.random.PRNGKey(0), jnp.ones(d_hidden), 50)
out = dyn.apply(vars, jnp.ones((1, d_hidden)), 50)
assert out.shape == (steps, d_hidden):::
LRU with MLP
::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class LRUDynamicsVarying(LRUDynamics):
model: nn.Module # model to process the linear state
def setup(self):
super().setup()
def __call__(
self,
x: jnp.ndarray, # initial complex state flattened (d_hidden,) complex
steps: int, # number of steps to advance
) -> jnp.ndarray: # advanced state (steps, d_hidden) complex
x = super().__call__(x, steps)
x_hat = self.model(x.real**2 + x.imag**2)
x_hat = x_hat[..., : self.d_hidden] + 1j * x_hat[..., self.d_hidden :]
x = x * x_hat
return x:::
::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
d_hidden = 64
steps = 50
model = MLP(hidden_channels=[64, 64, 64])
dyn = LRUDynamicsVarying(
d_hidden=d_hidden,
r_min=0.99,
r_max=1.0,
max_phase=jnp.pi * 2,
model=model,
clip_eigs=False,
prepend_ones=False,
):::
Deep GRU
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DeepRNN(nn.Module):
"""
A deep RNN model that applies a RNN cell over the last dimension of the input.
Works with nn.GRUCell, nn.RNNCell, nn.SimpleCell, nn.MGUCell.
"""
d_model: int
d_vars: int
n_layers: int
cell: nn.Module
training: bool = True
norm: str = "layer"
def setup(self):
# scan does the same thing as nn.RNN (unrolls the over the time dimension)
self.first_layer = nn.RNN(
self.cell(features=self.d_model * self.d_vars),
)
self.layers = [
nn.RNN(
self.cell(features=self.d_model * self.d_vars),
)
for _ in range(self.n_layers)
]
def __call__(
self,
x0: jnp.ndarray, # (W, C) # initial state
x: jnp.ndarray, # (T, W, C) # empty state
) -> jnp.ndarray: # (T, W, C) # advanced state
# the rnn works over the last dimension
# we need to reshape the input to (T, d_model * C)
x0 = rearrange(x0, "w c -> (w c)")
x = rearrange(x, "t w c -> t (w c)")
x = self.first_layer(x, initial_carry=x0)
for layer in self.layers:
x = layer(x)
return rearrange(x, "t (w c) -> t w c", w=self.d_model, c=self.d_vars)
BatchedDeepRNN = nn.vmap(
DeepRNN,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
axis_name="batch",
):::
::: {#cell-16 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
B, T, W, C = 10, 50, 20, 3
deep_rnn = BatchedDeepRNN(d_model=W, d_vars=C, n_layers=2, cell=partial(nn.GRUCell))
x = jnp.ones((B, T, W, C))
x0 = jnp.ones((B, W, C))
variables = deep_rnn.init(jax.random.PRNGKey(65), x0, x)
out = deep_rnn.apply(variables, x0, x)
assert out.shape == (B, T, W, C):::
Complex oscillator
::: {#cell-19 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def constrain_to_unit_circle(x):
mag = jnp.abs(x)
angle = jnp.angle(x)
# return x * nn.tanh(mag) / mag
x = magic_clamp(mag, 1e-8, 1.0)
return x * jnp.exp(1j * angle):::
::: {#cell-20 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def init_complex(
r_min=0.0,
r_max=1.0,
min_phase=0.0,
max_phase=2 * jnp.pi,
dtype=jnp.float32,
):
def init(key, shape, dtype=dtype) -> ArrayLike:
dtype = jax.dtypes.canonicalize_dtype(dtype)
radius = jax.random.uniform(
key,
shape,
dtype,
minval=r_min,
maxval=r_max,
)
phase = jax.random.uniform(
jax.random.split(key)[0],
shape,
dtype,
minval=min_phase,
maxval=max_phase,
)
return radius * jnp.exp(1j * phase)
return init
def init_real_imag(
r_min=0.0,
r_max=1.0,
min_phase=0.0,
max_phase=2 * jnp.pi,
dtype=jnp.float32,
):
def init(key, shape, dtype=dtype) -> ArrayLike:
dtype = jax.dtypes.canonicalize_dtype(dtype)
radius = jax.random.uniform(
key,
shape,
dtype,
minval=r_min,
maxval=r_max,
)
phase = jax.random.uniform(
jax.random.split(key)[0],
shape,
dtype,
minval=min_phase,
maxval=max_phase,
)
return jnp.stack([radius * jnp.cos(phase), radius * jnp.sin(phase)], axis=-1)
return init:::
::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DiscreteModalDynamics(BaseModalDynamics):
clip: bool
r_min: float
r_max: float
min_phase: float
max_phase: float
def setup(self):
self.real_imag = self.param(
"real_imag",
init_real_imag(
r_min=self.r_min,
r_max=self.r_max,
min_phase=self.min_phase,
max_phase=self.max_phase,
),
(self.d_hidden,),
)
self.z = self.real_imag[..., 0] + 1j * self.real_imag[..., 1]
if self.clip:
self.z = constrain_to_unit_circle(self.z)
def __call__(
self,
x: ArrayLike, # initial complex state (d_hidden,)
n_steps: int, # number of steps to advance
):
z = jnp.repeat(self.z[None, :], n_steps, axis=0)
if self.prepend_ones:
# prepend ones to the beginning and slice the last element
# this is needed to start from the initial state
z = jnp.concatenate(
[jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
)[:-1, :]
return jax.lax.associative_scan(jnp.multiply, z) * x
def get_eigenvalues(self) -> ArrayLike:
if self.clip:
z = constrain_to_unit_circle(self.z)
return z
def set_eigenvalues(self, eigenvalues: ArrayLike):
eigenvalues_as_real_imag = jnp.stack(
[jnp.real(eigenvalues), jnp.imag(eigenvalues)],
axis=-1,
)
self.put_variable("params", "real_imag", eigenvalues_as_real_imag):::
::: {#cell-22 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
d_hidden = 10
clip = True
reduce = False
n_steps = 100
model = DiscreteModalDynamics(
d_hidden=d_hidden,
clip=clip,
r_min=0.3,
r_max=1,
min_phase=0.01,
max_phase=jnp.pi / 2,
prepend_ones=False,
)
output, variables = model.init_with_output(
jax.random.PRNGKey(0),
jnp.ones((d_hidden,)),
n_steps=n_steps,
)
assert output.shape == (
n_steps,
d_hidden,
):::
::: {#cell-23 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DiscreteModalDynamicsVarying(DiscreteModalDynamics):
model: nn.Module # model to process the linear state
def setup(self):
super().setup()
def __call__(
self,
x: jnp.ndarray, # initial complex state flattened (d_hidden,) complex
n_steps: int, # number of steps to advance
) -> jnp.ndarray: # advanced state (steps, d_hidden) complex
x = super().__call__(x, n_steps)
x_hat = self.model(x.real**2 + x.imag**2)
x_hat = x_hat[..., : self.d_hidden] + 1j * x_hat[..., self.d_hidden :]
x = x * x_hat
return x:::
::: {#cell-24 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
d_hidden = 10
clip = True
reduce = False
n_steps = 100
model = DiscreteModalDynamicsVarying(
d_hidden=d_hidden,
clip=clip,
r_min=0.3,
r_max=1,
min_phase=0.01,
max_phase=jnp.pi / 2,
model=MLP(hidden_channels=[64, 64, 20]),
prepend_ones=False,
)
output, variables = model.init_with_output(
jax.random.PRNGKey(0),
jnp.ones((d_hidden,)),
n_steps=n_steps,
)
assert output.shape == (
n_steps,
d_hidden,
):::
::: {#cell-25 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DiscreteModalDynamicsAngleMag(BaseModalDynamics):
clip: bool
r_min: float
r_max: float
min_phase: float
max_phase: float
def setup(self):
self.angles = self.param(
"angles",
lambda key, shape: jax.random.uniform(
key,
shape,
minval=self.min_phase,
maxval=self.max_phase,
),
(self.d_hidden,),
)
self.magnitudes = self.param(
"magnitudes",
lambda key, shape: jax.random.uniform(
key,
shape,
minval=self.r_min,
maxval=self.r_max,
),
(self.d_hidden,),
)
@nn.compact
def __call__(
self,
x: jnp.ndarray, # initial complex state flattened (d_hidden,) complex
n_steps: int, # number of steps to advance
):
if self.clip:
magnitudes = magic_clamp(self.magnitudes, 1e-8, 1.0)
else:
magnitudes = self.magnitudes
z = magnitudes * jnp.exp(1j * self.angles)
z = jnp.repeat(z[None], n_steps, axis=0)
if self.prepend_ones:
# prepend ones to the beginning and slice the last element
# this is needed to start from the initial state
z = jnp.concatenate(
[jnp.ones((1, self.d_hidden), dtype=jnp.complex64), z], axis=0
)[:-1, :]
return jax.lax.associative_scan(jnp.multiply, z) * x
def get_eigenvalues(self) -> ArrayLike:
if self.clip:
magnitudes = jnp.clip(self.magnitudes, 1e-8, 1.0)
else:
magnitudes = self.magnitudes
return magnitudes * jnp.exp(1j * self.angles)
def set_eigenvalues(self, eigenvalues: ArrayLike):
self.put_variable("params", "angles", jnp.angle(eigenvalues))
self.put_variable("params", "magnitudes", jnp.abs(eigenvalues)):::