FNO embedded in a recurrent neural network

This notebook adapts the FNO for its use in a recurrent neural network. The idea is to use the FNO to learn the dynamics of a system, and then use the FNO as a layer in a recurrent neural network to learn the dynamics of the system over time. This is a reimplementation of https://github.com/julian-parker/DAFX22_FNO in Jax.

jax.config.update("jax_platform_name", "cpu")

source

FNORNN

 FNORNN (hidden_channels:int, grid_size:int, n_spectral_layers:int=4,
         out_channels:int=1, length:int=None, activation:flax.linen.module
         .Module=<jax._src.custom_derivatives.custom_jvp object at
         0x1257873d0>, parent:Union[flax.linen.module.Module,flax.core.sco
         pe.Scope,flax.linen.module._Sentinel,NoneType]=<flax.linen.module
         ._Sentinel object at 0x1345271f0>, name:Optional[str]=None)

source

FNOCell

 FNOCell (hidden_channels:int, grid_size:int, layers:int=4,
          out_channels:int=1, activation:flax.linen.module.Module=<jax._sr
          c.custom_derivatives.custom_jvp object at 0x1257873d0>, parent:U
          nion[flax.linen.module.Module,flax.core.scope.Scope,flax.linen.m
          odule._Sentinel,NoneType]=<flax.linen.module._Sentinel object at
          0x1345271f0>, name:Optional[str]=None)

Parker’s ARMA without input

jax.config.update("jax_platform_name", "cpu")
hidden_channels = 6
grid_size = 101
time_steps = 10

fno_rnn = FNORNN(
    hidden_channels=hidden_channels,
    grid_size=grid_size,
    length=time_steps,
)

h0 = jnp.ones((grid_size, 1))
x = jnp.ones((time_steps, grid_size, 1))

params = fno_rnn.init(jax.random.PRNGKey(0), h0, x)
y = fno_rnn.apply(params, h0, x)

assert y.shape == x.shape
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705516485.179141  406942 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.

source

BatchFNORNN

 BatchFNORNN (hidden_channels:int, grid_size:int, n_spectral_layers:int=4,
              out_channels:int=1, length:int=None, activation:flax.linen.m
              odule.Module=<jax._src.custom_derivatives.custom_jvp object
              at 0x1257873d0>, parent:Union[flax.linen.module.Module,flax.
              core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax
              .linen.module._Sentinel object at 0x1345271f0>,
              name:Optional[str]=None)
batch_size = 3
time_steps = 10
x = jnp.ones((batch_size, time_steps, grid_size, 1))
h0 = jnp.ones((batch_size, grid_size, 1))

batch_fno_rnn = BatchFNORNN(
    hidden_channels=hidden_channels,
    grid_size=grid_size,
    length=time_steps,
)

params = batch_fno_rnn.init(
    jax.random.PRNGKey(0), h0, x
)  # Why does it need to be initialised with the number of timesteps?
y = batch_fno_rnn.apply(params, h0, x)
# Print the shape of the output
print(y.shape)
(3, 10, 101, 1)