Tension modulated stiff membrane

This notebook is a solver using the functional transform method for a stiff membrane with tension modulation. The membrane is described by the following PDE:

\[ \begin{align} D \nabla^4 u +\rho h \frac{\partial^2 u}{\partial t^2} - T(u) \nabla^2 u + d_{1} \frac{\partial u}{\partial t} + d_{3} \frac{\partial \nabla^2 u}{\partial t} = f^{(ext)} \end{align} \] Assumming bending stiffness to be small \(D/T \ll 1\) Where for a rectangular plate \(x \in [0, L_x], y \in [0, L_y]\) of thickness \(h\) , \(D = Q h^3 / 12 (1 - \nu ^2)\) \(\rho\) is the mass density \((kg/m^3)\), T(u) is the tension, \(d_1\) and \(d_3\) are the damping coefficients, and \(f^{(ext)}\) is the external force.

The modes for simply supported BC are \[ K_{n,m}\left(x, y \right) =\sin \left(\frac{n \pi x}{L_x} \right) sin \left(\frac{m \pi y}{L_y} \right) \] Where \(n, m\) are integers. \[ \begin{aligned} & \nabla^2 K_{n, m}(x, y)=-\lambda_{n, m} K_{n, m}(x, y), \\ & \text { with } \quad \lambda_{n, m}=\pi^2\left[\left(\frac{n}{L_x}\right)^2+\left(\frac{m}{L_y}\right)^2\right] . \end{aligned} \]

The tension is given by \[ T(u) = T_0 + T_{NL}(u) \]

We consider, \[ T_{N L}(u)=C_{N L} \frac{S(u)-S_0}{S_0} \simeq \frac{1}{2} \frac{C_{N L}}{S_0} \int_{\mathcal{S}}\|\nabla u\|^2 d \mathbf{x} \] For the rectangular plate, the Berger approximation is used for the tension \[ \begin{aligned} T_{N L}(u) \simeq & \frac{Q h}{2 L_x L_y\left(1-\nu^2\right)} \\ & \cdot \int_0^{L_x} \int_0^{L_y}\left[\left(\frac{\partial u}{\partial x}\right)^2+\left(\frac{\partial u}{\partial y}\right)^2\right] d x d y . \end{aligned} \] And therefore \[ C_{N L}=\frac{Q h}{\left(1-\nu^2\right)} \]

\[ \boxed{ \begin{aligned} \rho h \ddot{\bar{u}}_{n, m}(t) + \left(d_3 \lambda_{n, m}+d_1\right) \dot{\bar{u}}_{n, m}( t) + \left(\lambda_{n, m} \left(\lambda_{n, m}D +T_0\right)\right) \bar{u}_{n, m}(t)-\bar{f}^{(tm)}_{n, m}(u, \bar{u}) = 0 \end{aligned} } \]

Where \[ \begin{aligned} f^{(tm)}_{n, m}(u, \bar{u}) & = -\lambda_{n, m} T_{N L} (u) \bar{u}_{n, m}(t) \\ & = -\lambda_{n, m} \frac{1}{2}\frac{C_{N L}}{S_0} \left[\sum_{\tilde{n}, \tilde{m}} \frac{\lambda_{\tilde{n}, \tilde{m}}\bar{u}_{\tilde{n}, \tilde{m}}^2 (t)}{\lVert K_{\tilde{n}, \tilde{m}} \rVert_{2}^{2}} \right] \bar{u}_{n, m}(t) \end{aligned} \]

For a rectangular plate, the norm of the modes is given by \[\lVert K_{n, m} \rVert_{2}^{2} = \langle K_{n, m} , K_{n, m} \rangle = \int _0^{L_x}\int _0^{L_y}\sin ^2\left(\frac{n \pi x}{L_x}\right) \sin ^2\left(\frac{m \pi y}{L_y}\right)dydx = \frac{L_x L_y}{4} \quad \forall n, m \in \mathbb{Z} \]

Write as a system of first order ODEs:

\[ \begin{aligned} \dot{\bar{u}}(\mu, t) &= \bar{v}(\mu, t) \\ \dot{\bar{v}}(\mu, t) &= \frac{-\left(d_3 \lambda_\mu+d_1\right)}{\rho h} \bar{v}(\mu, t) - \frac{\beta_\mu}{\rho h} \bar{u}(\mu, t) + \frac{f^{(tm)}_{\mu}(u, \bar{u})}{\rho h} \end{aligned} \]

Where \[ \begin{aligned} \beta_\mu &= \lambda_\mu \left(\lambda_\mu D + T_0\right) = \lambda_\mu^2 D + \lambda_\mu T_0 \\ \bar{b}(\mu, u, \bar{u}) &= \frac{f^{(tm)}_{\mu}(u, \bar{u})}{\rho h} \\ &= -\frac{1}{\rho h} \lambda_{\mu} \frac{1}{2}\frac{Q h}{\left(1-\nu^2\right)}\frac{1}{L_x L_y} \frac{4}{L_x L_y}\left[\sum_{\eta} \lambda_{\eta}\bar{u}_{\eta}^2 (t) \right] \bar{u}_{\mu}(t) \\ &= -\lambda_{\mu} \frac{Q}{\rho \left(1-\nu^2\right)}\frac{2}{L_x^2 L_y^2}\left[\sum_{\eta} \lambda_{\eta}\bar{u}_{\eta}^2 (t) \right] \bar{u}_{\mu}(t) \\ \end{aligned} \]

${u}{n, m}(t) = {u}{}(t) $ is a matrix, but to be able to feed it to solve_ivp, we will flatten it to a vector, and index it as $= n N_y + m $ where \(N_y\) is the number of modes in the y direction.

This is a system of first-order ODEs. In matrix form, we can write this as:

\[ \boxed{ \begin{aligned} \mathbf{\dot{\bar{u}}} &= \mathbf{\bar{v}} \\ \mathbf{\dot{\bar{v}}} &= -\mathbf{M_v} \mathbf{\bar{v}} - \mathbf{M_u} \mathbf{\bar{u}} + \mathbf{\bar{b}} \end{aligned} } \]

The Matrices \(\mathbf{M_v}\), \(\mathbf{M_y}\) are diagonal matrices whose order is equal to the number of modes \(M\).

\[ \begin{aligned} \mathbf{\Lambda} &= \text{diag}\left(\lambda_\mu\right) \\ \mathbf{M_v} &= \text{diag}\left(\frac{d_3 \lambda_\mu + d_1}{\rho h}\right) \\ \mathbf{M_y} &= \text{diag}\left(\frac{\beta_\mu}{\rho h}\right) \\ \mathbf{\bar{b}}(\mathbf{u}, \mathbf{\bar{u}}) &= - \frac{Q}{\rho \left(1-\nu^2\right)}\frac{2}{L_x^2 L_y^2}\mathbf{\Lambda} \left[ \mathbf{\bar{u}}^{T} \mathbf{\Lambda} \mathbf{\bar{u}} \right] \mathbf{\bar{u}} \\ \mathbf{\bar{b}}(\mathbf{u}, \mathbf{\bar{u}}) &= - \mathbf{C_b}\mathbf{\Lambda} \left[ \mathbf{\bar{u}}^{T} \mathbf{\Lambda} \mathbf{\bar{u}} \right] \mathbf{\bar{u}} \\ \end{aligned} \]

::: {#cell-9 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import numpy as np
from scipy.integrate import solve_ivp, simpson

:::

from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
import time

::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

class Wave2dSolverTensionModulated:
    """
    Tension modulated wave equation solver for a rectangular stiff membrane.
    The parameters were taken from (Fletcher, 1991, p.86) and adapted to a rectangular case.
    """

    def __init__(
        self,
        sampling_rate: int = 16000,  # 1/s     Temporal sampling frequency
        final_time: float = 0.5,  # s       Duration of the simulation
        n_gridpoints_x: int = 41,  # pts/m   Spatial sampling grid
        length_x: float = 0.4,  # m       Length of x dimension
        aspect_ratio: float = 0.9,  #        Aspect ratio of the membrane, Ly/Lx
        rho: float = 1380,  # kg/m**3 Density
        h: float = 1.9e-4,  # m       Thickness
        E: int = 3.5e9,  # Pa      Young's modulus
        nu: float = 0.3,  #         Poisson's ratio
        d1: float = 8e-5,  # kg/(ms) Frequency independent loss
        d3: float = 1.4e-5,  # kg m/s  Frequency dependent loss
        Ts0: float = 2620,  # N/m       Tension per unit length
        n_max_modes: int = 36,  #         Number of modal coordinates
        use_nonlinear: bool = True,  #         Use nonlinear wave equation
        include_energy: bool = False,  #         Calculate the energy of the system
    ):
        # TODO: use a ratio for the side lengths, and use the same ratio for gridpoints
        # Attributes intrinsic to the Wave Equation PDE
        self.pde_num_variables = 1
        self.pde_num_spatial_dimensions = 2
        self.pde_order_time_derivatives = 2

        # Attributes of the simulation
        self.sampling_rate = sampling_rate
        self.final_time = final_time
        self.n_gridpoints_x = n_gridpoints_x
        self.length_x = length_x
        self.aspect_ratio = aspect_ratio
        self.use_nonlinear = use_nonlinear
        self.include_energy = include_energy

        # Attributes of the membrane
        self.rho = rho
        self.h = h
        self.E = E
        self.nu = nu
        self.d1 = d1
        self.d3 = d3
        self.Ts0 = Ts0

        # Use the same number of modes in x and y, this might not be a good idea for aspect ratios far from 1
        self.n_max_modes_x = int(np.floor(np.sqrt(n_max_modes)))
        self.n_max_modes_y = int(np.floor(np.sqrt(n_max_modes)))
        self.n_max_modes = self.n_max_modes_x * self.n_max_modes_y
        # if n_max_modes > n_gridpoints - 2:
        #     self.n_max_modes = n_gridpoints - 2
        #     print(f"n_max_modes too high, setting to {self.n_max_modes}")
        # else:
        #     self.n_max_modes = n_max_modes

        x = np.linspace(0, self.length_x, self.n_gridpoints_x)
        self.dx = x[1] - x[0]  # m     spatial sampling interval

        # calculate the gridpoints and length in y using the aspect ratio
        # We want dy to be as close as possible to dx
        self.length_y = self.length_x * self.aspect_ratio
        self.n_gridpoints_y = (
            int(np.floor((self.n_gridpoints_x - 1) * self.aspect_ratio)) + 1
        )

        y = np.linspace(0, self.length_y, self.n_gridpoints_y)
        self.x = x
        self.y = y

        self.grid_x, self.grid_y = np.meshgrid(x, y, indexing="ij")
        #  spatial grid
        # self.grid_x = np.m

        self.dy = y[1] - y[0]  # m     spatial sampling interval

        #  temporal grid use arange to make sure that the timestep corresponds exactly to the sampling frequency
        self.dt = 1 / self.sampling_rate  # s     temporal sampling interval
        self.timesteps = np.arange(0, self.final_time, self.dt)

        self.mu_x = np.arange(1, self.n_max_modes_x + 1)
        self.mu_y = np.arange(1, self.n_max_modes_y + 1)
        self.wavenumbers_x = self.mu_x * np.pi / self.length_x
        self.wavenumbers_y = self.mu_y * np.pi / self.length_y

        self.grid_wavenumber_x, self.grid_wavenumber_y = np.meshgrid(
            self.wavenumbers_x, self.wavenumbers_y
        )
        self.modes = np.zeros(
            (
                self.n_max_modes_x,
                self.n_max_modes_y,
                self.n_gridpoints_x,
                self.n_gridpoints_y,
            )
        )
        self.lambdas = np.zeros((self.n_max_modes_x * self.n_max_modes_y))
        # modes and eigenvalues of the modes (lanbdas)
        for i, wx in enumerate(self.wavenumbers_x):
            for j, wy in enumerate(self.wavenumbers_y):
                self.modes[i, j, :, :] = np.sin(wx * self.grid_x) * np.sin(
                    wy * self.grid_y
                )
                self.lambdas[i * self.n_max_modes_y + j] = wx**2 + wy**2
        self.norm_factor = 4 / (
            self.length_x * self.length_y
        )  # normalization factor for the modes
        self.D = self.E * (self.h) ** 3 / (12 * (1 - self.nu**2))
        CNL = self.E * self.h / (1 - self.nu**2)

        beta_mu = self.D * self.lambdas**2 + self.Ts0 * self.lambdas

        # calculate the matrices
        self.Lambdadiag = np.diag(self.lambdas)
        # self.H_1 = np.diag(self.lambdas) / (self.rho * self.h)

        self.M_v = np.diag(self.d1 + self.d3 * self.lambdas) / (self.rho * self.h)

        self.M_u = np.diag(beta_mu) / (self.rho * self.h)

        # coeffficient for the nonlinear term
        self.Cb = (
            2
            * self.E
            / (self.rho * self.length_x**2 * self.length_y**2 * (1 - self.nu**2))
        )

    def print_matrices(self):
        print(f"Lambdadiag shape: {self.Lambdadiag.shape}")
        # print(f"H_1 shape: {self.H_1.shape}")
        print(f"M_v shape: {self.M_v.shape}")
        print(f"M_u shape: {self.M_u.shape}")
        return

    def print_solver_info(self):
        # Print some information
        print(f"dx: {self.dx} in meters")
        print(f"dy: {self.dy} in meters")
        print(f"dt: {self.dt} in seconds")
        print(
            f"number of points in the x direction (n_gridpoints_x): {self.n_gridpoints_x}"
        )
        print(
            f"number of points in the y direction (n_gridpoints_y): {self.n_gridpoints_y}"
        )
        print(f"time in samples (nt): {self.timesteps.shape}")
        print(f"number of modes in x direction (n_max_modes_x): {self.n_max_modes_x}")
        print(f"number of modes in y direction (n_max_modes_y): {self.n_max_modes_y}")
        print(
            f"number of modes (n_max_modes): {self.n_max_modes_x * self.n_max_modes_y}"
        )
        print(f"length in x direction (length_x): {self.length_x} in meters")
        print(f"length in y direction (length_y): {self.length_y} in meters")
        # Print the shapes of the grids and the wavenumbers
        print(f"grid_x shape: {self.grid_x.shape}")
        print(f"grid_y shape: {self.grid_y.shape}")
        print(f"wavenumbers_x shape: {self.wavenumbers_x.shape}")
        print(f"wavenumbers_y shape: {self.wavenumbers_y.shape}")
        print(f"modes shape: {self.modes.shape}")
        print(f"lambdas shape: {self.lambdas.shape}")
        return

    def to_modal(
        self,
        u,
        v,
        integrator: str = "simpson",  # displacement  # velocity
    ) -> tuple[np.ndarray, np.ndarray]:
        """Project the displacement and velocity to modal coordinates.
        Also flatten the arrays to be 1D.
        """
        bar_u = np.zeros(self.n_max_modes)
        # bar_z2 = np.zeros(self.n_max_modes)
        bar_v = np.zeros(self.n_max_modes)

        if integrator == "simpson":
            for i in range(self.n_max_modes_x):
                for j in range(self.n_max_modes_y):
                    uu = u * self.modes[i, j, :, :]
                    bar_u[i * self.n_max_modes_y + j] = simpson(
                        [simpson(uu_y, x=self.y) for uu_y in uu], x=self.x
                    )
                    vv = v * self.modes[i, j, :, :]
                    bar_v[i * self.n_max_modes_y + j] = simpson(
                        [simpson(vv_y, x=self.y) for vv_y in vv], x=self.x
                    )
        elif integrator == "trapz":
            # This is unverified, use simpson for now
            raise NotImplementedError
            # for i in range(self.n_max_modes_x):
            #     for j in range(self.n_max_modes_y):
            #         bar_u[i * self.n_max_modes_y + j] = self.dx*self.dy*np.sum(self.modes[i, j, :, :] * u)
            #         # bar_z_dot[i * self.n_max_modes_y + j] = np.sum(
            #     self.modes[i, j, :, :] * z_dot
            # )
        else:
            raise ValueError(f"Integrator {integrator} not recognised")

        return bar_u, bar_v

    def to_modal_vecs(
        self,
        u,
        v,
        integrator: str = "simpson",  # displacement  # velocity
    ) -> tuple[np.ndarray, np.ndarray]:
        """Project the displacement and velocity to modal coordinates.
        Also flatten the arrays to be 1D.
        """
        bar_u = np.zeros((u.shape[0], self.n_max_modes))
        bar_v = np.zeros((u.shape[0], self.n_max_modes))
        if integrator == "simpson":
            for i in range(self.n_max_modes_x):
                for j in range(self.n_max_modes_y):
                    uu = u * self.modes[i, j, :, :]
                    bar_u[:, i * self.n_max_modes_y + j] = simpson(
                        [simpson(uu_y, x=self.y) for uu_y in uu], x=self.x
                    )
                    vv = v * self.modes[i, j, :, :]
                    bar_v[:, i * self.n_max_modes_y + j] = simpson(
                        [simpson(vv_y, x=self.y) for vv_y in vv], x=self.x
                    )
        elif integrator == "trapz":
            raise NotImplementedError
        else:
            raise ValueError(f"Integrator {integrator} not recognised")
        return bar_u, bar_v

    def to_displacement(
        self,
        bar_u,  # modal displacement
        bar_v,  # modal velocity
    ) -> tuple[np.ndarray, np.ndarray]:
        """Sum the modal displacements and velocities to get the displacement and velocity"""

        u = np.zeros((self.n_gridpoints_x, self.n_gridpoints_y))
        v = np.zeros((self.n_gridpoints_x, self.n_gridpoints_y))
        norm_factor = self.norm_factor
        for i in range(self.n_max_modes_x):
            for j in range(self.n_max_modes_y):
                u += bar_u[i * self.n_max_modes_y + j] * self.modes[i, j, :, :]
                v += bar_v[i * self.n_max_modes_y + j] * self.modes[i, j, :, :]
        u *= norm_factor
        v *= norm_factor
        return u, v

    def enforce_boundary_conditions(self, u, v):
        """Enforce dirichlet boundary conditions.
        ATTN: remember v is respect to time, not x,y
        """
        u[0, :] = 0
        u[-1, :] = 0
        u[:, 0] = 0
        u[:, -1] = 0
        v[0, :] = 0
        v[-1, :] = 0
        v[:, 0] = 0
        v[:, -1] = 0
        return u, v

    def nonlinear_membrane(
        self,
        t,  # time
        state,  # state vector
    ) -> np.ndarray:  # state at timestep t and position x (u(x, t))
        # unpack the state vector
        bar_u = state[: self.n_max_modes]  # displacement in modal coordinates
        bar_v = state[self.n_max_modes :]  # velocity in modal coordinates

        # add axis for calculation
        bar_u = bar_u[..., None]

        # calculate the $\bar{b}$ vector
        # ATTN: this b doesn't include the minus sign
        if self.use_nonlinear:
            deformation = bar_u.T @ self.Lambdadiag @ bar_u
            b = self.Cb * deformation * self.Lambdadiag @ bar_u

        # update the state vector
        # Why .squeeze()? Is it really needed?
        bar_udot = bar_v
        bar_vdot = -self.M_v @ bar_v - self.M_u @ bar_u.squeeze()
        if self.use_nonlinear:
            bar_vdot -= b.squeeze()

        # return the state derivatives
        return np.concatenate([bar_udot, bar_vdot])

    def calculate_energy(
        self,
        bar_u: np.ndarray,  # modal coordinates of the displacement
        bar_v: np.ndarray,  # modal coordinates of the velocity
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Calculate the kinetic, elastic potential, stiffness potential and nonlinear potential energy of the system.

        Args:
            bar_u (np.ndarray): modal coordinates of the displacement. Shape (nt, n_max_modes)
            bar_v (np.ndarray): modal coordinates of the velocity. Shape (nt, n_max_modes)

        Returns:
            np.ndarray: _description_
        """
        # calculate the kinetic energy
        energy_kinetic = (
            0.5 * self.rho * self.h * self.norm_factor * np.sum(bar_v**2, axis=-1)
        )
        # calculate the potential energy
        bar_uu = bar_u**2
        summation = np.sum(bar_uu * self.lambdas, axis=-1)
        energy_potential_elastic_linear = 0.5 * self.Ts0 * self.norm_factor * summation
        energy_potential_stiffness_linear = (
            0.5 * self.D * self.norm_factor * np.sum(bar_uu * self.lambdas**2, axis=-1)
        )
        if self.use_nonlinear:
            TNL = (
                summation
                * self.norm_factor
                * self.E
                * self.h
                / (2 * (1 - self.nu**2) * self.length_x * self.length_y)
            )
            energy_potential_elastic_nonlinear = (
                0.5 * TNL * self.norm_factor * summation
            )
        else:
            energy_potential_elastic_nonlinear = np.zeros_like(
                energy_potential_elastic_linear
            )

        return (
            energy_kinetic,
            energy_potential_elastic_linear,
            energy_potential_stiffness_linear,
            energy_potential_elastic_nonlinear,
        )

    def solve(
        self,
        u0: np.ndarray = None,  # initial displacement (default: None)
        v0: np.ndarray = None,  # initial velocity (default: None)
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Solve the wave equation with the given initial conditions."""

        # set the initial conditions if not given
        u0 = u0 if u0 is not None else np.zeros_like(self.grid_x)
        v0 = v0 if v0 is not None else np.zeros_like(self.grid_x)

        # Enforce the simply suppported boundary conditions a tht e edges of the membrane
        u0, v0 = self.enforce_boundary_conditions(u0, v0)
        # transform the initial conditions to modal coordinates
        bar_u0, bar_v0 = self.to_modal(u0, v0)
        # solve the wave equation in modal coordinates
        sol = solve_ivp(
            fun=self.nonlinear_membrane,
            t_span=[0, self.final_time],
            y0=np.concatenate([bar_u0, bar_v0], axis=0),
            t_eval=self.timesteps,
            method="DOP853",
            rtol=1e-12,
            atol=1e-14,
        )
        # sol has dimensions (n_max_modes*2, nt)

        # unpack the solution
        bar_u = (sol.y[: self.n_max_modes]).T
        bar_v = (sol.y[self.n_max_modes :]).T

        # Print the shapes of the solutions
        print(f"bar_u shape: {bar_u.shape}")
        print(f"bar_v shape: {bar_v.shape}")
        # transform back to the physical domain
        # This loop is probably extremely slow but is only done once
        u = np.zeros(
            (
                len(sol.t),
                self.n_gridpoints_x,
                self.n_gridpoints_y,
            )
        )
        v = np.zeros(
            (
                len(sol.t),
                self.n_gridpoints_x,
                self.n_gridpoints_y,
            )
        )
        for i in range(len(sol.t)):
            u[i, :, :], v[i, :, :] = self.to_displacement(bar_u[i, :], bar_v[i, :])

        # calculate the energy of the system
        if self.include_energy:
            (
                energy_kinetic,
                energy_potential_elastic_linear,
                energy_potential_stiffness_linear,
                energy_potential_elastic_nonlinear,
            ) = self.calculate_energy(bar_u, bar_v)
            return (
                sol.t,
                u,
                v,
                energy_kinetic,
                energy_potential_elastic_linear,
                energy_potential_stiffness_linear,
                energy_potential_elastic_nonlinear,
            )
        return (
            sol.t,
            u,
            v,
        )

:::

sampling_rate = 16000
# Initialize the solver
solver = Wave2dSolverTensionModulated(
    sampling_rate=sampling_rate,
    final_time=0.1,
    n_max_modes=225,
    n_gridpoints_x=41,
    Ts0=800,
    d1=0.5,
    d3=5e-3,
    use_nonlinear=True,
    include_energy=True,
)
# Plot some of the modes
nx = 3
ny = 2
fig, ax = plt.subplots(nx, ny, figsize=(5, 15))
aspect_ratio_dxdy = solver.dx / solver.dy
# Print this aspect ratio and the one used in the solver
solver.print_solver_info()
print(f"Aspect ratio: {aspect_ratio_dxdy}")
print(f"Solver aspect ratio: {solver.aspect_ratio}")
for i in range(nx):
    for j in range(ny):
        print(f"Mode max value: {np.max(solver.modes[i, j, :, :])}")
        ax[i, j].imshow(
            solver.modes[i, j, :, :], aspect=aspect_ratio_dxdy, origin="lower"
        )
        ax[i, j].set_title(f"Mode nx={i + 1}, ny={j + 1}")
        ax[i, j].set_xlabel("y")
        ax[i, j].set_ylabel("x")
        # ax[i, j].tight_layout()
# Compare the performance of to_modal and to_modal_vecs


# Generate a tensor of states
n_states = 10
u = np.random.randn(n_states, solver.n_gridpoints_x, solver.n_gridpoints_y)
v = np.random.randn(n_states, solver.n_gridpoints_x, solver.n_gridpoints_y)

# Vectorised version
start = time.time()
bar_u_vec, bar_v_vec = solver.to_modal_vecs(u, v)
end = time.time()
print(f"Vectorised version took {end - start} seconds")

# Non-vectorised version
start = time.time()
bar_u = np.zeros((n_states, solver.n_max_modes))
bar_v = np.zeros((n_states, solver.n_max_modes))
for i in range(n_states):
    bar_u[i], bar_v[i] = solver.to_modal(u[i], v[i])
end = time.time()
print(f"Non-vectorised version took {end - start} seconds")

# Compare the results
print(f"bar_u_vec shape: {bar_u_vec.shape}")
print(f"bar_u shape: {bar_u.shape}")
print(f"bar_v_vec shape: {bar_v_vec.shape}")
print(f"bar_v shape: {bar_v.shape}")
print(f"bar_u_vec - bar_u: {np.max(np.abs(bar_u_vec - bar_u))}")
print(f"bar_v_vec - bar_v: {np.max(np.abs(bar_v_vec - bar_v))}")
def gaussian_pulse(x_grid, y_grid, x0, y0, sigma):
    return np.exp(-((x_grid - x0) ** 2 + (y_grid - y0) ** 2) / (2 * sigma**2))


def noise(grid):
    return np.random.randn(grid.shape)
# Solve the wave equation
# u0 = np.zeros((solver.n_gridpoints_x, solver.n_gridpoints_y))
u0 = solver.modes[0, 0, :, :]
# Create a gaussian impulse
# u0 = np.zeros((solver.n_gridpoints_x, solver.n_gridpoints_y))
u0 = np.random.randn(solver.n_gridpoints_x, solver.n_gridpoints_y)
ctr = (0.2 * solver.length_x, 0.6 * solver.length_y)
std = 0.05 * solver.length_x
for i in range(solver.n_gridpoints_x):
    for j in range(solver.n_gridpoints_y):
        u0[i, j] = np.exp(
            -((solver.grid_x[i, j] - ctr[0]) ** 2 + (solver.grid_y[i, j] - ctr[1]) ** 2)
            / (2 * std**2)
        )

u0_alt = gaussian_pulse(solver.grid_x, solver.grid_y, ctr[0], ctr[1], std)
# Compare the two initial conditions
print(np.allclose(u0, u0_alt))

# v0 = np.zeros_like(u0)
# u0 = 0.01*u0
v0 = 20 * u0
u0 = np.zeros_like(u0)
# Add a delta impuse to the initial velocity, close to the center
# v0[solver.n_gridpoints_x//2, solver.n_gridpoints_y//2] = 20

t, u, v, e_k, e_ple, e_pls, e_pne = solver.solve(u0=u0, v0=v0)
print(f"t shape: {t.shape}")
print(f"u shape: {u.shape}")
print(f"v shape: {v.shape}")
# Check that the boundary conditions are maintained
print(np.all(u[:, 0, :] == 0))
print(np.all(u[:, -1, :] == 0))
print(np.all(u[:, :, 0] == 0))
print(np.all(u[:, :, -1] == 0))

# Plot the sum error of the solution at the edges for all timesteps
plt.figure()
plt.plot(t, np.sum(u[:, 0, :], axis=1), label="bottom")
plt.plot(t, np.sum(u[:, -1, :], axis=1), label="top")
plt.plot(t, np.sum(u[:, :, 0], axis=1), label="left")
plt.plot(t, np.sum(u[:, :, -1], axis=1), label="right")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Sum of displacement at the edges")
plt.title("Boundary conditions")
plt.show()
vmax_u = np.max(np.abs(u))
vmax_v = np.max(np.abs(v))

# N_plots=100
fig, ax = plt.subplots(1, 2)


def animate(i):
    ax[0].clear()
    ax[1].clear()
    ax[0].imshow(
        u[i, :, :], aspect=aspect_ratio_dxdy, origin="lower", vmin=-vmax_u, vmax=vmax_u
    )
    ax[1].imshow(
        v[i, :, :], aspect=aspect_ratio_dxdy, origin="lower", vmin=-vmax_v, vmax=vmax_v
    )
    ax[0].set_title(f"Displacement")
    ax[0].set_xlabel("y")
    ax[0].set_ylabel("x")
    ax[1].set_title(f"Velocity")
    ax[1].set_xlabel("y")
    ax[1].set_ylabel("x")


fig.set_figheight(5)
fig.set_figwidth(10)
ani = FuncAnimation(fig, animate, frames=range(0, 320, 50), interval=300, repeat=False)
plt.close()
from IPython.display import HTML

HTML(ani.to_jshtml())
# Max amplitude after t_late seconds of the simulation
t_late = 0.04
sample_ind = int(t_late * sampling_rate)
max_amplitude = np.max(np.abs(u[sample_ind, :, :]))
print(f"Max amplitude after {t_late}s : {max_amplitude}")
# Max speed in the last 1 seconds of the simulation
max_speed = np.max(np.abs(v[sample_ind, :, :]))
print(f"Max speed after {t_late}s : {max_speed}")

# Plot state at the beginning and after 2 seconds of the simulation
fig, ax = plt.subplots(2, 2)
ax[0, 0].imshow(
    u[0, :, :], aspect=aspect_ratio_dxdy, origin="lower", vmin=-vmax_u, vmax=vmax_u
)
ax[0, 0].set_title("Initial displacement")
ax[0, 1].imshow(
    v[0, :, :], aspect=aspect_ratio_dxdy, origin="lower", vmin=-vmax_v, vmax=vmax_v
)
ax[0, 1].set_title("Initial velocity")
ax[1, 0].imshow(
    u[sample_ind, :, :],
    aspect=aspect_ratio_dxdy,
    origin="lower",
    vmin=-vmax_u,
    vmax=vmax_u,
)
ax[1, 0].set_title("Final displacement")
ax[1, 1].imshow(
    v[sample_ind, :, :],
    aspect=aspect_ratio_dxdy,
    origin="lower",
    vmin=-vmax_v,
    vmax=vmax_v,
)
ax[1, 1].set_title("Final velocity")
plt.show()
# PLot the spectrogram for the output of a single point o the membrane
import matplotlib.pyplot as plt
from scipy.signal import spectrogram
import IPython.display as ipd

# print the shape of u and v
print("u shape:", u.shape)
print("v shape:", v.shape)

# Print the maximum value of u and v
print("Max u:", np.max(np.abs(u)))
print("Max v:", np.max(np.abs(v)))

out_point = (int(u.shape[1] * np.random.random()), int(u.shape[2] * np.random.random()))
# Check the that the output point is within the bounds of the data
assert out_point[0] < u.shape[1]
assert out_point[1] < u.shape[2]

u_out = u[:, out_point[0], out_point[1]]

# plot the time domain output of a single point on the membrane
plt.figure()
plt.plot(u_out)
plt.title(f"u at point {out_point}")
plt.figure()
plt.show()

display(ipd.Audio(u_out, rate=sampling_rate))

# PLot the spectrogram for the output of a single point on the membrane
win_size = 512
win_over = win_size // 8
f, time_spec, Sxx = spectrogram(
    u_out,
    fs=sampling_rate,
    nperseg=win_size,
    nfft=win_size,
    noverlap=win_size - win_over,
)
Sxx_norm = Sxx / Sxx.max()
Sxx_norm_db = 10 * np.log10(Sxx_norm)
# Make all values below -60 dB to -1000 dB
Sxx_norm_db[Sxx_norm_db < -60] = -60
plt.figure()
plt.pcolormesh(time_spec, f, Sxx_norm_db)
plt.ylabel("Frequency [Hz]")
plt.semilogy()
# plt.yscale('log')
plt.ylim(100, 1000)
plt.yticks(
    [10, 100, 200, 300, 400, 600, 800, 2000], [10, 100, 200, 300, 400, 600, 800, 2000]
)
# plt.xlim(0, 0.25)
plt.xlabel("Time [sec]")
plt.title(f"Spectrogram of u at point {out_point}")
plt.colorbar(label="dB")
plt.show()
from scipy.sparse import diags, kron, eye


# Functions for calculating the energy of the system
def calculate_energy_kinetic(solver, v):
    vv = v**2
    integral = simpson([simpson(vv_y, x=solver.y) for vv_y in vv], x=solver.x)
    integral_trapz = solver.dx * solver.dy * np.sum(vv)
    # print(f"Integral simpson: {integral}")
    # print(f"Integral trapz: {integral_trapz}")
    return solver.rho * solver.h * integral / 2


def calculate_energy_potential_elastic_linear(solver, u):
    Nx = solver.n_gridpoints_x
    Ny = solver.n_gridpoints_y
    dx = solver.dx
    dy = solver.dy

    # create finite difference matrices
    # Dxx
    Dxx = diags(
        [1, -2, 1], [-1, 0, 1], shape=(Nx, Nx)
    ).toarray()  # Create sparse matrix Dxx
    Dxx /= dx**2  # Divide matrix Dxx by h**2
    # Dyy
    Dyy = diags(
        [1, -2, 1], [-1, 0, 1], shape=(Ny, Ny)
    ).toarray()  # Create sparse matrix Dyy
    Dyy /= dy**2  # Divide matrix Dyy by h**2
    # Implement dxxu0 = 0 boundary conditions (second order, centered)
    Dxx[0, 0] = 0
    Dxx[-1, 0] = 0
    Dyy[0, 0] = 0
    Dyy[-1, 0] = 0

    # Implement dirichlet boundary conditions
    Dxx[0, :] = 0
    Dxx[-1, :] = 0
    Dyy[0, :] = 0
    Dyy[-1, :] = 0
    Dxx2 = kron(Dxx, eye(Ny))
    Dyy2 = kron(eye(Nx), Dyy)
    nabla2 = Dxx2 + Dyy2

    # Flatten the displacement
    u_flat = u.reshape(-1, 1)

    integrand_flat = (nabla2 @ u_flat) * u_flat
    integrand = integrand_flat.reshape(Nx, Ny)
    integral = simpson(
        [simpson(integrand_y, x=solver.y) for integrand_y in integrand], x=solver.x
    )
    return -solver.Ts0 * integral / 2


def calculate_energy_potential_stiffness_linear(solver, u):
    Nx = solver.n_gridpoints_x
    Ny = solver.n_gridpoints_y
    dx = solver.dx
    dy = solver.dy

    # create finite difference matrices
    # Dxx
    Dxx = diags(
        [1, -2, 1], [-1, 0, 1], shape=(Nx, Nx)
    ).toarray()  # Create sparse matrix Dxx
    Dxx /= dx**2  # Divide matrix Dxx by h**2
    # Dyy
    Dyy = diags(
        [1, -2, 1], [-1, 0, 1], shape=(Ny, Ny)
    ).toarray()  # Create sparse matrix Dyy
    Dyy /= dy**2  # Divide matrix Dyy by h**2
    # Implement dxxu0 = 0 boundary conditions (second order, centered)
    Dxx[0, 0] = 0
    Dxx[-1, 0] = 0
    Dyy[0, 0] = 0
    Dyy[-1, 0] = 0

    # Implement dirichlet boundary conditions
    Dxx[0, :] = 0
    Dxx[-1, :] = 0
    Dyy[0, :] = 0
    Dyy[-1, :] = 0
    Dxx2 = kron(Dxx, eye(Ny))
    Dyy2 = kron(eye(Nx), Dyy)
    nabla2 = Dxx2 + Dyy2
    nabla4 = nabla2 @ nabla2

    # Flatten the displacement
    u_flat = u.reshape(-1, 1)

    integrand_flat = (nabla4 @ u_flat) * u_flat
    integrand = integrand_flat.reshape(Nx, Ny)
    integral = simpson(
        [simpson(integrand_y, x=solver.y) for integrand_y in integrand], x=solver.x
    )
    return solver.D * integral / 2


def calculate_energy_potential_elastic_nonlinear(solver, u):
    Nx = solver.n_gridpoints_x
    Ny = solver.n_gridpoints_y
    dx = solver.dx
    dy = solver.dy

    # create finite difference matrices
    # Dxx
    Dx = diags(
        [-1, 0, 1], [-1, 0, 1], shape=(Nx, Nx)
    ).toarray()  # Create sparse matrix Dxx
    Dx /= 2 * dx  # Divide matrix Dxx by h**2
    # Dyy
    Dy = diags(
        [-1, 0, 1], [-1, 0, 1], shape=(Ny, Ny)
    ).toarray()  # Create sparse matrix Dyy
    Dy /= 2 * dy  # Divide matrix Dyy by h**2
    # Implement dxxu0 = 0 boundary conditions (second order, centered)
    Dx[0, 0] = 0
    Dx[-1, 0] = 0
    Dy[0, 0] = 0
    Dy[-1, 0] = 0

    # Implement dirichlet boundary conditions
    Dx[0, :] = 0
    Dx[-1, :] = 0
    Dy[0, :] = 0
    Dy[-1, :] = 0
    Dx2 = kron(Dx, eye(Ny))
    Dy2 = kron(eye(Nx), Dy)  # Flatten the displacement
    u_flat = u.reshape(-1, 1)

    # Calculate the gradient of the displacement
    Dx2u_sq = (Dx2 @ u_flat) ** 2
    Dy2u_sq = (Dy2 @ u_flat) ** 2

    integrand_flat_T = Dx2u_sq + Dy2u_sq
    integrand_T = integrand_flat_T.reshape(Nx, Ny)
    integral_T = simpson(
        [simpson(integrand_y, x=solver.y) for integrand_y in integrand_T], x=solver.x
    )

    TNL = (
        solver.E
        * solver.h
        / (2 * (1 - solver.nu**2) * solver.length_x * solver.length_y)
        * integral_T
    )

    # create finite difference matrices
    # Dxx
    Dxx = diags(
        [1, -2, 1], [-1, 0, 1], shape=(Nx, Nx)
    ).toarray()  # Create sparse matrix Dxx
    Dxx /= dx**2  # Divide matrix Dxx by h**2
    # Dyy
    Dyy = diags(
        [1, -2, 1], [-1, 0, 1], shape=(Ny, Ny)
    ).toarray()  # Create sparse matrix Dyy
    Dyy /= dy**2  # Divide matrix Dyy by h**2
    # Implement dxxu0 = 0 boundary conditions (second order, centered)
    Dxx[0, 0] = 0
    Dxx[-1, 0] = 0
    Dyy[0, 0] = 0
    Dyy[-1, 0] = 0

    # Implement dirichlet boundary conditions
    Dxx[0, :] = 0
    Dxx[-1, :] = 0
    Dyy[0, :] = 0
    Dyy[-1, :] = 0
    Dxx2 = kron(Dxx, eye(Ny))
    Dyy2 = kron(eye(Nx), Dyy)
    nabla2 = Dxx2 + Dyy2

    integrand_flat = (nabla2 @ u_flat) * u_flat
    integrand = integrand_flat.reshape(Nx, Ny)
    integral = simpson(
        [simpson(integrand_y, x=solver.y) for integrand_y in integrand], x=solver.x
    )

    return -TNL * integral / 2


# Calculate the energy of the system over time, plot the kinetic, linear potential and total energy
block_size = len(t) - 1
start_t_samples = 0
t_energy = t[start_t_samples : start_t_samples + block_size]
energy_kinetic = np.zeros_like(t_energy)
energy_potential_elastic_linear = np.zeros_like(t_energy)
energy_potential_stiffness_linear = np.zeros_like(t_energy)
energy_potential_elastic_nonlinear = np.zeros_like(t_energy)
energy_total = np.zeros_like(t_energy)
for i in range(block_size):
    energy_kinetic[i] = calculate_energy_kinetic(solver, v[start_t_samples + i, :, :])
    energy_potential_elastic_linear[i] = calculate_energy_potential_elastic_linear(
        solver, u[start_t_samples + i, :, :]
    )
    energy_potential_stiffness_linear[i] = calculate_energy_potential_stiffness_linear(
        solver, u[start_t_samples + i, :, :]
    )
    energy_potential_elastic_nonlinear[i] = (
        calculate_energy_potential_elastic_nonlinear(
            solver, u[start_t_samples + i, :, :]
        )
    )

if solver.use_nonlinear:
    energy_total = (
        energy_kinetic
        + energy_potential_elastic_linear
        + energy_potential_stiffness_linear
        + energy_potential_elastic_nonlinear
    )
else:
    energy_total = (
        energy_kinetic
        + energy_potential_elastic_linear
        + energy_potential_stiffness_linear
    )

energy_linear = (
    energy_kinetic + energy_potential_elastic_linear + energy_potential_stiffness_linear
)
# Calculate the energy of the system using the modal coordinates
from matplotlib.pylab import solve
from matplotlib.pyplot import bar


def calculate_energy_kinetic_modal(solver, bar_v):
    bar_vv = bar_v**2
    summation = np.sum(bar_vv)
    return solver.rho * solver.h * solver.norm_factor * summation / 2


def calculate_energy_potential_elastic_linear_modal(solver, bar_u):
    bar_uu = bar_u**2
    summation = np.sum(bar_uu * solver.lambdas)
    return solver.Ts0 * solver.norm_factor * summation / 2


def calculate_energy_potential_stiffness_linear_modal(solver, bar_u):
    bar_uu = bar_u**2
    summation = np.sum(bar_uu * solver.lambdas**2)
    return solver.D * solver.norm_factor * summation / 2


def calculate_energy_potential_elastic_nonlinear_modal(solver, bar_u):
    bar_uu = bar_u**2
    summation = np.sum(bar_uu * solver.lambdas)
    TNL = (
        summation
        * solver.norm_factor
        * solver.E
        * solver.h
        / (2 * (1 - solver.nu**2) * solver.length_x * solver.length_y)
    )
    return TNL * solver.norm_factor * summation / 2


# Calculate the energy of the system over time, plot the kinetic, linear potential and total energy
energy_kinetic_modal = np.zeros_like(t_energy)
energy_potential_elastic_linear_modal = np.zeros_like(t_energy)
energy_potential_stiffness_linear_modal = np.zeros_like(t_energy)
energy_potential_elastic_nonlinear_modal = np.zeros_like(t_energy)

# Time these calculations
start_time = time.time()

bar_u_vec, bar_v_vec = solver.to_modal_vecs(
    u[start_t_samples : start_t_samples + block_size],
    v[start_t_samples : start_t_samples + block_size],
)
end_time = time.time()
print(f"Time to calculate modal coordinates: {end_time - start_time}")
start_time = time.time()
for i in range(block_size):
    bar_u, bar_v = bar_u_vec[i], bar_v_vec[i]
    energy_kinetic_modal[i] = calculate_energy_kinetic_modal(solver, bar_v)
    energy_potential_elastic_linear_modal[i] = (
        calculate_energy_potential_elastic_linear_modal(solver, bar_u)
    )
    energy_potential_stiffness_linear_modal[i] = (
        calculate_energy_potential_stiffness_linear_modal(solver, bar_u)
    )
    energy_potential_elastic_nonlinear_modal[i] = (
        calculate_energy_potential_elastic_nonlinear_modal(solver, bar_u)
    )

end_time = time.time()
print(f"Time to calculate energy in modal coordinates: {end_time - start_time}")
if solver.use_nonlinear:
    energy_total_modal = (
        energy_kinetic_modal
        + energy_potential_elastic_linear_modal
        + energy_potential_stiffness_linear_modal
        + energy_potential_elastic_nonlinear_modal
    )
else:
    energy_total_modal = (
        energy_kinetic_modal
        + energy_potential_elastic_linear_modal
        + energy_potential_stiffness_linear_modal
    )

energy_linear_modal = (
    energy_kinetic_modal
    + energy_potential_elastic_linear_modal
    + energy_potential_stiffness_linear_modal
)
# Compare the energy of the system using the modal coordinates and the physical coordinates

# Kinetic energy
plt.figure()
plt.plot(t_energy, energy_kinetic, label="Kinetic physical")
plt.plot(t_energy, energy_kinetic_modal, linestyle="--", label="Kinetic modal")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Kinetic energy of the system")
plt.show()

# Linear potential elastic energy
plt.figure()
plt.plot(
    t_energy, energy_potential_elastic_linear, label="Linear potential elastic physical"
)
plt.plot(
    t_energy,
    energy_potential_elastic_linear_modal,
    linestyle="--",
    label="Linear potential elastic modal",
)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Linear potential elastic energy of the system")
plt.show()

# Linear potential stiffness energy
plt.figure()
plt.plot(
    t_energy,
    energy_potential_stiffness_linear,
    label="Linear potential stiffness physical",
)
plt.plot(
    t_energy,
    energy_potential_stiffness_linear_modal,
    linestyle="--",
    label="Linear potential stiffness modal",
)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Linear potential stiffness energy of the system")
plt.show()

# Nonlinear potential elastic energy
plt.figure()
plt.plot(
    t_energy,
    energy_potential_elastic_nonlinear,
    label="Nonlinear potential elastic physical",
)
plt.plot(
    t_energy,
    energy_potential_elastic_nonlinear_modal,
    linestyle="--",
    label="Nonlinear potential elastic modal",
)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Nonlinear potential elastic energy of the system")
plt.show()

# Total energy
plt.figure()
plt.plot(t_energy, energy_total, label="Total physical")
plt.plot(t_energy, energy_total_modal, linestyle="--", label="Total modal")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Total energy of the system")
plt.show()

# Linear energy
plt.figure()
plt.plot(t_energy, energy_linear, label="Linear physical")
plt.plot(t_energy, energy_linear_modal, linestyle="--", label="Linear modal")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Total linear energy of the system")
plt.show()
# Plot the difference between the two energies
#  Total energy
plt.figure()
plt.plot(t_energy, energy_total - energy_total_modal, label="Total")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Difference between the two total energies")
plt.show()

# Linear energy
plt.figure()
plt.plot(t_energy, energy_linear - energy_linear_modal, label="Linear")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Difference between the two linear energies")
plt.show()

# Kinetic energy
plt.figure()
plt.plot(t_energy, energy_kinetic - energy_kinetic_modal, label="Kinetic")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Difference between the two kinetic energies")
plt.show()

# Linear potential elastic energy
plt.figure()
plt.plot(
    t_energy,
    energy_potential_elastic_linear - energy_potential_elastic_linear_modal,
    label="Linear potential elastic",
)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Difference between the two linear potential elastic energies")
plt.show()

# Linear potential stiffness energy
plt.figure()
plt.plot(
    t_energy,
    energy_potential_stiffness_linear - energy_potential_stiffness_linear_modal,
    label="Linear potential stiffness",
)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Difference between the two linear potential stiffness energies")
plt.show()

# Nonlinear potential elastic energy
plt.figure()
plt.plot(
    t_energy,
    energy_potential_elastic_nonlinear - energy_potential_elastic_nonlinear_modal,
    label="Nonlinear potential elastic",
)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.title("Difference between the two nonlinear potential elastic energies")
plt.show()
# The enerygy of the system is not conserved, should be the  same as the Kinetic energy initially added
# Plot the change in energy over time
assert np.allclose(energy_total[0], energy_kinetic[0])
assert np.allclose(energy_total_modal[0], energy_kinetic_modal[0])

plt.figure()
# plt.plot(t_energy, energy_total/energy_total[0] -1, label="Total")
plt.plot(t_energy, energy_total_modal / energy_total_modal[0] - 1, label="Total modal")
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.xlim(0.0, 0.02)
plt.title("Change in total energy of the system")
plt.show()
# Plot the nonlinear energy as a fraction of the total energy
plt.figure()
plt.plot(
    t_energy,
    energy_potential_elastic_nonlinear / energy_total,
    label="Nonlinear physical",
)
plt.plot(
    t_energy,
    energy_potential_elastic_nonlinear_modal / energy_total_modal,
    linestyle="--",
    label="Nonlinear modal",
)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Energy")
plt.xlim(0.0, 0.02)
plt.title("Nonlinear potential elastic energy as a fraction of the total energy")
plt.show()
# Check that the energies calculated internally are the same as the ones calculated externally
assert np.allclose(energy_kinetic_modal[:300], e_k[:300])
assert np.allclose(energy_potential_elastic_linear_modal[:300], e_ple[:300])
assert np.allclose(energy_potential_stiffness_linear_modal[:300], e_pls[:300])
assert np.allclose(energy_potential_elastic_nonlinear_modal[:300], e_pne[:300])