Fourier Neural Operator in 1D and 2D

Neural Operator

We are interested in learning the mapping between function spaces. In particular, we are interested in learning the mapping between the input space \(\Omega\) and the output space \(\Lambda\) of a function \(u: \Omega \rightarrow \Lambda\). We will assume that the input space is a subset of \(\mathbb{R}^d\) and the output space is a subset of \(\mathbb{R}^m\). We will also assume that the function \(u\) is smooth, i.e., it has a finite number of derivatives. We will denote the derivatives of \(u\) by \(u_{x_i}\), \(u_{x_i x_j}\), etc. We will also assume that the function \(u\) satisfies a partial differential equation (PDE) \(\mathcal{L} u = 0\) for some linear differential operator \(\mathcal{L}\).

A single layer of the neural operator is defined as follows:

\[ \begin{aligned} \mathcal{F} &:= \sigma \left(W +\mathcal{K} + b \right) \\ \mathcal{G}_\theta &:= \mathcal{Q} \circ \mathcal{F} \circ \mathcal{P} \end{aligned} \]

where

  • \(\mathcal{P} : \mathbb{R^{in}} \to \mathbb{R^{hidden}}\) is a lifting layer
  • \(\mathcal{Q} : \mathbb{R^{hidden}} \to \mathbb{R^{out}}\) is a projection layer
  • \(\mathcal{F} \colon \mathbb{R^{hidden}} \to \mathbb{R^{hidden}}\) is the Neural Operator Layer with
    • \(\mathcal{K}\) is one of several Kernel Operators
    • \({W}\) is a matrix (local linear operator); a “skip connection” inspired by ResNet
    • \(b\) is a “function” bias

The “Fourier” Neural operator

It takes the form of the linear transformation (convolution) of the Fourier coeffcients of the input function \(v\), and the kernel \(R_\phi\). The result is then transformed back using the inverse Fourier transform.

\[ \mathcal{K} = \mathcal{F}^{-1} (R_\phi \cdot \mathcal{F} (v)) \]


source

SpectralConv1d

 SpectralConv1d (in_channels:int, d_vars:int, n_modes:int,
                 linear_conv:bool=True, parent:Union[flax.linen.module.Mod
                 ule,flax.core.scope.Scope,flax.linen.module._Sentinel,Non
                 eType]=<flax.linen.module._Sentinel object at
                 0x12ff8ca90>, name:Optional[str]=None)

Spectral Convolution Layer for 1D inputs. The n_modes parameter should be set to the length of the output for now, as it is not clear that the truncation is done correctly

Test

batch_size = 1
in_channels = 2
d_vars = 2
length = 10  # length of the input signal, also can be seen as the grid size
n_modes = length

conv = SpectralConv1d(
    in_channels=in_channels,
    d_vars=d_vars,
    n_modes=n_modes,
    linear_conv=True,
)

rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(length, in_channels))

params = conv.init(jax.random.PRNGKey(0), x=x)

y = conv.apply(params, x)

assert y.shape == (length, d_vars)

source

SpectralLayers1d

 SpectralLayers1d (n_channels:int, n_modes:int, linear_conv:bool=True,
                   n_layers:int=4, activation:flax.linen.module.Module=<ja
                   x._src.custom_derivatives.custom_jvp object at
                   0x12fe5bc70>, parent:Union[flax.linen.module.Module,fla
                   x.core.scope.Scope,flax.linen.module._Sentinel,NoneType
                   ]=<flax.linen.module._Sentinel object at 0x12ff8ca90>,
                   name:Optional[str]=None)

Stack of 1D Spectral Convolution Layers

hidden_channels = 6
grid_size = 101

spectral_layers = SpectralLayers1d(
    n_channels=hidden_channels,
    n_modes=grid_size,
    linear_conv=True,
    n_layers=4,
    activation=nn.relu,
)
params = spectral_layers.init(
    jax.random.PRNGKey(0), jnp.ones((grid_size, hidden_channels))
)

x = jnp.ones((grid_size, hidden_channels))
y = spectral_layers.apply(params, x)
assert y.shape == x.shape

Fourier Neural Operator in 1 Dimension


source

FNO1D

 FNO1D (hidden_channels:int, n_modes:int, d_vars:int=1,
        linear_conv:bool=True, n_layers:int=4, n_steps:int=None,
        activation:flax.linen.module.Module=<function gelu>,
        norm:str=('layer',), training:bool=True, parent:Union[flax.linen.m
        odule.Module,flax.core.scope.Scope,flax.linen.module._Sentinel,Non
        eType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>,
        name:Optional[str]=None)

Test

time = 1
hidden_channels = 6
grid_size = 101
in_channels = 1
d_vars = 5
n_layers = 2
batch_size = 10

batch_fno = BatchedFNO1D(
    hidden_channels=hidden_channels,
    n_modes=grid_size,
    d_vars=d_vars,
    n_layers=n_layers,
    n_steps=1,
)

rng = jax.random.PRNGKey(0)
x = jnp.ones((batch_size, time, grid_size, in_channels))

params = batch_fno.init(jax.random.PRNGKey(0), x)
y = batch_fno.apply(params, x)

# assert y.shape == x.shape
assert y.shape[-1] == d_vars
assert y.shape == (batch_size, time, grid_size, d_vars)
display(jax.tree_util.tree_map(jnp.shape, params["params"]))
{'Dense_0': {'bias': (6,), 'kernel': (1, 6)},
 'Dense_1': {'bias': (128,), 'kernel': (6, 128)},
 'Dense_2': {'bias': (5,), 'kernel': (128, 5)},
 'SpectralLayers1d_0': {'layers_conv_0': {'weight_imag': (6, 6, 101),
   'weight_real': (6, 6, 101)},
  'layers_conv_1': {'weight_imag': (6, 6, 101), 'weight_real': (6, 6, 101)},
  'layers_w_0': {'bias': (6,), 'kernel': (1, 6, 6)},
  'layers_w_1': {'bias': (6,), 'kernel': (1, 6, 6)}}}

Fourier Neural Operator in 2D


source

SpectralConv2d

 SpectralConv2d (in_channels:int, out_channels:int, n_modes1:int,
                 n_modes2:int, parent:Union[flax.linen.module.Module,flax.
                 core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<f
                 lax.linen.module._Sentinel object at 0x12ff8ca90>,
                 name:Optional[str]=None)
batch_size = 1
in_channels = 2
out_channels = 2
height = 10
width = 10
n_modes = width // 2 + 1

conv = SpectralConv2d(
    in_channels=in_channels,
    out_channels=out_channels,
    n_modes1=n_modes,
    n_modes2=n_modes,
)

rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(height, width, in_channels))

params = conv.init(
    rng,
    x=x
)
y = conv.apply(params, x)

assert y.shape == (height, width, out_channels)

source

FNO2D

 FNO2D (hidden_channels:int, n_modes:int, d_vars:int=1,
        linear_conv:bool=True, n_layers:int=4, n_steps:int=None,
        activation:flax.linen.module.Module=<function gelu>,
        d_model:Tuple[int,int]=(41, 37), use_positions:bool=False,
        norm:str='layer', training:bool=True, parent:Union[flax.linen.modu
        le.Module,flax.core.scope.Scope,flax.linen.module._Sentinel,NoneTy
        pe]=<flax.linen.module._Sentinel object at 0x12ff8ca90>,
        name:Optional[str]=None)

Test

T_in = 1
T_out = 9
B, T, H, W, C = 10, T_in, 32, 32, 2
hidden_channels = 10
n_modes = 16

rng = jax.random.PRNGKey(0)
x = jax.random.uniform(rng, shape=(B, T, H, W, C))

batched_fno = BatchedFNO2D(
    hidden_channels=hidden_channels,
    d_vars=C,
    n_steps=T_out,
    n_modes=n_modes,
    use_positions=False,
    d_model=(H, W),
)
params = batched_fno.init(rng, x)

y = batched_fno.apply(params, x)
print(y.shape)
assert y.shape == (B, T_out, H, W, C)
display(jax.tree_util.tree_map(jnp.shape, params["params"]))
(10, 9, 32, 32, 2)
{'P': {'bias': (10,), 'kernel': (2, 10)},
 'Q': {'layers_0': {'bias': (128,), 'kernel': (10, 128)},
  'layers_2': {'bias': (18,), 'kernel': (128, 18)}},
 'conv_layers_0': {'weight_1_imag': (10, 10, 16, 16),
  'weight_1_real': (10, 10, 16, 16),
  'weight_2_imag': (10, 10, 16, 16),
  'weight_2_real': (10, 10, 16, 16)},
 'conv_layers_1': {'weight_1_imag': (10, 10, 16, 16),
  'weight_1_real': (10, 10, 16, 16),
  'weight_2_imag': (10, 10, 16, 16),
  'weight_2_real': (10, 10, 16, 16)},
 'conv_layers_2': {'weight_1_imag': (10, 10, 16, 16),
  'weight_1_real': (10, 10, 16, 16),
  'weight_2_imag': (10, 10, 16, 16),
  'weight_2_real': (10, 10, 16, 16)},
 'conv_layers_3': {'weight_1_imag': (10, 10, 16, 16),
  'weight_1_real': (10, 10, 16, 16),
  'weight_2_imag': (10, 10, 16, 16),
  'weight_2_real': (10, 10, 16, 16)},
 'w_layers_0': {'bias': (10,), 'kernel': (1, 10, 10)},
 'w_layers_1': {'bias': (10,), 'kernel': (1, 10, 10)},
 'w_layers_2': {'bias': (10,), 'kernel': (1, 10, 10)},
 'w_layers_3': {'bias': (10,), 'kernel': (1, 10, 10)}}