Linear utility functions

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import numpy as np
from typing import List, Union, Tuple, Optional, Dict

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def sort_eigvals_eigvecs(
    eigvals: np.ndarray,  # eigenvalues
    eigvecs: np.ndarray,  # eigenvectors
) -> Tuple[np.ndarray, np.ndarray]:  # (eigenvalues, eigenvectors)
    idx = np.argsort(np.abs(eigvals))[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]
    return eigvals, eigvecs


def get_linear_model(
    X: np.ndarray,  # data matrix (state_vars, time_steps)
    X_prime: np.ndarray,  # shifted data matrix (state_vars, time_steps)
    sorted: bool = True,  # sort eigenvalues and eigenvectors
) -> Tuple[np.ndarray, np.ndarray]:  # (eigenvalues, eigenvectors)
    """
    Returns the linear model A such that X' = AX
    """
    A_lstsq = np.linalg.lstsq(X.T, X_prime.T, rcond=None)[0]
    A_lstsq = A_lstsq.T
    eigenvalues, eigenvectors = np.linalg.eig(A_lstsq)

    return (
        sort_eigvals_eigvecs(eigenvalues, eigenvectors)
        if sorted
        else (eigenvalues, eigenvectors)
    )

:::

DMD Utils

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import jax.numpy as jnp
from pydmd.utils import pseudo_hankel_matrix
from pydmd import DMD

:::

::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def hankelize(
    u: jnp.ndarray,
    d: int = 2,
):
    # slice the data and hankelize
    u_ref = u
    u_ref = pseudo_hankel_matrix(u_ref.T, d=d).T
    return u_ref


def fit_dmd_to_sample(x: jnp.ndarray, r: int = 50) -> DMD:  # (timesteps, gridpoints)
    dmd = DMD(svd_rank=r)
    dmd.fit(x.squeeze().T)
    return dmd


def fast_predict(
    y: jnp.ndarray,  # gridsize
    inv_modes: jnp.ndarray,
    fwd_modes: jnp.ndarray,
    eigs: jnp.ndarray,
    lenght: int = None,
):
    if lenght is None:
        lenght = y.shape[0]

    # we need to predict the next 3999 timesteps
    lenght += 1

    states = jnp.vander(eigs, lenght, increasing=True)
    x_0 = inv_modes @ y
    pred = fwd_modes @ (states * x_0[..., None])

    # slice from the second timestep and convert to (time, gridsize)
    return pred[:, 1:].T.real

:::

Utilities to replace weights

::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

from pydmd.utils import pseudo_hankel_matrix
from pydmd import DMD, BOPDMD

:::

::: {#cell-10 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def get_linear_approximation(
    y: jnp.ndarray,  # (time_steps, grid_size)
    r: int = 50,
    method: str = "dmd",  # "dmd", "full", "analytical"
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    # fit the model parameters using DMD before training
    if method in ["dmd"]:
        dmd = DMD(svd_rank=r)
        dmd.fit(y.T)
        eigvals, eigvecs = sort_eigvals_eigvecs(dmd.eigs, dmd.modes)
        eigvals = jnp.log(eigvals)  # make continuous

    elif method in ["analytical"]:
        dt = 1 / 4000
        eigvals, eigvecs = stiff_string_eigendecomposition(n_max_modes=r)
        eigvals = eigvals * dt
        eigvecs = eigvecs.T
    return eigvals, eigvecs


def set_params_from_linear(
    params: Dict,
    single_eigvals: jnp.ndarray,
    single_lstvecs: jnp.ndarray,
    model: str,  # "lru" or "koopman"
) -> Dict:
    if model == "lru":
        map_to_0_2pi = lambda x: jnp.where(x < 0, x + 2 * jnp.pi, x)

        # nu_log = jnp.nan_to_num(jnp.log(-jnp.log(jnp.abs(single_eigvals))))
        # theta_log = jnp.nan_to_num(jnp.log(map_to_0_2pi(jnp.angle(single_eigvals))))
        nu_log = jnp.log(-single_eigvals.real)
        theta_log = jnp.nan_to_num(jnp.log(map_to_0_2pi(single_eigvals.imag)))

        params["first_layer"]["nu_log"] = nu_log
        params["first_layer"]["theta_log"] = theta_log
        # set eigenvectors
        inv_lstvecs = jnp.linalg.pinv(single_lstvecs)

        params["first_layer"]["B_re"] = inv_lstvecs.real
        params["first_layer"]["B_im"] = inv_lstvecs.imag

        params["first_layer"]["C_re"] = single_lstvecs.real
        params["first_layer"]["C_im"] = single_lstvecs.imag

    elif model == "koopman":
        ################
        # single_lstvecs = lstvecs[:, ::2]
        # single_lstvecs = single_lstvecs[:single_lstvecs.shape[0] // 2, :single_lstvecs.shape[1] // 2]

        conj_eigenvalues = jnp.concatenate(
            [
                single_eigvals.real + 1j * jnp.abs(single_eigvals.imag),
                single_eigvals.real - 1j * jnp.abs(single_eigvals.imag),
            ]
        )

        conj_eigenvecs = jnp.concatenate(
            [
                single_lstvecs.real + 1j * 0,
                single_lstvecs.real - 1j * 0,
            ],
            axis=1,
        )

        inv_eigenvecs = jnp.linalg.pinv(conj_eigenvecs)
        inv_eigenvecs_as_real = jnp.concatenate(
            [inv_eigenvecs.real, inv_eigenvecs.imag], axis=0
        ).T
        fwd_eigenvecs_as_real = jnp.concatenate(
            [conj_eigenvecs.real, conj_eigenvecs.imag], axis=1
        ).T

        params["params"]["batched_koopman"]["encoder"]["encoder"]["kernel"] = (
            inv_eigenvecs_as_real
        )
        params["params"]["batched_koopman"]["decoder"]["decoder"]["kernel"] = (
            fwd_eigenvecs_as_real
        )

        params["params"]["batched_koopman"]["weight_real"] = conj_eigenvalues.real[
            : conj_eigenvalues.shape[0] // 2
        ]
        params["params"]["batched_koopman"]["weight_imag"] = conj_eigenvalues.imag[
            : conj_eigenvalues.shape[0] // 2
        ]

    return params

:::