"jax_platform_name", "cpu") jax.config.update(
FNO embedded in a recurrent neural network
This notebook adapts the FNO for its use in a recurrent neural network. The idea is to use the FNO to learn the dynamics of a system, and then use the FNO as a layer in a recurrent neural network to learn the dynamics of the system over time.
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax.numpy as jnp
from physmodjax.models.fno import SpectralLayers1d
from flax import linen as nn
import jax
:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FNOCell(nn.Module):
"""
Parker's ARMA without input
"""
int
hidden_channels: int
grid_size: int = 4
layers: int = 1
out_channels: = nn.relu
activation: nn.Module
@nn.compact
def __call__(
self,
# hidden state (grid_size, hidden_channels)
h, # input (grid_size, 1)
x,
):= nn.Dense(features=self.out_channels)
down_lifting = SpectralLayers1d(
spectral_layers =self.hidden_channels,
n_channels=self.grid_size,
n_modes=True,
linear_conv=self.layers,
n_layers=self.activation,
activation
)
= spectral_layers(h)
h
# the output is the down lifted hidden state
# (grid_size, hidden_channels) -> (grid_size, 1)
= down_lifting(h)
y
return h, y
class FNORNN(nn.Module):
int # number of hidden channels
hidden_channels: int # number of grid points
grid_size: int = 4 # number of spectral layers
n_spectral_layers: int = 1
out_channels: int = (
length: None # length of the sequence. If None, the length is inferred from the input
)= nn.relu
activation: nn.Module
@nn.compact
def __call__(
self,
# initial hidden state (grid_size, statevars)
h0: jnp.ndarray, = None, # input sequence (timesteps, grid_size, 1)
x: jnp.ndarray -> jnp.ndarray:
) = nn.scan(
ScanFNOCell
FNOCell,="params",
variable_broadcast={"params": False},
split_rngs=self.length,
length
)
= ScanFNOCell(
scan =self.hidden_channels,
hidden_channels=self.grid_size,
grid_size=self.n_spectral_layers,
layers=self.out_channels,
out_channels=self.activation,
activation
)
= nn.Dense(features=self.hidden_channels)
up_lifting
# We up lift the initial condition from (grid_size, 1) -> (grid_size, hidden_channels)
= up_lifting(h0)
h0 = scan(h0, x)
h, y return y
:::
::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
"jax_platform_name", "cpu")
jax.config.update(= 6
hidden_channels = 101
grid_size = 10
time_steps
= FNORNN(
fno_rnn =hidden_channels,
hidden_channels=grid_size,
grid_size=time_steps,
length
)
= jnp.ones((grid_size, 1))
h0 = jnp.ones((time_steps, grid_size, 1))
x
= fno_rnn.init(jax.random.PRNGKey(0), h0, x)
params = fno_rnn.apply(params, h0, x)
y
assert y.shape == x.shape
:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class BatchFNORNN(nn.Module):
int # number of hidden channels
hidden_channels: int # number of grid points
grid_size: int = 4 # number of spectral layers
n_spectral_layers: int = 1
out_channels: int = (
length: None # length of the sequence. If None, the length is inferred from the input
)= nn.relu
activation: nn.Module
@nn.compact
def __call__(
self,
# initial hidden state (batch_size, grid_size, statevars)
h0: jnp.ndarray, = None, # input sequence (batch_size, timesteps, grid_size, 1)
x: jnp.ndarray -> jnp.ndarray:
) = nn.vmap(
fnornn
FNORNN,=0,
in_axes={"params": None},
variable_axes={"params": False},
split_rngs
)return fnornn(
=self.hidden_channels,
hidden_channels=self.grid_size,
grid_size=self.n_spectral_layers,
n_spectral_layers=self.out_channels,
out_channels=self.length,
length=self.activation,
activation )(h0, x)
:::
::: {#cell-9 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 3
batch_size = 10
time_steps = jnp.ones((batch_size, time_steps, grid_size, 1))
x = jnp.ones((batch_size, grid_size, 1))
h0
= BatchFNORNN(
batch_fno_rnn =hidden_channels,
hidden_channels=grid_size,
grid_size=time_steps,
length
)
= batch_fno_rnn.init(
params 0), h0, x
jax.random.PRNGKey(# Why does it need to be initialised with the number of timesteps?
) = batch_fno_rnn.apply(params, h0, x)
y # Print the shape of the output
print(y.shape)
:::