Fourier Neural Operator in 1D and 2D
Neural Operator
We are interested in learning the mapping between function spaces. In particular, we are interested in learning the mapping between the input space \(\Omega\) and the output space \(\Lambda\) of a function \(u: \Omega \rightarrow \Lambda\). We will assume that the input space is a subset of \(\mathbb{R}^d\) and the output space is a subset of \(\mathbb{R}^m\). We will also assume that the function \(u\) is smooth, i.e., it has a finite number of derivatives. We will denote the derivatives of \(u\) by \(u_{x_i}\), \(u_{x_i x_j}\), etc. We will also assume that the function \(u\) satisfies a partial differential equation (PDE) \(\mathcal{L} u = 0\) for some linear differential operator \(\mathcal{L}\).
A single layer of the neural operator is defined as follows:
\[ \begin{aligned} \mathcal{F} &:= \sigma \left(W +\mathcal{K} + b \right) \\ \mathcal{G}_\theta &:= \mathcal{Q} \circ \mathcal{F} \circ \mathcal{P} \end{aligned} \]
where
- \(\mathcal{P} : \mathbb{R^{in}} \to \mathbb{R^{hidden}}\) is a lifting layer
- \(\mathcal{Q} : \mathbb{R^{hidden}} \to \mathbb{R^{out}}\) is a projection layer
- \(\mathcal{F} \colon \mathbb{R^{hidden}} \to \mathbb{R^{hidden}}\) is the Neural Operator Layer with
- \(\mathcal{K}\) is one of several Kernel Operators
- \({W}\) is a matrix (local linear operator); a “skip connection” inspired by ResNet
- \(b\) is a “function” bias
The “Fourier” Neural operator
It takes the form of the linear transformation (convolution) of the Fourier coeffcients of the input function \(v\), and the kernel \(R_\phi\). The result is then transformed back using the inverse Fourier transform.
\[ \mathcal{K} = \mathcal{F}^{-1} (R_\phi \cdot \mathcal{F} (v)) \]
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import flax.linen as nn
from flax.linen.initializers import uniform
import jax.numpy as jnp
from einops import rearrange
from typing import Tuple
import jax
from physmodjax.utils.data import create_grid
:::
::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SpectralConv1d(nn.Module):
"""Spectral Convolution Layer for 1D inputs.
The n_modes parameter should be set to the length of the output for now, as it is not clear that the truncation is done correctly
"""
int # number of input channels (last dimension of input)
in_channels: int # number of output channels (last dimension of output)
d_vars: int # number of fourier modes to use
n_modes: bool = (
linear_conv: True # whether to use linear convolution or circular convolution
)
def setup(self):
= (self.in_channels, self.d_vars, self.n_modes)
weight_shape = 1 / (self.in_channels * self.d_vars)
scale
self.weight_real = self.param(
"weight_real",
=scale),
uniform(scale# cant use complex64
weight_shape,
)self.weight_imag = self.param(
"weight_imag",
=scale),
uniform(scale# cant use complex64
weight_shape,
)
def __call__(
self,
# (w, c)
x: jnp.ndarray,
):= x.shape
W, C
# get the fourier coefficients along the spatial dimension
# we pad the inputs so that we perform a linear convolution
= jnp.fft.rfft(x, n=W * 2 - 1, axis=-2, norm="ortho")
X
# truncate to the first n_modes coefficients
= X[: self.n_modes, :]
X
# multiply by the fourier coefficients of the kernel
= self.weight_real + 1j * self.weight_imag
complex_weight = jnp.einsum("ki,iok->ko", X, complex_weight)
X
# inverse fourier transform along dimension N and remove padding
= jnp.fft.irfft(X, axis=-2, norm="ortho")[:W]
x
return x
:::
Test
::: {#cell-11 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 1
batch_size = 2
in_channels = 2
d_vars = 10 # length of the input signal, also can be seen as the grid size
length = length
n_modes
= SpectralConv1d(
conv =in_channels,
in_channels=d_vars,
d_vars=n_modes,
n_modes=True,
linear_conv
)
= jax.random.PRNGKey(0)
rng = jax.random.uniform(rng, shape=(length, in_channels))
x
= conv.init(jax.random.PRNGKey(0), x=x)
params
= conv.apply(params, x)
y
assert y.shape == (length, d_vars)
:::
::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SpectralLayers1d(nn.Module):
"""Stack of 1D Spectral Convolution Layers"""
int # number of hidden channels
n_channels: int # number of fourier modes to keep
n_modes: bool = True # whether to use linear convolution
linear_conv: int = 4 # number of layers
n_layers: = nn.relu # activation function
activation: nn.Module
def setup(self):
self.layers_conv = [
SpectralConv1d(=self.n_channels,
in_channels=self.n_channels,
d_vars=self.n_modes,
n_modes=self.linear_conv,
linear_conv
)for _ in range(self.n_layers)
]
self.layers_w = [
=self.n_channels, kernel_size=(1,))
nn.Conv(featuresfor _ in range(self.n_layers)
]
def __call__(
self,
# (grid_points, channels)
x, -> jnp.ndarray: # (grid_points, channels)
) for conv, w in zip(self.layers_conv, self.layers_w):
= conv(x)
x1 = w(x)
x2 = self.activation(x1 + x2)
x
return x
:::
::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 6
hidden_channels = 101
grid_size
= SpectralLayers1d(
spectral_layers =hidden_channels,
n_channels=grid_size,
n_modes=True,
linear_conv=4,
n_layers=nn.relu,
activation
)= spectral_layers.init(
params 0), jnp.ones((grid_size, hidden_channels))
jax.random.PRNGKey(
)
= jnp.ones((grid_size, hidden_channels))
x = spectral_layers.apply(params, x)
y assert y.shape == x.shape
:::
Fourier Neural Operator in 1 Dimension
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FNO1D(nn.Module):
int # number of hidden channels
hidden_channels: int # number of fourier modes to keep
n_modes: int = 1 # number of output channels
d_vars: bool = True # whether to use linear convolution
linear_conv: int = 4 # number of layers
n_layers: int = None # number of steps to output
n_steps: = nn.gelu # activation function
activation: nn.Module str = ("layer",) # normalization layer
norm: bool = True # whether to train the model
training:
@nn.compact
def __call__(
self,
# input (T, W, C)
x,
):"""
The input to the FNO1D model is a 1D signal of shape (t, w, c)
where w is the spatial dimension and c is the number of channels.
The channel dimension is typically 1 for scalar fields. However, it can be
can also contain multiple time steps as channels or contain multiple scalar fields.
"""
# we need to make time as a channel dimension for the spectral layers
= rearrange(x, "t w c -> w (t c)")
x
= SpectralLayers1d(
spectral_layers =self.hidden_channels,
n_channels=self.n_modes,
n_modes=True,
linear_conv=self.n_layers,
n_layers=self.activation,
activation
)
= nn.Dense(features=self.hidden_channels)(
h
x# lift the input to the hidden state
) = spectral_layers(h)
h
# Down lift the hidden state to the output using a tiny mlp
= nn.Sequential(
y
[=128),
nn.Dense(featuresself.activation,
=self.d_vars * self.n_steps),
nn.Dense(features
]
)(h)
# rearrange the output to the original shape
= rearrange(y, "w (t c) -> t w c", t=self.n_steps, c=self.d_vars)
y
return y
= nn.vmap(
BatchedFNO1D
FNO1D,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs )
:::
Test
::: {#cell-17 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 1
time = 6
hidden_channels = 101
grid_size = 1
in_channels = 5
d_vars = 2
n_layers = 10
batch_size
= BatchedFNO1D(
batch_fno =hidden_channels,
hidden_channels=grid_size,
n_modes=d_vars,
d_vars=n_layers,
n_layers=1,
n_steps
)
= jax.random.PRNGKey(0)
rng = jnp.ones((batch_size, time, grid_size, in_channels))
x
= batch_fno.init(jax.random.PRNGKey(0), x)
params = batch_fno.apply(params, x)
y
# assert y.shape == x.shape
assert y.shape[-1] == d_vars
assert y.shape == (batch_size, time, grid_size, d_vars)
"params"])) display(jax.tree_util.tree_map(jnp.shape, params[
:::
Fourier Neural Operator in 2D
::: {#cell-19 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SpectralConv2d(nn.Module):
int
in_channels: int
out_channels: int # modes along the columns
n_modes1: int # modes along the rows
n_modes2:
def setup(self):
= (
weight_shape self.in_channels,
self.out_channels,
self.n_modes1,
self.n_modes2,
)
= 1 / (self.in_channels * self.out_channels)
scale
self.weight_1_real = self.param(
"weight_1_real",
=scale),
uniform(scale
weight_shape,
)
self.weight_1_imag = self.param(
"weight_1_imag",
=scale),
uniform(scale
weight_shape,
)
self.weight_2_real = self.param(
"weight_2_real",
=scale),
uniform(scale
weight_shape,
)
self.weight_2_imag = self.param(
"weight_2_imag",
=scale),
uniform(scale
weight_shape,
)
self.complex_weight_1 = self.weight_1_real + 1j * self.weight_1_imag
self.complex_weight_2 = self.weight_2_real + 1j * self.weight_2_imag
def __call__(
self,
# (H, W, C)
x: jnp.ndarray,
):"""
The input x is of shape (H, W, C), and we always perform a linear convolution
"""
= x.shape
H, W, C # get the fourier transform of the input
# along the first two dimensions
= jnp.fft.rfft2(x, s=(H * 2 - 1, W * 2 - 1), axes=(0, 1), norm="ortho")
X
# truncate the fourier transform
# to the first n_modes1, n_modes2 modes
# X -> (n_modes1, n_modes2, C)
# X = X[:self.n_modes1, :self.n_modes2, :]
# multiply the weights with the fourier transform
# This is a bit tricky. In the original implementation
# We multiply with two different weights
# along the first dimension from -n_modes1:n_modes1
# this is neccesary to cover the entire height
# this differs from parker's implementation
= jnp.einsum(
out_ft_up "xyi,ioxy->xyo",
self.n_modes1, : self.n_modes2, :],
X[: self.complex_weight_1,
)
= jnp.einsum(
out_ft_down "xyi,ioxy->xyo",
-self.n_modes1 :, : self.n_modes2, :],
X[self.complex_weight_2,
)
= jnp.concatenate((out_ft_up, out_ft_down), axis=0)
out_ft
# inverse fourier transform
# along the first two dimensions
= jnp.fft.irfft2(out_ft, s=(H, W), axes=(0, 1))
x
return x
:::
::: {#cell-20 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 1
batch_size = 2
in_channels = 2
out_channels = 10
height = 10
width = width // 2 + 1
n_modes
= SpectralConv2d(
conv =in_channels,
in_channels=out_channels,
out_channels=n_modes,
n_modes1=n_modes,
n_modes2
)
= jax.random.PRNGKey(0)
rng = jax.random.uniform(rng, shape=(height, width, in_channels))
x
= conv.init(rng, x=x)
params = conv.apply(params, x)
y
assert y.shape == (height, width, out_channels)
:::
::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FNO2D(nn.Module):
int # number of hidden channels
hidden_channels: int # number of fourier modes to keep
n_modes: int = 1 # number of output channels
d_vars: bool = True # whether to use linear convolution
linear_conv: int = 4 # number of layers
n_layers: int = None # number of steps to output
n_steps: = nn.gelu # activation function
activation: nn.Module int, int] = (41, 37) # (H, W) of the input for the grid
d_model: Tuple[bool = False # whether to use positions in the input
use_positions: str = "layer" # normalization layer
norm: bool = True
training:
def setup(self):
self.conv_layers = [
SpectralConv2d(=self.hidden_channels,
in_channels=self.hidden_channels,
out_channels=self.n_modes,
n_modes1=self.n_modes,
n_modes2
)for _ in range(self.n_layers)
]
# dense layers
# we use conv so that we don't have to shuffle the dimensions
self.w_layers = [
=self.hidden_channels, kernel_size=(1,))
nn.Conv(featuresfor _ in range(self.n_layers)
]
self.P = nn.Dense(
=self.hidden_channels,
features
)
# TODO: in the original implementation this is a tiny mlp
# self.Q = nn.Dense(
# features=self.out_channels,
# )
self.Q = nn.Sequential(
[=128),
nn.Dense(featuresself.activation,
=self.d_vars * self.n_steps),
nn.Dense(features
]
)
if self.use_positions:
self.grid = create_grid(self.d_model[1], self.d_model[0])
def advance(
self,
# (h, w, (t c))
x: jnp.ndarray, -> jnp.ndarray:
) """
The input x is of shape (H, W, C), and we always perform a linear convolution
"""
if self.use_positions:
= jnp.concatenate((x, self.grid), axis=-1)
x
# lifting layer works on the last dimension
= self.P(x)
x
for conv, w in zip(self.conv_layers, self.w_layers):
= conv(x)
x1 = w(x)
x2 = self.activation(x1 + x2)
x
= self.Q(x)
x
return x
def __call__(
self,
# (t, h, w, c)
x, -> jnp.ndarray:
) """
The input x is of shape (T, H, W, C).
We always map from a single timestep to one or more timesteps.
The FNO2D can map from many-to-many timesteps, in which case these
are concatenated along the channel dimension.
"""
# we need to rearrange the dimensions
# will work only with 1 variable
# this is equivalent to the temporal bundling trick
= rearrange(x, "t h w c -> h w (t c)")
x
= self.advance(x)
x
= rearrange(x, "h w (t c) -> t h w c", t=self.n_steps, c=self.d_vars)
x
return x
= nn.vmap(
BatchedFNO2D
FNO2D,=0,
in_axes=0,
out_axes={"params": None},
variable_axes={"params": False},
split_rngs )
:::
Test
::: {#cell-23 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
= 1
T_in = 9
T_out = 10, T_in, 32, 32, 2
B, T, H, W, C = 10
hidden_channels = 16
n_modes
= jax.random.PRNGKey(0)
rng = jax.random.uniform(rng, shape=(B, T, H, W, C))
x
= BatchedFNO2D(
batched_fno =hidden_channels,
hidden_channels=C,
d_vars=T_out,
n_steps=n_modes,
n_modes=False,
use_positions=(H, W),
d_model
)= batched_fno.init(rng, x)
params
= batched_fno.apply(params, x)
y print(y.shape)
assert y.shape == (B, T_out, H, W, C)
"params"])) display(jax.tree_util.tree_map(jnp.shape, params[
:::