FNO embedded in a recurrent neural network

This notebook adapts the FNO for its use in a recurrent neural network. The idea is to use the FNO to learn the dynamics of a system, and then use the FNO as a layer in a recurrent neural network to learn the dynamics of the system over time.

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import jax.numpy as jnp
from physmodjax.models.fno import SpectralLayers1d
from flax import linen as nn
import jax

:::

jax.config.update("jax_platform_name", "cpu")

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class FNOCell(nn.Module):
    """
    Parker's ARMA without input
    """

    hidden_channels: int
    grid_size: int
    layers: int = 4
    out_channels: int = 1
    activation: nn.Module = nn.relu

    @nn.compact
    def __call__(
        self,
        h,  # hidden state (grid_size, hidden_channels)
        x,  # input (grid_size, 1)
    ):
        down_lifting = nn.Dense(features=self.out_channels)
        spectral_layers = SpectralLayers1d(
            n_channels=self.hidden_channels,
            n_modes=self.grid_size,
            linear_conv=True,
            n_layers=self.layers,
            activation=self.activation,
        )

        h = spectral_layers(h)

        # the output is the down lifted hidden state
        # (grid_size, hidden_channels) -> (grid_size, 1)
        y = down_lifting(h)

        return h, y


class FNORNN(nn.Module):
    hidden_channels: int  # number of hidden channels
    grid_size: int  # number of grid points
    n_spectral_layers: int = 4  # number of spectral layers
    out_channels: int = 1
    length: int = (
        None  # length of the sequence. If None, the length is inferred from the input
    )
    activation: nn.Module = nn.relu

    @nn.compact
    def __call__(
        self,
        h0: jnp.ndarray,  # initial hidden state (grid_size, statevars)
        x: jnp.ndarray = None,  # input sequence (timesteps, grid_size, 1)
    ) -> jnp.ndarray:
        ScanFNOCell = nn.scan(
            FNOCell,
            variable_broadcast="params",
            split_rngs={"params": False},
            length=self.length,
        )

        scan = ScanFNOCell(
            hidden_channels=self.hidden_channels,
            grid_size=self.grid_size,
            layers=self.n_spectral_layers,
            out_channels=self.out_channels,
            activation=self.activation,
        )

        up_lifting = nn.Dense(features=self.hidden_channels)

        # We up lift the initial condition from (grid_size, 1) -> (grid_size, hidden_channels)
        h0 = up_lifting(h0)
        h, y = scan(h0, x)
        return y

:::

::: {#cell-7 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

jax.config.update("jax_platform_name", "cpu")
hidden_channels = 6
grid_size = 101
time_steps = 10

fno_rnn = FNORNN(
    hidden_channels=hidden_channels,
    grid_size=grid_size,
    length=time_steps,
)

h0 = jnp.ones((grid_size, 1))
x = jnp.ones((time_steps, grid_size, 1))

params = fno_rnn.init(jax.random.PRNGKey(0), h0, x)
y = fno_rnn.apply(params, h0, x)

assert y.shape == x.shape

:::

::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class BatchFNORNN(nn.Module):
    hidden_channels: int  # number of hidden channels
    grid_size: int  # number of grid points
    n_spectral_layers: int = 4  # number of spectral layers
    out_channels: int = 1
    length: int = (
        None  # length of the sequence. If None, the length is inferred from the input
    )
    activation: nn.Module = nn.relu

    @nn.compact
    def __call__(
        self,
        h0: jnp.ndarray,  # initial hidden state (batch_size, grid_size, statevars)
        x: jnp.ndarray = None,  # input sequence (batch_size, timesteps, grid_size, 1)
    ) -> jnp.ndarray:
        fnornn = nn.vmap(
            FNORNN,
            in_axes=0,
            variable_axes={"params": None},
            split_rngs={"params": False},
        )
        return fnornn(
            hidden_channels=self.hidden_channels,
            grid_size=self.grid_size,
            n_spectral_layers=self.n_spectral_layers,
            out_channels=self.out_channels,
            length=self.length,
            activation=self.activation,
        )(h0, x)

:::

::: {#cell-9 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 3
time_steps = 10
x = jnp.ones((batch_size, time_steps, grid_size, 1))
h0 = jnp.ones((batch_size, grid_size, 1))

batch_fno_rnn = BatchFNORNN(
    hidden_channels=hidden_channels,
    grid_size=grid_size,
    length=time_steps,
)

params = batch_fno_rnn.init(
    jax.random.PRNGKey(0), h0, x
)  # Why does it need to be initialised with the number of timesteps?
y = batch_fno_rnn.apply(params, h0, x)
# Print the shape of the output
print(y.shape)

:::