Pseudo-spectral Solver

A pseudo-spectral solver for the 1D wave equation.

The wave equation in 1D is given by:

\[\frac{\partial^2 u}{\partial t^2} = c^2 \frac{\partial^2 u}{\partial x^2}\]

where \(c\) is the wave speed.

In the pseudo-spectral method, we transform the derivatives in space to the frequency domain, solve the equation in the frequency domain, and then transform back to the spatial domain. We solve the problem and iterate using an ode solver.

The 2nd derivative in space is given by:

\[ \frac{\partial^2 u}{\partial x^2} = \mathcal{F}^{-1} \left[ (ik)^2 \mathcal{F} \left( u \right) \right] \]

where \(\mathcal{F}\) and \(\mathcal{F}^{-1}\) are the forward and inverse Fourier transforms, respectively, and \(k\) is the wavenumber.

import matplotlib.pyplot as plt

source

fourier_derivative_2

 fourier_derivative_2 (u:numpy.ndarray, k:numpy.ndarray)

Compute the 2nd derivative of a function in Fourier space, and return the result in physical space.

Type Details
u ndarray function in physical space
k ndarray wave number array

Define the right-hand side of the equation as


source

Wave1dSolverPseudoSpectral

 Wave1dSolverPseudoSpectral (sampling_rate:float, final_time:float,
                             length:float, n_gridpoints:int,
                             wave_speed:float=1)

This class solves the 1D wave equation using the pseudo-spectral method. Inspired by the content in https://www.coursera.org/learn/computers-waves-simulations/home/week/5 It assumes dirchlet boundary conditions on both ends of the string.

Type Default Details
sampling_rate float sampling rate in Hz
final_time float final time in seconds
length float length of the string in meters
n_gridpoints int number of points in the string
wave_speed float 1 wave speed in m/s

Test the solver

from physmodjax.solver.generator import Gaussian
n_gridpoints = 1000
solver = Wave1dSolverPseudoSpectral(
    sampling_rate=44100,
    final_time=1,
    length=1,
    n_gridpoints=n_gridpoints,
    wave_speed=1,
)

u0 = Gaussian(num_points=n_gridpoints)()
v0 = np.zeros_like(u0)

t, u, v = solver.solve(u0, v0)
dx: 0.010101010101010102 in meters
dt: 2.0833333333333333e-05 in seconds
number of points (n_gridpoints): (100,)
time in samples (nt): (48000,)
(48000,) (48000, 100) (48000, 100)
# show the solution viewed from above
plt.figure(figsize=(5, 10))
plt.pcolormesh(solver.grid, t[::100], u[::100])
plt.xlabel("x")
plt.ylabel("t")
plt.colorbar()
plt.show()