"jax_platform_name", "cpu") jax.config.update(
FNO embedded in a recurrent neural network
This notebook adapts the FNO for its use in a recurrent neural network. The idea is to use the FNO to learn the dynamics of a system, and then use the FNO as a layer in a recurrent neural network to learn the dynamics of the system over time. This is a reimplementation of https://github.com/julian-parker/DAFX22_FNO in Jax.
FNORNN
FNORNN (hidden_channels:int, grid_size:int, n_spectral_layers:int=4, out_channels:int=1, length:int=None, activation:flax.linen.module .Module=<jax._src.custom_derivatives.custom_jvp object at 0x1257873d0>, parent:Union[flax.linen.module.Module,flax.core.sco pe.Scope,flax.linen.module._Sentinel,NoneType]=<flax.linen.module ._Sentinel object at 0x1345271f0>, name:Optional[str]=None)
FNOCell
FNOCell (hidden_channels:int, grid_size:int, layers:int=4, out_channels:int=1, activation:flax.linen.module.Module=<jax._sr c.custom_derivatives.custom_jvp object at 0x1257873d0>, parent:U nion[flax.linen.module.Module,flax.core.scope.Scope,flax.linen.m odule._Sentinel,NoneType]=<flax.linen.module._Sentinel object at 0x1345271f0>, name:Optional[str]=None)
Parker’s ARMA without input
"jax_platform_name", "cpu")
jax.config.update(= 6
hidden_channels = 101
grid_size = 10
time_steps
= FNORNN(
fno_rnn =hidden_channels,
hidden_channels=grid_size,
grid_size=time_steps,
length
)
= jnp.ones((grid_size, 1))
h0 = jnp.ones((time_steps, grid_size, 1))
x
= fno_rnn.init(jax.random.PRNGKey(0), h0, x)
params = fno_rnn.apply(params, h0, x)
y
assert y.shape == x.shape
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705516485.179141 406942 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
BatchFNORNN
BatchFNORNN (hidden_channels:int, grid_size:int, n_spectral_layers:int=4, out_channels:int=1, length:int=None, activation:flax.linen.m odule.Module=<jax._src.custom_derivatives.custom_jvp object at 0x1257873d0>, parent:Union[flax.linen.module.Module,flax. core.scope.Scope,flax.linen.module._Sentinel,NoneType]=<flax .linen.module._Sentinel object at 0x1345271f0>, name:Optional[str]=None)
= 3
batch_size = 10
time_steps = jnp.ones((batch_size, time_steps, grid_size, 1))
x = jnp.ones((batch_size, grid_size, 1))
h0
= BatchFNORNN(
batch_fno_rnn =hidden_channels,
hidden_channels=grid_size,
grid_size=time_steps,
length
)
= batch_fno_rnn.init(
params 0), h0, x
jax.random.PRNGKey(# Why does it need to be initialised with the number of timesteps?
) = batch_fno_rnn.apply(params, h0, x)
y # Print the shape of the output
print(y.shape)
(3, 10, 101, 1)