import matplotlib.pyplot as plt
Pseudo-spectral Solver
A pseudo-spectral solver for the 1D wave equation.
The wave equation in 1D is given by:
\[\frac{\partial^2 u}{\partial t^2} = c^2 \frac{\partial^2 u}{\partial x^2}\]
where \(c\) is the wave speed.
In the pseudo-spectral method, we transform the derivatives in space to the frequency domain, solve the equation in the frequency domain, and then transform back to the spatial domain. We solve the problem and iterate using an ode solver.
The 2nd derivative in space is given by:
\[ \frac{\partial^2 u}{\partial x^2} = \mathcal{F}^{-1} \left[ (ik)^2 \mathcal{F} \left( u \right) \right] \]
where \(\mathcal{F}\) and \(\mathcal{F}^{-1}\) are the forward and inverse Fourier transforms, respectively, and \(k\) is the wavenumber.
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import numpy as np
from scipy.integrate import solve_ivp
from typing import Tuple
:::
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def fourier_derivative_2(
# function in physical space
u: np.ndarray, # wave number array
k: np.ndarray,
):"""
Compute the 2nd derivative of a function in Fourier space, and return the result in physical space.
"""
= np.fft.fft(u)
u_hat = (1j * k) ** 2 * u_hat # 2nd derivative
dudx2 return np.real(np.fft.ifft(dudx2))
:::
Define the right-hand side of the equation as
::: {#cell-10 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class Wave1dSolverPseudoSpectral:
"""
This class solves the 1D wave equation using the pseudo-spectral method.
Inspired by the content in https://www.coursera.org/learn/computers-waves-simulations/home/week/5
It assumes dirchlet boundary conditions on both ends of the string.
"""
def __init__(
self,
float, # sampling rate in Hz
sampling_rate: float, # final time in seconds
final_time: float, # length of the string in meters
length: int, # number of points in the string
n_gridpoints: float = 1, # wave speed in m/s
wave_speed:
):self.sampling_rate = sampling_rate
self.final_time = final_time
self.length = length
self.wave_speed = wave_speed
self.n_gridpoints = n_gridpoints
self.dt = 1 / self.sampling_rate
self.timesteps = np.arange(0, self.final_time, self.dt)
self.grid = np.linspace(0, self.length, self.n_gridpoints)
self.dx = self.grid[1] - self.grid[0]
self.pde_order_time_derivatives = 2
print(f"dx: {self.dx} in meters")
print(f"dt: {self.dt} in seconds")
print(f"number of points (n_gridpoints): {self.grid.shape}")
print(f"time in samples (nt): {self.timesteps.shape}")
# Wave numbers for Fourier differentiation in space
self.n_modes = 2 * np.pi * np.fft.fftfreq(self.n_gridpoints, d=self.dx)
def solve(
self,
# initial position
u0: np.ndarray, # initial velocity
v0: np.ndarray, -> Tuple[np.ndarray, np.ndarray, np.ndarray]: # Returns time, position, velocity
) """
Solve the wave equation using the pseudo-spectral method.
"""
def wave_equation(
# time
t, # state vector
state, float, # wave speed
c: # wave number array
k: np.ndarray, int, # number of grid points
n_gridpoints: -> np.ndarray: # state at timestep t and position x (u(x, t))
) """
Right hand side of the wave equation
"""
= state[:n_gridpoints] # position
u = state[n_gridpoints:] # velocity
v
# Set dirchlet boundary conditions
# before computing the 2nd derivative
0] = 0
u[-1] = 0
u[
# 2nd derivative of position
= fourier_derivative_2(u, k)
dudx2
= c**2 * dudx2
dv_dt
0] = 0
v[-1] = 0
v[
# return the state derivatives
return np.concatenate([v, dv_dt])
# solve the wave equation
= solve_ivp(
sol =wave_equation,
fun=[0, self.final_time],
t_span=np.concatenate([u0, v0], axis=0),
y0="RK45",
method=self.timesteps,
t_eval=(self.wave_speed, self.n_modes, self.n_gridpoints),
args
)
return sol.t, sol.y[: self.n_gridpoints].T, sol.y[self.n_gridpoints :].T
:::
Test the solver
from physmodjax.solver.generator import Gaussian
= 1000
n_gridpoints = Wave1dSolverPseudoSpectral(
solver =44100,
sampling_rate=1,
final_time=1,
length=n_gridpoints,
n_gridpoints=1,
wave_speed
)
= Gaussian(num_points=n_gridpoints)()
u0 = np.zeros_like(u0)
v0
= solver.solve(u0, v0) t, u, v
# show the solution viewed from above
=(5, 10))
plt.figure(figsize100], u[::100])
plt.pcolormesh(solver.grid, t[::"x")
plt.xlabel("t")
plt.ylabel(
plt.colorbar() plt.show()