= 1
batch_size = 2
in_channels = 2
d_vars = 10 # length of the input signal, also can be seen as the grid size
length = length
n_modes
= SpectralConv1d(
conv =in_channels,
in_channels=d_vars,
d_vars=n_modes,
n_modes=True,
linear_conv
)
= jax.random.PRNGKey(0)
rng = jax.random.uniform(rng, shape=(length, in_channels))
x
= conv.init(jax.random.PRNGKey(0), x=x)
params
= conv.apply(params, x)
y
assert y.shape == (length, d_vars)
Fourier Neural Operator in 1D and 2D
Neural Operator
We are interested in learning the mapping between function spaces. In particular, we are interested in learning the mapping between the input space \(\Omega\) and the output space \(\Lambda\) of a function \(u: \Omega \rightarrow \Lambda\). We will assume that the input space is a subset of \(\mathbb{R}^d\) and the output space is a subset of \(\mathbb{R}^m\). We will also assume that the function \(u\) is smooth, i.e., it has a finite number of derivatives. We will denote the derivatives of \(u\) by \(u_{x_i}\), \(u_{x_i x_j}\), etc. We will also assume that the function \(u\) satisfies a partial differential equation (PDE) \(\mathcal{L} u = 0\) for some linear differential operator \(\mathcal{L}\).
A single layer of the neural operator is defined as follows:
\[ \begin{aligned} \mathcal{F} &:= \sigma \left(W +\mathcal{K} + b \right) \\ \mathcal{G}_\theta &:= \mathcal{Q} \circ \mathcal{F} \circ \mathcal{P} \end{aligned} \]
where
- \(\mathcal{P} : \mathbb{R^{in}} \to \mathbb{R^{hidden}}\) is a lifting layer
- \(\mathcal{Q} : \mathbb{R^{hidden}} \to \mathbb{R^{out}}\) is a projection layer
- \(\mathcal{F} \colon \mathbb{R^{hidden}} \to \mathbb{R^{hidden}}\) is the Neural Operator Layer with
- \(\mathcal{K}\) is one of several Kernel Operators
- \({W}\) is a matrix (local linear operator); a “skip connection” inspired by ResNet
- \(b\) is a “function” bias
The “Fourier” Neural operator
It takes the form of the linear transformation (convolution) of the Fourier coeffcients of the input function \(v\), and the kernel \(R_\phi\). The result is then transformed back using the inverse Fourier transform.
\[ \mathcal{K} = \mathcal{F}^{-1} (R_\phi \cdot \mathcal{F} (v)) \]
SpectralConv1d
SpectralConv1d (in_channels:int, d_vars:int, n_modes:int, linear_conv:bool=True, parent:Union[flax.linen.module.Mod ule,flax.core.scope.Scope,flax.linen.module._Sentinel,Non eType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
Spectral Convolution Layer for 1D inputs. The n_modes parameter should be set to the length of the output for now, as it is not clear that the truncation is done correctly
Test
SpectralLayers1d
SpectralLayers1d (n_channels:int, n_modes:int, linear_conv:bool=True, n_layers:int=4, activation:flax.linen.module.Module=<ja x._src.custom_derivatives.custom_jvp object at 0x12fe5bc70>, parent:Union[flax.linen.module.Module,fla x.core.scope.Scope,flax.linen.module._Sentinel,NoneType ]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
Stack of 1D Spectral Convolution Layers
= 6
hidden_channels = 101
grid_size
= SpectralLayers1d(
spectral_layers =hidden_channels,
n_channels=grid_size,
n_modes=True,
linear_conv=4,
n_layers=nn.relu,
activation
)= spectral_layers.init(
params 0), jnp.ones((grid_size, hidden_channels))
jax.random.PRNGKey(
)
= jnp.ones((grid_size, hidden_channels))
x = spectral_layers.apply(params, x)
y assert y.shape == x.shape
Fourier Neural Operator in 1 Dimension
FNO1D
FNO1D (hidden_channels:int, n_modes:int, d_vars:int=1, linear_conv:bool=True, n_layers:int=4, n_steps:int=None, activation:flax.linen.module.Module=<function gelu>, norm:str=('layer',), training:bool=True, parent:Union[flax.linen.m odule.Module,flax.core.scope.Scope,flax.linen.module._Sentinel,Non eType]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
Test
= 1
time = 6
hidden_channels = 101
grid_size = 1
in_channels = 5
d_vars = 2
n_layers = 10
batch_size
= BatchedFNO1D(
batch_fno =hidden_channels,
hidden_channels=grid_size,
n_modes=d_vars,
d_vars=n_layers,
n_layers=1,
n_steps
)
= jax.random.PRNGKey(0)
rng = jnp.ones((batch_size, time, grid_size, in_channels))
x
= batch_fno.init(jax.random.PRNGKey(0), x)
params = batch_fno.apply(params, x)
y
# assert y.shape == x.shape
assert y.shape[-1] == d_vars
assert y.shape == (batch_size, time, grid_size, d_vars)
"params"])) display(jax.tree_util.tree_map(jnp.shape, params[
{'Dense_0': {'bias': (6,), 'kernel': (1, 6)},
'Dense_1': {'bias': (128,), 'kernel': (6, 128)},
'Dense_2': {'bias': (5,), 'kernel': (128, 5)},
'SpectralLayers1d_0': {'layers_conv_0': {'weight_imag': (6, 6, 101),
'weight_real': (6, 6, 101)},
'layers_conv_1': {'weight_imag': (6, 6, 101), 'weight_real': (6, 6, 101)},
'layers_w_0': {'bias': (6,), 'kernel': (1, 6, 6)},
'layers_w_1': {'bias': (6,), 'kernel': (1, 6, 6)}}}
Fourier Neural Operator in 2D
SpectralConv2d
SpectralConv2d (in_channels:int, out_channels:int, n_modes1:int, n_modes2:int, parent:Union[flax.linen.module.Module,flax. core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<f lax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
= 1
batch_size = 2
in_channels = 2
out_channels = 10
height = 10
width = width // 2 + 1
n_modes
= SpectralConv2d(
conv =in_channels,
in_channels=out_channels,
out_channels=n_modes,
n_modes1=n_modes,
n_modes2
)
= jax.random.PRNGKey(0)
rng = jax.random.uniform(rng, shape=(height, width, in_channels))
x
= conv.init(
params
rng,=x
x
)= conv.apply(params, x)
y
assert y.shape == (height, width, out_channels)
FNO2D
FNO2D (hidden_channels:int, n_modes:int, d_vars:int=1, linear_conv:bool=True, n_layers:int=4, n_steps:int=None, activation:flax.linen.module.Module=<function gelu>, d_model:Tuple[int,int]=(41, 37), use_positions:bool=False, norm:str='layer', training:bool=True, parent:Union[flax.linen.modu le.Module,flax.core.scope.Scope,flax.linen.module._Sentinel,NoneTy pe]=<flax.linen.module._Sentinel object at 0x12ff8ca90>, name:Optional[str]=None)
Test
= 1
T_in = 9
T_out = 10, T_in, 32, 32, 2
B, T, H, W, C = 10
hidden_channels = 16
n_modes
= jax.random.PRNGKey(0)
rng = jax.random.uniform(rng, shape=(B, T, H, W, C))
x
= BatchedFNO2D(
batched_fno =hidden_channels,
hidden_channels=C,
d_vars=T_out,
n_steps=n_modes,
n_modes=False,
use_positions=(H, W),
d_model
)= batched_fno.init(rng, x)
params
= batched_fno.apply(params, x)
y print(y.shape)
assert y.shape == (B, T_out, H, W, C)
"params"])) display(jax.tree_util.tree_map(jnp.shape, params[
(10, 9, 32, 32, 2)
{'P': {'bias': (10,), 'kernel': (2, 10)},
'Q': {'layers_0': {'bias': (128,), 'kernel': (10, 128)},
'layers_2': {'bias': (18,), 'kernel': (128, 18)}},
'conv_layers_0': {'weight_1_imag': (10, 10, 16, 16),
'weight_1_real': (10, 10, 16, 16),
'weight_2_imag': (10, 10, 16, 16),
'weight_2_real': (10, 10, 16, 16)},
'conv_layers_1': {'weight_1_imag': (10, 10, 16, 16),
'weight_1_real': (10, 10, 16, 16),
'weight_2_imag': (10, 10, 16, 16),
'weight_2_real': (10, 10, 16, 16)},
'conv_layers_2': {'weight_1_imag': (10, 10, 16, 16),
'weight_1_real': (10, 10, 16, 16),
'weight_2_imag': (10, 10, 16, 16),
'weight_2_real': (10, 10, 16, 16)},
'conv_layers_3': {'weight_1_imag': (10, 10, 16, 16),
'weight_1_real': (10, 10, 16, 16),
'weight_2_imag': (10, 10, 16, 16),
'weight_2_real': (10, 10, 16, 16)},
'w_layers_0': {'bias': (10,), 'kernel': (1, 10, 10)},
'w_layers_1': {'bias': (10,), 'kernel': (1, 10, 10)},
'w_layers_2': {'bias': (10,), 'kernel': (1, 10, 10)},
'w_layers_3': {'bias': (10,), 'kernel': (1, 10, 10)}}