= 128
d_hidden = 5, 16, 41, 37, 2
B, T, H, W, C = jnp.zeros((B, T, H, W, C))
dummy = jnp.zeros((B, T, H, W, C))
target
= partial(LRUDynamics, d_hidden=(20*20*2), r_min=0.9, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)
dynamics_model
= BatchedFourierAutoencoder2D(
model =dynamics_model,
dynamics_model=C,
d_vars=(H, W),
d_model="layer",
norm=True,
training=20
n_modes
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, H, W, C)
Autoencoder models
2D Convolutional Autoencoder with Linear Dynamics
FourierAutoencoder2D
FourierAutoencoder2D (dynamics_model:flax.linen.module.Module, d_vars:int, d_model:Tuple[int,int], norm:str='layer', training:bool=True, use_positions:bool=False, n_modes:int=20, parent:Un ion[flax.linen.module.Module,flax.core.scope.Scope, flax.linen.module._Sentinel,NoneType]=<flax.linen.m odule._Sentinel object at 0x124a34f40>, name:Optional[str]=None)
DenseKoopmanAutoencoder2D
DenseKoopmanAutoencoder2D (encoder_model:flax.linen.module.Module, decoder_model:flax.linen.module.Module, dynamics_model:flax.linen.module.Module, d_vars:int, d_model:Tuple[int,int], n_steps:int, norm:str='layer', training:bool=True, use_positions:bool=False, parent:Union[flax.linen.module.Module,flax.cor e.scope.Scope,flax.linen.module._Sentinel,None Type]=<flax.linen.module._Sentinel object at 0x124a34f40>, name:Optional[str]=None)
Koopman Dense Autoencoder
= 16
n_steps = 128
d_hidden = 5, n_steps, 41, 37, 3
B, T, H, W, C = jnp.zeros((B, T, H, W, C))
dummy = jnp.zeros((B, T, H, W, C))
target
= partial(nn.Dense, features=d_hidden*2)
encoder_model = partial(nn.Dense, features=H * W * C)
decoder_model = partial(LRUDynamics, d_hidden=d_hidden, r_min=0.9, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)
dynamics_model
= BatchedDenseKoopmanAutoencoder2D(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=n_steps,
n_steps=C,
d_vars=(H, W),
d_model="layer",
norm=True
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, H, W, C)
= model.apply(vars, dummy, method="encode")
encoded = model.apply(vars, encoded, method="decode")
decoded
= model.apply(vars, target, method="encode")
enc_sequence = model.apply(vars, enc_sequence, method="decode") dec_sequence
KoopmanAutoencoder2D
KoopmanAutoencoder2D (encoder_model:physmodjax.models.conv.ConvEncoder, decoder_model:physmodjax.models.conv.ConvDecoder, dynamics_model:physmodjax.models.ssm.LRUDynamics, d_latent_channels:int, d_latent_dims:Tuple[int,int], n_steps:int, norm:str='layer', training:bool=True, parent:Union[ flax.linen.module.Module,flax.core.scope.Scope,flax .linen.module._Sentinel,NoneType]=<flax.linen.modul e._Sentinel object at 0x124a34f40>, name:Optional[str]=None)
Koopman Autoencoder
= 16
n_steps = 128
d_hidden = 5, n_steps, 40, 40, 3
B, T, H, W, C = jnp.zeros((B, T, H, W, C))
dummy = jnp.zeros((B, T, H, W, C))
target
= partial(ConvEncoder, block_size=(8, 16, 32))
encoder_model = partial(ConvDecoder, block_size=(8, 16, 32))
decoder_model = partial(LRUDynamics, d_hidden=16 * 5 * 5, r_min=0.9, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)
dynamics_model
= BatchedKoopmanAutoencoder2D(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=32,
d_latent_channels=(5, 5),
d_latent_dims=16,
n_steps="layer",
norm=True
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out assert out.shape == (B, T, H, W, C)
= model.apply(vars, dummy, method="encode")
encoded = model.apply(vars, encoded, method="decode")
decoded
= model.apply(vars, target, method="encode")
enc_sequence = model.apply(vars, enc_sequence, method="decode") dec_sequence
Koopman Autoencoder 1D
KoopmanAutoencoder1D
KoopmanAutoencoder1D (encoder_model:flax.linen.module.Module, decoder_model:flax.linen.module.Module, dynamics_model:flax.linen.module.Module, d_vars:int, d_model:int, n_steps:int, norm:str='layer', training:bool=True, parent:Union[ flax.linen.module.Module,flax.core.scope.Scope,flax .linen.module._Sentinel,NoneType]=<flax.linen.modul e._Sentinel object at 0x124a34f40>, name:Optional[str]=None)
Koopman Autoencoder
= 5, 16, 101, 3
B, T, W, C = 128
d_hidden = jnp.zeros((B, T, W, C))
dummy = jnp.zeros((B, T, W, C))
target
= partial(
encoder_model
nn.Dense,=d_hidden * 2,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
decoder_model
nn.Dense,=W * C,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
dynamics_model
LRUDynamicsVarying,=d_hidden,
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=nn.Dense(features=d_hidden * 2, kernel_init=nn.initializers.orthogonal()),
model=False,
clip_eigs
)
= BatchedKoopmanAutoencoder1D(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=C,
d_vars=W,
d_model=T,
n_steps="layer",
norm=True,
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, W, C)
x (303,)
x (303,)
print(dummy.shape)
= model.apply(vars, dummy, method="encode")
encoded print("encoded", encoded.shape)
= model.apply(vars, encoded, method="decode")
decoded assert decoded.shape == dummy.shape
= model.apply(vars, target, method="encode")
enc_sequence = model.apply(vars, enc_sequence, method="decode")
dec_sequence
assert dec_sequence.shape == (B, T, W, C)
(5, 16, 101, 3)
x (16, 303)
encoded (5, 16, 256)
x (16, 303)
KoopmanAutoencoder1DReal
KoopmanAutoencoder1DReal (encoder_model:flax.linen.module.Module, decoder_model:flax.linen.module.Module, dynamics_model:flax.linen.module.Module, d_vars:int, d_model:int, n_steps:int, norm:str='layer', training:bool=True, parent:Un ion[flax.linen.module.Module,flax.core.scope.Sc ope,flax.linen.module._Sentinel,NoneType]=<flax .linen.module._Sentinel object at 0x124a34f40>, name:Optional[str]=None)
Koopman Autoencoder but with real encoding and decoding
= 5, 16, 101, 3
B, T, W, C = 128
d_hidden = jnp.zeros((B, T, W, C))
dummy = jnp.zeros((B, T, W, C))
target
= partial(
encoder_model
nn.Dense,=d_hidden,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
decoder_model
nn.Dense,=W * C,
features=nn.initializers.orthogonal(),
kernel_init=False,
use_bias
)= partial(
dynamics_model
LRUDynamicsVarying,=d_hidden,
d_hidden=0.9,
r_min=1.0,
r_max=jnp.pi * 2,
max_phase=nn.Dense(
model=d_hidden * 2,
features=nn.initializers.orthogonal(),
kernel_init
),=False,
clip_eigs
)
= BatchedKoopmanAutoencoder1DReal(
model =encoder_model,
encoder_model=decoder_model,
decoder_model=dynamics_model,
dynamics_model=C,
d_vars=W,
d_model=T,
n_steps="layer",
norm=True,
training
)
vars = model.init(jax.random.PRNGKey(0), dummy)
= model.apply(vars, dummy)
out
assert out.shape == (B, T, W, C)