from physmodjax.solver.generator import Gaussian
Finite element solver for the 1D wave equation
A finite-element solver for the 1D wave equation.
We use finite elements to obtain the mass \(M\) and stiffness \(K\) matrices. We use bilinear discretization for the time stepping.
\[ \begin{bmatrix} \dot{\mathbf{x}} \\ \ddot{\mathbf{x}} \end{bmatrix} = \begin{bmatrix} 0 & \mathbf{I} \\ -\mathbf{M}^{-1}\mathbf{K} & 0 \end{bmatrix} \begin{bmatrix} \mathbf{x} \\ \dot{\mathbf{x}} \end{bmatrix} + \begin{bmatrix} 0 \\ \mathbf{M}^{-1} \end{bmatrix} \mathbf{f} \]
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import skfem
from skfem.helpers import dot, grad
from skfem import BilinearForm, ElementLineP1, Basis
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple
import jax.numpy as jnp
:::
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def discretize(A, step):
"""
Jax compatible bilinear discretization from https://github.com/srush/annotated-s4
"""
= jnp.eye(A.shape[0])
I = jnp.linalg.inv(I - (step / 2.0) * A) @ (I + (step / 2.0) * A)
Ab return Ab
:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class Wave1dSolverFE:
"""
This class solves the 1D wave equation using finite elements and state space discretization.
"""
def __init__(
self,
float, # sampling rate in Hz
sampling_rate: float, # final time in seconds
final_time: float, # length of the string in meters
length: int, # number of points in the string
n_gridpoints: float = 1, # wave speed in m/s
wave_speed:
):self.sampling_rate = sampling_rate
self.final_time = final_time
self.length = length
self.wave_speed = wave_speed
self.n_gridpoints = n_gridpoints
self.dt = 1 / self.sampling_rate
self.timesteps = np.arange(0, self.final_time, self.dt)
self.grid = np.linspace(0, self.length, self.n_gridpoints)
self.dx = self.grid[1] - self.grid[0]
self.pde_order_time_derivatives = 2
print(f"dx: {self.dx} in meters")
print(f"dt: {self.dt} in seconds")
print(f"number of points (n_gridpoints): {self.grid.shape}")
print(f"time in samples (nt): {self.timesteps.shape}")
# set constants
= 1
rho = self.wave_speed**2 * rho
k
# assemble the mass and stiffness matrices
@BilinearForm
def mass(u, v, w):
return rho * u * v
@BilinearForm
def laplace(u, v, _):
return k * dot(grad(u), grad(v))
= skfem.MeshLine(np.linspace(0, 1, n_gridpoints))
mesh = ElementLineP1()
element = Basis(mesh, element)
basis
= skfem.asm(laplace, basis)
K = skfem.asm(mass, basis)
M = skfem.condense(K, M, D=mesh.boundary_nodes())
K, M, _, _
# convert to dense matrices
# and get -inv(M) @ K
= -np.linalg.inv(M.todense()) @ K.todense()
L
# get the state transition matrix
= np.eye(L.shape[0])
I = np.zeros_like(L)
zeros = np.block([[zeros, I], [L, zeros]])
A
# get the discrete state transition matrix
self.A_d = discretize(A, self.dt)
def solve(
self,
# initial position
u0: np.ndarray, # initial velocity
v0: np.ndarray, -> Tuple[np.ndarray, np.ndarray, np.ndarray]: # Returns time, position, velocity
) = self.n_gridpoints - 2
gridpoints_no_boundary = np.concatenate([u0[1:-1], v0[1:-1]])[..., None]
h = np.zeros((self.timesteps.shape[0], self.n_gridpoints))
u = np.zeros_like(u)
v for idx, _ in enumerate(self.timesteps):
= self.A_d @ h
h 1:-1] = h[:gridpoints_no_boundary].squeeze()
u[idx, 1:-1] = h[gridpoints_no_boundary:].squeeze()
v[idx,
return self.timesteps, u, v
:::
Test
= 200
n_gridpoints = Wave1dSolverFE(
solver =2000,
sampling_rate=1,
final_time=1,
length=n_gridpoints,
n_gridpoints=10,
wave_speed
)
= Gaussian(num_points=n_gridpoints)()
u0 = np.zeros_like(u0)
v0
= solver.solve(u0, v0) t, u, v
50], label="initial") plt.plot(solver.grid, u[
# show the solution viewed from above
=(5, 10))
plt.figure(figsize
plt.pcolormesh(solver.grid, t, u)"x")
plt.xlabel("t")
plt.ylabel(
plt.colorbar() plt.show()