import matplotlib.pyplot as plt
Pseudo-spectral Solver
The wave equation in 1D is given by:
\[\frac{\partial^2 u}{\partial t^2} = c^2 \frac{\partial^2 u}{\partial x^2}\]
where \(c\) is the wave speed.
In the pseudo-spectral method, we transform the derivatives in space to the frequency domain, solve the equation in the frequency domain, and then transform back to the spatial domain. We solve the problem and iterate using an ode solver.
The 2nd derivative in space is given by:
\[ \frac{\partial^2 u}{\partial x^2} = \mathcal{F}^{-1} \left[ (ik)^2 \mathcal{F} \left( u \right) \right] \]
where \(\mathcal{F}\) and \(\mathcal{F}^{-1}\) are the forward and inverse Fourier transforms, respectively, and \(k\) is the wavenumber.
fourier_derivative_2
fourier_derivative_2 (u:numpy.ndarray, k:numpy.ndarray)
Compute the 2nd derivative of a function in Fourier space, and return the result in physical space.
Type | Details | |
---|---|---|
u | ndarray | function in physical space |
k | ndarray | wave number array |
Define the right-hand side of the equation as
Wave1dSolverPseudoSpectral
Wave1dSolverPseudoSpectral (sampling_rate:float, final_time:float, length:float, n_gridpoints:int, wave_speed:float=1)
This class solves the 1D wave equation using the pseudo-spectral method. Inspired by the content in https://www.coursera.org/learn/computers-waves-simulations/home/week/5 It assumes dirchlet boundary conditions on both ends of the string.
Type | Default | Details | |
---|---|---|---|
sampling_rate | float | sampling rate in Hz | |
final_time | float | final time in seconds | |
length | float | length of the string in meters | |
n_gridpoints | int | number of points in the string | |
wave_speed | float | 1 | wave speed in m/s |
Test the solver
from physmodjax.solver.generator import Gaussian
= 1000
n_gridpoints = Wave1dSolverPseudoSpectral(
solver =44100,
sampling_rate=1,
final_time=1,
length=n_gridpoints,
n_gridpoints=1,
wave_speed
)
= Gaussian(num_points=n_gridpoints)()
u0 = np.zeros_like(u0)
v0
= solver.solve(u0, v0) t, u, v
dx: 0.010101010101010102 in meters
dt: 2.0833333333333333e-05 in seconds
number of points (n_gridpoints): (100,)
time in samples (nt): (48000,)
(48000,) (48000, 100) (48000, 100)
# show the solution viewed from above
=(5, 10))
plt.figure(figsize100], u[::100])
plt.pcolormesh(solver.grid, t[::"x")
plt.xlabel("t")
plt.ylabel(
plt.colorbar() plt.show()