Autoencoder models

2D Convolutional Autoencoder with Linear Dynamics


FourierAutoencoder2D

 FourierAutoencoder2D (dynamics_model:flax.linen.module.Module,
                       d_vars:int, d_model:Tuple[int,int],
                       norm:str='layer', training:bool=True,
                       use_positions:bool=False, n_modes:int=20, parent:Un
                       ion[flax.linen.module.Module,flax.core.scope.Scope,
                       flax.linen.module._Sentinel,NoneType]=<flax.linen.m
                       odule._Sentinel object at 0x124a34f40>,
                       name:Optional[str]=None)
d_hidden = 128
B, T, H, W, C = 5, 16, 41, 37, 2
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))

dynamics_model = partial(LRUDynamics, d_hidden=(20*20*2), r_min=0.9, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)

model = BatchedFourierAutoencoder2D(
    dynamics_model=dynamics_model,
    d_vars=C,
    d_model=(H, W),
    norm="layer",
    training=True,
    n_modes=20
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, H, W, C)

DenseKoopmanAutoencoder2D

 DenseKoopmanAutoencoder2D (encoder_model:flax.linen.module.Module,
                            decoder_model:flax.linen.module.Module,
                            dynamics_model:flax.linen.module.Module,
                            d_vars:int, d_model:Tuple[int,int],
                            n_steps:int, norm:str='layer',
                            training:bool=True, use_positions:bool=False, 
                            parent:Union[flax.linen.module.Module,flax.cor
                            e.scope.Scope,flax.linen.module._Sentinel,None
                            Type]=<flax.linen.module._Sentinel object at
                            0x124a34f40>, name:Optional[str]=None)

Koopman Dense Autoencoder

n_steps = 16
d_hidden = 128
B, T, H, W, C = 5, n_steps, 41, 37, 3
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))

encoder_model = partial(nn.Dense, features=d_hidden*2)
decoder_model = partial(nn.Dense, features=H * W * C)
dynamics_model = partial(LRUDynamics, d_hidden=d_hidden, r_min=0.9, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)

model = BatchedDenseKoopmanAutoencoder2D(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    n_steps=n_steps,
    d_vars=C,
    d_model=(H, W),
    norm="layer",
    training=True
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, H, W, C)
encoded = model.apply(vars, dummy, method="encode")
decoded = model.apply(vars, encoded, method="decode")

enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode")

KoopmanAutoencoder2D

 KoopmanAutoencoder2D (encoder_model:physmodjax.models.conv.ConvEncoder,
                       decoder_model:physmodjax.models.conv.ConvDecoder,
                       dynamics_model:physmodjax.models.ssm.LRUDynamics,
                       d_latent_channels:int,
                       d_latent_dims:Tuple[int,int], n_steps:int,
                       norm:str='layer', training:bool=True, parent:Union[
                       flax.linen.module.Module,flax.core.scope.Scope,flax
                       .linen.module._Sentinel,NoneType]=<flax.linen.modul
                       e._Sentinel object at 0x124a34f40>,
                       name:Optional[str]=None)

Koopman Autoencoder

n_steps = 16
d_hidden = 128
B, T, H, W, C = 5, n_steps, 40, 40, 3
dummy = jnp.zeros((B, T, H, W, C))
target = jnp.zeros((B, T, H, W, C))

encoder_model = partial(ConvEncoder, block_size=(8, 16, 32))
decoder_model = partial(ConvDecoder, block_size=(8, 16, 32))
dynamics_model = partial(LRUDynamics, d_hidden=16 * 5 * 5, r_min=0.9, r_max=1.0, max_phase=jnp.pi * 2, clip_eigs=False)

model = BatchedKoopmanAutoencoder2D(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    d_latent_channels=32,
    d_latent_dims=(5, 5),
    n_steps=16,
    norm="layer",
    training=True
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)
assert out.shape == (B, T, H, W, C)
encoded = model.apply(vars, dummy, method="encode")
decoded = model.apply(vars, encoded, method="decode")

enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode")

Koopman Autoencoder 1D


KoopmanAutoencoder1D

 KoopmanAutoencoder1D (encoder_model:flax.linen.module.Module,
                       decoder_model:flax.linen.module.Module,
                       dynamics_model:flax.linen.module.Module,
                       d_vars:int, d_model:int, n_steps:int,
                       norm:str='layer', training:bool=True, parent:Union[
                       flax.linen.module.Module,flax.core.scope.Scope,flax
                       .linen.module._Sentinel,NoneType]=<flax.linen.modul
                       e._Sentinel object at 0x124a34f40>,
                       name:Optional[str]=None)

Koopman Autoencoder

B, T, W, C = 5, 16, 101, 3
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))

encoder_model = partial(
    nn.Dense,
    features=d_hidden * 2,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
decoder_model = partial(
    nn.Dense,
    features=W * C,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
dynamics_model = partial(
    LRUDynamicsVarying,
    d_hidden=d_hidden,
    r_min=0.9,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    model=nn.Dense(features=d_hidden * 2, kernel_init=nn.initializers.orthogonal()),
    clip_eigs=False,
)

model = BatchedKoopmanAutoencoder1D(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    d_vars=C,
    d_model=W,
    n_steps=T,
    norm="layer",
    training=True,
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, W, C)
x (303,)
x (303,)
print(dummy.shape)
encoded = model.apply(vars, dummy, method="encode")
print("encoded", encoded.shape)
decoded = model.apply(vars, encoded, method="decode")
assert decoded.shape == dummy.shape

enc_sequence = model.apply(vars, target, method="encode")
dec_sequence = model.apply(vars, enc_sequence, method="decode")

assert dec_sequence.shape == (B, T, W, C)
(5, 16, 101, 3)
x (16, 303)
encoded (5, 16, 256)
x (16, 303)

KoopmanAutoencoder1DReal

 KoopmanAutoencoder1DReal (encoder_model:flax.linen.module.Module,
                           decoder_model:flax.linen.module.Module,
                           dynamics_model:flax.linen.module.Module,
                           d_vars:int, d_model:int, n_steps:int,
                           norm:str='layer', training:bool=True, parent:Un
                           ion[flax.linen.module.Module,flax.core.scope.Sc
                           ope,flax.linen.module._Sentinel,NoneType]=<flax
                           .linen.module._Sentinel object at 0x124a34f40>,
                           name:Optional[str]=None)

Koopman Autoencoder but with real encoding and decoding

B, T, W, C = 5, 16, 101, 3
d_hidden = 128
dummy = jnp.zeros((B, T, W, C))
target = jnp.zeros((B, T, W, C))

encoder_model = partial(
    nn.Dense,
    features=d_hidden,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
decoder_model = partial(
    nn.Dense,
    features=W * C,
    kernel_init=nn.initializers.orthogonal(),
    use_bias=False,
)
dynamics_model = partial(
    LRUDynamicsVarying,
    d_hidden=d_hidden,
    r_min=0.9,
    r_max=1.0,
    max_phase=jnp.pi * 2,
    model=nn.Dense(
        features=d_hidden * 2,
        kernel_init=nn.initializers.orthogonal(),
    ),
    clip_eigs=False,
)

model = BatchedKoopmanAutoencoder1DReal(
    encoder_model=encoder_model,
    decoder_model=decoder_model,
    dynamics_model=dynamics_model,
    d_vars=C,
    d_model=W,
    n_steps=T,
    norm="layer",
    training=True,
)

vars = model.init(jax.random.PRNGKey(0), dummy)
out = model.apply(vars, dummy)

assert out.shape == (B, T, W, C)