Fourier Neural Operator in 1D and 2D
Neural Operator
We are interested in learning the mapping between function spaces. In particular, we are interested in learning the mapping between the input space \(\Omega\) and the output space \(\Lambda\) of a function \(u: \Omega \rightarrow \Lambda\). We will assume that the input space is a subset of \(\mathbb{R}^d\) and the output space is a subset of \(\mathbb{R}^m\). We will also assume that the function \(u\) is smooth, i.e., it has a finite number of derivatives. We will denote the derivatives of \(u\) by \(u_{x_i}\), \(u_{x_i x_j}\), etc. We will also assume that the function \(u\) satisfies a partial differential equation (PDE) \(\mathcal{L} u = 0\) for some linear differential operator \(\mathcal{L}\).
A single layer of the neural operator is defined as follows:
\[ \begin{aligned} \mathcal{F} &:= \sigma \left(W +\mathcal{K} + b \right) \\ \mathcal{G}_\theta &:= \mathcal{Q} \circ \mathcal{F} \circ \mathcal{P} \end{aligned} \]
where
- \(\mathcal{P} : \mathbb{R^{in}} \to \mathbb{R^{hidden}}\) is a lifting layer
- \(\mathcal{Q} : \mathbb{R^{hidden}} \to \mathbb{R^{out}}\) is a projection layer
- \(\mathcal{F} \colon \mathbb{R^{hidden}} \to \mathbb{R^{hidden}}\) is the Neural Operator Layer with
- \(\mathcal{K}\) is one of several Kernel Operators
- \({W}\) is a matrix (local linear operator); a “skip connection” inspired by ResNet
- \(b\) is a “function” bias
The “Fourier” Neural operator
It takes the form of the linear transformation (convolution) of the Fourier coeffcients of the input function \(v\), and the kernel \(R_\phi\). The result is then transformed back using the inverse Fourier transform.
\[ \mathcal{K} = \mathcal{F}^{-1} (R_\phi \cdot \mathcal{F} (v)) \]
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import flax.linen as nn
from flax.linen.initializers import uniform
import jax.numpy as jnp
from einops import rearrange
from typing import Tuple
import jax
from physmodjax.utils.data import create_grid:::
::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SpectralConv1d(nn.Module):
"""Spectral Convolution Layer for 1D inputs.
The n_modes parameter should be set to the length of the output for now, as it is not clear that the truncation is done correctly
"""
in_channels: int # number of input channels (last dimension of input)
d_vars: int # number of output channels (last dimension of output)
n_modes: int # number of fourier modes to use
linear_conv: bool = (
True # whether to use linear convolution or circular convolution
)
def setup(self):
weight_shape = (self.in_channels, self.d_vars, self.n_modes)
scale = 1 / (self.in_channels * self.d_vars)
self.weight_real = self.param(
"weight_real",
uniform(scale=scale),
weight_shape, # cant use complex64
)
self.weight_imag = self.param(
"weight_imag",
uniform(scale=scale),
weight_shape, # cant use complex64
)
def __call__(
self,
x: jnp.ndarray, # (w, c)
):
W, C = x.shape
# get the fourier coefficients along the spatial dimension
# we pad the inputs so that we perform a linear convolution
X = jnp.fft.rfft(x, n=W * 2 - 1, axis=-2, norm="ortho")
# truncate to the first n_modes coefficients
X = X[: self.n_modes, :]
# multiply by the fourier coefficients of the kernel
complex_weight = self.weight_real + 1j * self.weight_imag
X = jnp.einsum("ki,iok->ko", X, complex_weight)
# inverse fourier transform along dimension N and remove padding
x = jnp.fft.irfft(X, axis=-2, norm="ortho")[:W]
return x:::
Test
::: {#cell-11 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 1
in_channels = 2
d_vars = 2
length = 10 # length of the input signal, also can be seen as the grid size
n_modes = length
conv = SpectralConv1d(
in_channels=in_channels,
d_vars=d_vars,
n_modes=n_modes,
linear_conv=True,
)
rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(length, in_channels))
params = conv.init(jax.random.PRNGKey(0), x=x)
y = conv.apply(params, x)
assert y.shape == (length, d_vars):::
::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SpectralLayers1d(nn.Module):
"""Stack of 1D Spectral Convolution Layers"""
n_channels: int # number of hidden channels
n_modes: int # number of fourier modes to keep
linear_conv: bool = True # whether to use linear convolution
n_layers: int = 4 # number of layers
activation: nn.Module = nn.relu # activation function
def setup(self):
self.layers_conv = [
SpectralConv1d(
in_channels=self.n_channels,
d_vars=self.n_channels,
n_modes=self.n_modes,
linear_conv=self.linear_conv,
)
for _ in range(self.n_layers)
]
self.layers_w = [
nn.Conv(features=self.n_channels, kernel_size=(1,))
for _ in range(self.n_layers)
]
def __call__(
self,
x, # (grid_points, channels)
) -> jnp.ndarray: # (grid_points, channels)
for conv, w in zip(self.layers_conv, self.layers_w):
x1 = conv(x)
x2 = w(x)
x = self.activation(x1 + x2)
return x:::
::: {#cell-13 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
hidden_channels = 6
grid_size = 101
spectral_layers = SpectralLayers1d(
n_channels=hidden_channels,
n_modes=grid_size,
linear_conv=True,
n_layers=4,
activation=nn.relu,
)
params = spectral_layers.init(
jax.random.PRNGKey(0), jnp.ones((grid_size, hidden_channels))
)
x = jnp.ones((grid_size, hidden_channels))
y = spectral_layers.apply(params, x)
assert y.shape == x.shape:::
Fourier Neural Operator in 1 Dimension
::: {#cell-15 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FNO1D(nn.Module):
hidden_channels: int # number of hidden channels
n_modes: int # number of fourier modes to keep
d_vars: int = 1 # number of output channels
linear_conv: bool = True # whether to use linear convolution
n_layers: int = 4 # number of layers
n_steps: int = None # number of steps to output
activation: nn.Module = nn.gelu # activation function
norm: str = ("layer",) # normalization layer
training: bool = True # whether to train the model
@nn.compact
def __call__(
self,
x, # input (T, W, C)
):
"""
The input to the FNO1D model is a 1D signal of shape (t, w, c)
where w is the spatial dimension and c is the number of channels.
The channel dimension is typically 1 for scalar fields. However, it can be
can also contain multiple time steps as channels or contain multiple scalar fields.
"""
# we need to make time as a channel dimension for the spectral layers
x = rearrange(x, "t w c -> w (t c)")
spectral_layers = SpectralLayers1d(
n_channels=self.hidden_channels,
n_modes=self.n_modes,
linear_conv=True,
n_layers=self.n_layers,
activation=self.activation,
)
h = nn.Dense(features=self.hidden_channels)(
x
) # lift the input to the hidden state
h = spectral_layers(h)
# Down lift the hidden state to the output using a tiny mlp
y = nn.Sequential(
[
nn.Dense(features=128),
self.activation,
nn.Dense(features=self.d_vars * self.n_steps),
]
)(h)
# rearrange the output to the original shape
y = rearrange(y, "w (t c) -> t w c", t=self.n_steps, c=self.d_vars)
return y
BatchedFNO1D = nn.vmap(
FNO1D,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
):::
Test
::: {#cell-17 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
time = 1
hidden_channels = 6
grid_size = 101
in_channels = 1
d_vars = 5
n_layers = 2
batch_size = 10
batch_fno = BatchedFNO1D(
hidden_channels=hidden_channels,
n_modes=grid_size,
d_vars=d_vars,
n_layers=n_layers,
n_steps=1,
)
rng = jax.random.PRNGKey(0)
x = jnp.ones((batch_size, time, grid_size, in_channels))
params = batch_fno.init(jax.random.PRNGKey(0), x)
y = batch_fno.apply(params, x)
# assert y.shape == x.shape
assert y.shape[-1] == d_vars
assert y.shape == (batch_size, time, grid_size, d_vars)
display(jax.tree_util.tree_map(jnp.shape, params["params"])):::
Fourier Neural Operator in 2D
::: {#cell-19 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class SpectralConv2d(nn.Module):
in_channels: int
out_channels: int
n_modes1: int # modes along the columns
n_modes2: int # modes along the rows
def setup(self):
weight_shape = (
self.in_channels,
self.out_channels,
self.n_modes1,
self.n_modes2,
)
scale = 1 / (self.in_channels * self.out_channels)
self.weight_1_real = self.param(
"weight_1_real",
uniform(scale=scale),
weight_shape,
)
self.weight_1_imag = self.param(
"weight_1_imag",
uniform(scale=scale),
weight_shape,
)
self.weight_2_real = self.param(
"weight_2_real",
uniform(scale=scale),
weight_shape,
)
self.weight_2_imag = self.param(
"weight_2_imag",
uniform(scale=scale),
weight_shape,
)
self.complex_weight_1 = self.weight_1_real + 1j * self.weight_1_imag
self.complex_weight_2 = self.weight_2_real + 1j * self.weight_2_imag
def __call__(
self,
x: jnp.ndarray, # (H, W, C)
):
"""
The input x is of shape (H, W, C), and we always perform a linear convolution
"""
H, W, C = x.shape
# get the fourier transform of the input
# along the first two dimensions
X = jnp.fft.rfft2(x, s=(H * 2 - 1, W * 2 - 1), axes=(0, 1), norm="ortho")
# truncate the fourier transform
# to the first n_modes1, n_modes2 modes
# X -> (n_modes1, n_modes2, C)
# X = X[:self.n_modes1, :self.n_modes2, :]
# multiply the weights with the fourier transform
# This is a bit tricky. In the original implementation
# We multiply with two different weights
# along the first dimension from -n_modes1:n_modes1
# this is neccesary to cover the entire height
# this differs from parker's implementation
out_ft_up = jnp.einsum(
"xyi,ioxy->xyo",
X[: self.n_modes1, : self.n_modes2, :],
self.complex_weight_1,
)
out_ft_down = jnp.einsum(
"xyi,ioxy->xyo",
X[-self.n_modes1 :, : self.n_modes2, :],
self.complex_weight_2,
)
out_ft = jnp.concatenate((out_ft_up, out_ft_down), axis=0)
# inverse fourier transform
# along the first two dimensions
x = jnp.fft.irfft2(out_ft, s=(H, W), axes=(0, 1))
return x:::
::: {#cell-20 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
batch_size = 1
in_channels = 2
out_channels = 2
height = 10
width = 10
n_modes = width // 2 + 1
conv = SpectralConv2d(
in_channels=in_channels,
out_channels=out_channels,
n_modes1=n_modes,
n_modes2=n_modes,
)
rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(height, width, in_channels))
params = conv.init(rng, x=x)
y = conv.apply(params, x)
assert y.shape == (height, width, out_channels):::
::: {#cell-21 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class FNO2D(nn.Module):
hidden_channels: int # number of hidden channels
n_modes: int # number of fourier modes to keep
d_vars: int = 1 # number of output channels
linear_conv: bool = True # whether to use linear convolution
n_layers: int = 4 # number of layers
n_steps: int = None # number of steps to output
activation: nn.Module = nn.gelu # activation function
d_model: Tuple[int, int] = (41, 37) # (H, W) of the input for the grid
use_positions: bool = False # whether to use positions in the input
norm: str = "layer" # normalization layer
training: bool = True
def setup(self):
self.conv_layers = [
SpectralConv2d(
in_channels=self.hidden_channels,
out_channels=self.hidden_channels,
n_modes1=self.n_modes,
n_modes2=self.n_modes,
)
for _ in range(self.n_layers)
]
# dense layers
# we use conv so that we don't have to shuffle the dimensions
self.w_layers = [
nn.Conv(features=self.hidden_channels, kernel_size=(1,))
for _ in range(self.n_layers)
]
self.P = nn.Dense(
features=self.hidden_channels,
)
# TODO: in the original implementation this is a tiny mlp
# self.Q = nn.Dense(
# features=self.out_channels,
# )
self.Q = nn.Sequential(
[
nn.Dense(features=128),
self.activation,
nn.Dense(features=self.d_vars * self.n_steps),
]
)
if self.use_positions:
self.grid = create_grid(self.d_model[1], self.d_model[0])
def advance(
self,
x: jnp.ndarray, # (h, w, (t c))
) -> jnp.ndarray:
"""
The input x is of shape (H, W, C), and we always perform a linear convolution
"""
if self.use_positions:
x = jnp.concatenate((x, self.grid), axis=-1)
# lifting layer works on the last dimension
x = self.P(x)
for conv, w in zip(self.conv_layers, self.w_layers):
x1 = conv(x)
x2 = w(x)
x = self.activation(x1 + x2)
x = self.Q(x)
return x
def __call__(
self,
x, # (t, h, w, c)
) -> jnp.ndarray:
"""
The input x is of shape (T, H, W, C).
We always map from a single timestep to one or more timesteps.
The FNO2D can map from many-to-many timesteps, in which case these
are concatenated along the channel dimension.
"""
# we need to rearrange the dimensions
# will work only with 1 variable
# this is equivalent to the temporal bundling trick
x = rearrange(x, "t h w c -> h w (t c)")
x = self.advance(x)
x = rearrange(x, "h w (t c) -> t h w c", t=self.n_steps, c=self.d_vars)
return x
BatchedFNO2D = nn.vmap(
FNO2D,
in_axes=0,
out_axes=0,
variable_axes={"params": None},
split_rngs={"params": False},
):::
Test
::: {#cell-23 .cell 0=‘t’ 1=‘e’ 2=‘s’ 3=‘t’}
T_in = 1
T_out = 9
B, T, H, W, C = 10, T_in, 32, 32, 2
hidden_channels = 10
n_modes = 16
rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(B, T, H, W, C))
batched_fno = BatchedFNO2D(
hidden_channels=hidden_channels,
d_vars=C,
n_steps=T_out,
n_modes=n_modes,
use_positions=False,
d_model=(H, W),
)
params = batched_fno.init(rng, x)
y = batched_fno.apply(params, x)
print(y.shape)
assert y.shape == (B, T_out, H, W, C)
display(jax.tree_util.tree_map(jnp.shape, params["params"])):::