jax.config.update("jax_platform_name", "cpu")FNO embedded in a recurrent neural network
This notebook adapts the FNO for its use in a recurrent neural network. The idea is to use the FNO to learn the dynamics of a system, and then use the FNO as a layer in a recurrent neural network to learn the dynamics of the system over time.
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax.numpy as jnp
from physmodjax.models.fno import SpectralLayers1d
from flax import linen as nn
import jax:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FNOCell(nn.Module):
"""
Parker's ARMA without input
"""
hidden_channels: int
grid_size: int
layers: int = 4
out_channels: int = 1
activation: nn.Module = nn.relu
@nn.compact
def __call__(
self,
h, # hidden state (grid_size, hidden_channels)
x, # input (grid_size, 1)
):
down_lifting = nn.Dense(features=self.out_channels)
spectral_layers = SpectralLayers1d(
n_channels=self.hidden_channels,
n_modes=self.grid_size,
linear_conv=True,
n_layers=self.layers,
activation=self.activation,
)
h = spectral_layers(h)
# the output is the down lifted hidden state
# (grid_size, hidden_channels) -> (grid_size, 1)
y = down_lifting(h)
return h, y
class FNORNN(nn.Module):
hidden_channels: int # number of hidden channels
grid_size: int # number of grid points
n_spectral_layers: int = 4 # number of spectral layers
out_channels: int = 1
length: int = (
None # length of the sequence. If None, the length is inferred from the input
)
activation: nn.Module = nn.relu
@nn.compact
def __call__(
self,
h0: jnp.ndarray, # initial hidden state (grid_size, statevars)
x: jnp.ndarray = None, # input sequence (timesteps, grid_size, 1)
) -> jnp.ndarray:
ScanFNOCell = nn.scan(
FNOCell,
variable_broadcast="params",
split_rngs={"params": False},
length=self.length,
)
scan = ScanFNOCell(
hidden_channels=self.hidden_channels,
grid_size=self.grid_size,
layers=self.n_spectral_layers,
out_channels=self.out_channels,
activation=self.activation,
)
up_lifting = nn.Dense(features=self.hidden_channels)
# We up lift the initial condition from (grid_size, 1) -> (grid_size, hidden_channels)
h0 = up_lifting(h0)
h, y = scan(h0, x)
return y:::
::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
jax.config.update("jax_platform_name", "cpu")
hidden_channels = 6
grid_size = 101
time_steps = 10
fno_rnn = FNORNN(
hidden_channels=hidden_channels,
grid_size=grid_size,
length=time_steps,
)
h0 = jnp.ones((grid_size, 1))
x = jnp.ones((time_steps, grid_size, 1))
params = fno_rnn.init(jax.random.PRNGKey(0), h0, x)
y = fno_rnn.apply(params, h0, x)
assert y.shape == x.shape:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class BatchFNORNN(nn.Module):
hidden_channels: int # number of hidden channels
grid_size: int # number of grid points
n_spectral_layers: int = 4 # number of spectral layers
out_channels: int = 1
length: int = (
None # length of the sequence. If None, the length is inferred from the input
)
activation: nn.Module = nn.relu
@nn.compact
def __call__(
self,
h0: jnp.ndarray, # initial hidden state (batch_size, grid_size, statevars)
x: jnp.ndarray = None, # input sequence (batch_size, timesteps, grid_size, 1)
) -> jnp.ndarray:
fnornn = nn.vmap(
FNORNN,
in_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
)
return fnornn(
hidden_channels=self.hidden_channels,
grid_size=self.grid_size,
n_spectral_layers=self.n_spectral_layers,
out_channels=self.out_channels,
length=self.length,
activation=self.activation,
)(h0, x):::
::: {#cell-9 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 3
time_steps = 10
x = jnp.ones((batch_size, time_steps, grid_size, 1))
h0 = jnp.ones((batch_size, grid_size, 1))
batch_fno_rnn = BatchFNORNN(
hidden_channels=hidden_channels,
grid_size=grid_size,
length=time_steps,
)
params = batch_fno_rnn.init(
jax.random.PRNGKey(0), h0, x
) # Why does it need to be initialised with the number of timesteps?
y = batch_fno_rnn.apply(params, h0, x)
# Print the shape of the output
print(y.shape):::