Fourier Neural Operator in 1D and 2D

Neural Operator

We are interested in learning the mapping between function spaces. In particular, we are interested in learning the mapping between the input space \(\Omega\) and the output space \(\Lambda\) of a function \(u: \Omega \rightarrow \Lambda\). We will assume that the input space is a subset of \(\mathbb{R}^d\) and the output space is a subset of \(\mathbb{R}^m\). We will also assume that the function \(u\) is smooth, i.e., it has a finite number of derivatives. We will denote the derivatives of \(u\) by \(u_{x_i}\), \(u_{x_i x_j}\), etc. We will also assume that the function \(u\) satisfies a partial differential equation (PDE) \(\mathcal{L} u = 0\) for some linear differential operator \(\mathcal{L}\).

A single layer of the neural operator is defined as follows:

\[ \begin{aligned} \mathcal{F} &:= \sigma \left(W +\mathcal{K} + b \right) \\ \mathcal{G}_\theta &:= \mathcal{Q} \circ \mathcal{F} \circ \mathcal{P} \end{aligned} \]

where

  • \(\mathcal{P} : \mathbb{R^{in}} \to \mathbb{R^{hidden}}\) is a lifting layer
  • \(\mathcal{Q} : \mathbb{R^{hidden}} \to \mathbb{R^{out}}\) is a projection layer
  • \(\mathcal{F} \colon \mathbb{R^{hidden}} \to \mathbb{R^{hidden}}\) is the Neural Operator Layer with
    • \(\mathcal{K}\) is one of several Kernel Operators
    • \({W}\) is a matrix (local linear operator); a “skip connection” inspired by ResNet
    • \(b\) is a “function” bias

The “Fourier” Neural operator

It takes the form of the linear transformation (convolution) of the Fourier coeffcients of the input function \(v\), and the kernel \(R_\phi\). The result is then transformed back using the inverse Fourier transform.

\[ \mathcal{K} = \mathcal{F}^{-1} (R_\phi \cdot \mathcal{F} (v)) \]

::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import flax.linen as nn
from flax.linen.initializers import uniform
import jax.numpy as jnp
from einops import rearrange
from typing import Tuple
import jax
from physmodjax.utils.data import create_grid

:::

::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class SpectralConv1d(nn.Module):
    """Spectral Convolution Layer for 1D inputs.
    The n_modes parameter should be set to the length of the output for now, as it is not clear that the truncation is done correctly
    """

    in_channels: int  # number of input channels (last dimension of input)
    d_vars: int  # number of output channels (last dimension of output)
    n_modes: int  # number of fourier modes to use
    linear_conv: bool = (
        True  # whether to use linear convolution or circular convolution
    )

    def setup(self):
        weight_shape = (self.in_channels, self.d_vars, self.n_modes)
        scale = 1 / (self.in_channels * self.d_vars)

        self.weight_real = self.param(
            "weight_real",
            uniform(scale=scale),
            weight_shape,  # cant use complex64
        )
        self.weight_imag = self.param(
            "weight_imag",
            uniform(scale=scale),
            weight_shape,  # cant use complex64
        )

    def __call__(
        self,
        x: jnp.ndarray,  # (w, c)
    ):
        W, C = x.shape

        # get the fourier coefficients along the spatial dimension
        # we pad the inputs so that we perform a linear convolution
        X = jnp.fft.rfft(x, n=W * 2 - 1, axis=-2, norm="ortho")

        # truncate to the first n_modes coefficients
        X = X[: self.n_modes, :]

        # multiply by the fourier coefficients of the kernel
        complex_weight = self.weight_real + 1j * self.weight_imag
        X = jnp.einsum("ki,iok->ko", X, complex_weight)

        # inverse fourier transform along dimension N and remove padding
        x = jnp.fft.irfft(X, axis=-2, norm="ortho")[:W]

        return x

:::

Test

::: {#cell-11 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 1
in_channels = 2
d_vars = 2
length = 10  # length of the input signal, also can be seen as the grid size
n_modes = length

conv = SpectralConv1d(
    in_channels=in_channels,
    d_vars=d_vars,
    n_modes=n_modes,
    linear_conv=True,
)

rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(length, in_channels))

params = conv.init(jax.random.PRNGKey(0), x=x)

y = conv.apply(params, x)

assert y.shape == (length, d_vars)

:::

::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class SpectralLayers1d(nn.Module):
    """Stack of 1D Spectral Convolution Layers"""

    n_channels: int  # number of hidden channels
    n_modes: int  # number of fourier modes to keep
    linear_conv: bool = True  # whether to use linear convolution
    n_layers: int = 4  # number of layers
    activation: nn.Module = nn.relu  # activation function

    def setup(self):
        self.layers_conv = [
            SpectralConv1d(
                in_channels=self.n_channels,
                d_vars=self.n_channels,
                n_modes=self.n_modes,
                linear_conv=self.linear_conv,
            )
            for _ in range(self.n_layers)
        ]

        self.layers_w = [
            nn.Conv(features=self.n_channels, kernel_size=(1,))
            for _ in range(self.n_layers)
        ]

    def __call__(
        self,
        x,  # (grid_points, channels)
    ) -> jnp.ndarray:  # (grid_points, channels)
        for conv, w in zip(self.layers_conv, self.layers_w):
            x1 = conv(x)
            x2 = w(x)
            x = self.activation(x1 + x2)

        return x

:::

::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

hidden_channels = 6
grid_size = 101

spectral_layers = SpectralLayers1d(
    n_channels=hidden_channels,
    n_modes=grid_size,
    linear_conv=True,
    n_layers=4,
    activation=nn.relu,
)
params = spectral_layers.init(
    jax.random.PRNGKey(0), jnp.ones((grid_size, hidden_channels))
)

x = jnp.ones((grid_size, hidden_channels))
y = spectral_layers.apply(params, x)
assert y.shape == x.shape

:::

Fourier Neural Operator in 1 Dimension

::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class FNO1D(nn.Module):
    hidden_channels: int  # number of hidden channels
    n_modes: int  # number of fourier modes to keep
    d_vars: int = 1  # number of output channels
    linear_conv: bool = True  # whether to use linear convolution
    n_layers: int = 4  # number of layers
    n_steps: int = None  # number of steps to output
    activation: nn.Module = nn.gelu  # activation function
    norm: str = ("layer",)  # normalization layer
    training: bool = True  # whether to train the model

    @nn.compact
    def __call__(
        self,
        x,  # input (T, W, C)
    ):
        """
        The input to the FNO1D model is a 1D signal of shape (t, w, c)
        where w is the spatial dimension and c is the number of channels.
        The channel dimension is typically 1 for scalar fields. However, it can be
        can also contain multiple time steps as channels or contain multiple scalar fields.
        """

        # we need to make time as a channel dimension for the spectral layers
        x = rearrange(x, "t w c -> w (t c)")

        spectral_layers = SpectralLayers1d(
            n_channels=self.hidden_channels,
            n_modes=self.n_modes,
            linear_conv=True,
            n_layers=self.n_layers,
            activation=self.activation,
        )

        h = nn.Dense(features=self.hidden_channels)(
            x
        )  # lift the input to the hidden state
        h = spectral_layers(h)

        # Down lift the hidden state to the output using a tiny mlp
        y = nn.Sequential(
            [
                nn.Dense(features=128),
                self.activation,
                nn.Dense(features=self.d_vars * self.n_steps),
            ]
        )(h)

        # rearrange the output to the original shape
        y = rearrange(y, "w (t c) -> t w c", t=self.n_steps, c=self.d_vars)

        return y


BatchedFNO1D = nn.vmap(
    FNO1D,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
)

:::

Test

::: {#cell-17 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

time = 1
hidden_channels = 6
grid_size = 101
in_channels = 1
d_vars = 5
n_layers = 2
batch_size = 10

batch_fno = BatchedFNO1D(
    hidden_channels=hidden_channels,
    n_modes=grid_size,
    d_vars=d_vars,
    n_layers=n_layers,
    n_steps=1,
)

rng = jax.random.PRNGKey(0)
x = jnp.ones((batch_size, time, grid_size, in_channels))

params = batch_fno.init(jax.random.PRNGKey(0), x)
y = batch_fno.apply(params, x)

# assert y.shape == x.shape
assert y.shape[-1] == d_vars
assert y.shape == (batch_size, time, grid_size, d_vars)
display(jax.tree_util.tree_map(jnp.shape, params["params"]))

:::

Fourier Neural Operator in 2D

::: {#cell-19 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class SpectralConv2d(nn.Module):
    in_channels: int
    out_channels: int
    n_modes1: int  # modes along the columns
    n_modes2: int  # modes along the rows

    def setup(self):
        weight_shape = (
            self.in_channels,
            self.out_channels,
            self.n_modes1,
            self.n_modes2,
        )

        scale = 1 / (self.in_channels * self.out_channels)

        self.weight_1_real = self.param(
            "weight_1_real",
            uniform(scale=scale),
            weight_shape,
        )

        self.weight_1_imag = self.param(
            "weight_1_imag",
            uniform(scale=scale),
            weight_shape,
        )

        self.weight_2_real = self.param(
            "weight_2_real",
            uniform(scale=scale),
            weight_shape,
        )

        self.weight_2_imag = self.param(
            "weight_2_imag",
            uniform(scale=scale),
            weight_shape,
        )

        self.complex_weight_1 = self.weight_1_real + 1j * self.weight_1_imag
        self.complex_weight_2 = self.weight_2_real + 1j * self.weight_2_imag

    def __call__(
        self,
        x: jnp.ndarray,  # (H, W, C)
    ):
        """
        The input x is of shape (H, W, C), and we always perform a linear convolution
        """

        H, W, C = x.shape
        # get the fourier transform of the input
        # along the first two dimensions
        X = jnp.fft.rfft2(x, s=(H * 2 - 1, W * 2 - 1), axes=(0, 1), norm="ortho")

        # truncate the fourier transform
        # to the first n_modes1, n_modes2 modes
        # X -> (n_modes1, n_modes2, C)
        # X = X[:self.n_modes1, :self.n_modes2, :]

        # multiply the weights with the fourier transform
        # This is a bit tricky. In the original implementation
        # We multiply with two different weights
        # along the first dimension from -n_modes1:n_modes1
        # this is neccesary to cover the entire height
        # this differs from parker's implementation
        out_ft_up = jnp.einsum(
            "xyi,ioxy->xyo",
            X[: self.n_modes1, : self.n_modes2, :],
            self.complex_weight_1,
        )

        out_ft_down = jnp.einsum(
            "xyi,ioxy->xyo",
            X[-self.n_modes1 :, : self.n_modes2, :],
            self.complex_weight_2,
        )

        out_ft = jnp.concatenate((out_ft_up, out_ft_down), axis=0)

        # inverse fourier transform
        # along the first two dimensions
        x = jnp.fft.irfft2(out_ft, s=(H, W), axes=(0, 1))

        return x

:::

::: {#cell-20 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

batch_size = 1
in_channels = 2
out_channels = 2
height = 10
width = 10
n_modes = width // 2 + 1

conv = SpectralConv2d(
    in_channels=in_channels,
    out_channels=out_channels,
    n_modes1=n_modes,
    n_modes2=n_modes,
)

rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(height, width, in_channels))

params = conv.init(rng, x=x)
y = conv.apply(params, x)

assert y.shape == (height, width, out_channels)

:::

::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class FNO2D(nn.Module):
    hidden_channels: int  # number of hidden channels
    n_modes: int  # number of fourier modes to keep
    d_vars: int = 1  # number of output channels
    linear_conv: bool = True  # whether to use linear convolution
    n_layers: int = 4  # number of layers
    n_steps: int = None  # number of steps to output
    activation: nn.Module = nn.gelu  # activation function
    d_model: Tuple[int, int] = (41, 37)  # (H, W) of the input for the grid
    use_positions: bool = False  # whether to use positions in the input
    norm: str = "layer"  # normalization layer
    training: bool = True

    def setup(self):
        self.conv_layers = [
            SpectralConv2d(
                in_channels=self.hidden_channels,
                out_channels=self.hidden_channels,
                n_modes1=self.n_modes,
                n_modes2=self.n_modes,
            )
            for _ in range(self.n_layers)
        ]

        # dense layers
        # we use conv so that we don't have to shuffle the dimensions
        self.w_layers = [
            nn.Conv(features=self.hidden_channels, kernel_size=(1,))
            for _ in range(self.n_layers)
        ]

        self.P = nn.Dense(
            features=self.hidden_channels,
        )

        # TODO: in the original implementation this is a tiny mlp
        # self.Q = nn.Dense(
        #     features=self.out_channels,
        # )
        self.Q = nn.Sequential(
            [
                nn.Dense(features=128),
                self.activation,
                nn.Dense(features=self.d_vars * self.n_steps),
            ]
        )

        if self.use_positions:
            self.grid = create_grid(self.d_model[1], self.d_model[0])

    def advance(
        self,
        x: jnp.ndarray,  # (h, w, (t c))
    ) -> jnp.ndarray:
        """
        The input x is of shape (H, W, C), and we always perform a linear convolution
        """
        if self.use_positions:
            x = jnp.concatenate((x, self.grid), axis=-1)

        # lifting layer works on the last dimension
        x = self.P(x)

        for conv, w in zip(self.conv_layers, self.w_layers):
            x1 = conv(x)
            x2 = w(x)
            x = self.activation(x1 + x2)

        x = self.Q(x)

        return x

    def __call__(
        self,
        x,  # (t, h, w, c)
    ) -> jnp.ndarray:
        """
        The input x is of shape (T, H, W, C).
        We always map from a single timestep to one or more timesteps.
        The FNO2D can map from many-to-many timesteps, in which case these
        are concatenated along the channel dimension.
        """
        # we need to rearrange the dimensions
        # will work only with 1 variable
        # this is equivalent to the temporal bundling trick
        x = rearrange(x, "t h w c -> h w (t c)")

        x = self.advance(x)

        x = rearrange(x, "h w (t c) -> t h w c", t=self.n_steps, c=self.d_vars)

        return x


BatchedFNO2D = nn.vmap(
    FNO2D,
    in_axes=0,
    out_axes=0,
    variable_axes={"params": None},
    split_rngs={"params": False},
)

:::

Test

::: {#cell-23 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}

T_in = 1
T_out = 9
B, T, H, W, C = 10, T_in, 32, 32, 2
hidden_channels = 10
n_modes = 16

rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(B, T, H, W, C))

batched_fno = BatchedFNO2D(
    hidden_channels=hidden_channels,
    d_vars=C,
    n_steps=T_out,
    n_modes=n_modes,
    use_positions=False,
    d_model=(H, W),
)
params = batched_fno.init(rng, x)

y = batched_fno.apply(params, x)
print(y.shape)
assert y.shape == (B, T_out, H, W, C)
display(jax.tree_util.tree_map(jnp.shape, params["params"]))

:::