# jax.config.update("jax_enable_x64", False)
# os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" # add xla flags
= 0.001
dur = 1024
num_variations
= 48000
fs = 40
num_points = 25
simulated_modes = 1
room_size = 1
room_aspect_ratio = 100
num_example_timesteps
#######################################################################################################################
= WaveSolver2DJax(
solver =dur,
final_time=fs,
sampling_rate=room_size,
lx=room_aspect_ratio * room_size,
ly=room_size / num_points,
spatial_delta=simulated_modes,
n_max_modes )
Wave 2D linear solver
adapted from Paker repo
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import numpy as np
import jax.numpy as jnp
import jax
:::
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class WaveSolver2DJax:
def __init__(
self,
int = 48000, # 1/s Temporal sampling frequency
sampling_rate: float = 0.02, # s Duration of the simulation
final_time: float = 1.2041, # kg/m**3 Density
rho: int = 10, # Number of modal expansion terms
n_max_modes: float = 1.0, # m Length in x direction
lx: float = 1.0, # m Length in y direction
ly: float = 343, # m/s Speed of sound
c0: =1.0, # Damping factor
damping=1e-3, # m Spatial sampling grid
spatial_delta
):= 1 / sampling_rate
T
self.numT = round(final_time / T)
self.t = np.linspace(0, final_time, num=self.numT, endpoint=True) # time vector
self.numXs = round(lx / spatial_delta)
self.numYs = round(ly / spatial_delta)
= np.linspace(0, lx, num=self.numXs, endpoint=True) # space vector
xs = np.linspace(0, ly, num=self.numYs, endpoint=True) # space vector
ys
= np.meshgrid(np.arange(n_max_modes), np.arange(n_max_modes))
xv, yv = xv.flatten() + 1
mux = yv.flatten() + 1
muy
= np.arange(n_max_modes * n_max_modes)
mu
= mux * np.pi / lx
lamX = muy * np.pi / ly
lamY
= 1j * c0 * np.sqrt(lamX[mu] ** 2 + lamY[mu] ** 2)
smu
= np.hstack((smu, np.conj(smu)))
smu = np.hstack((lamX.T, lamX.T))
lamX = np.hstack((lamY.T, lamY.T))
lamY
# add damping
= smu - damping
smu
## FTM - scaling factor
= -8 * lamX**2 / (rho * smu**2) * lx * ly
nx = -8 * lamY**2 / (rho * smu**2) * lx * ly
ny = 8 / (rho * c0**2) * lx * ly
nc
= nx + ny + nc
nmu
## FTM - Eigenfunctions
self.K1 = lambda x, y: 4 * np.cos(lamX * x) * np.cos(lamY * y)
self.K2 = (
lambda x, y: 4 * lamX / (smu * rho) * np.sin(lamX * x) * np.cos(lamY * y)
)self.K3 = (
lambda x, y: 4 * lamY / (smu * rho) * np.cos(lamX * x) * np.sin(lamY * y)
)
self.Ka1 = (
lambda x, y: -4 * lamX / (smu * rho) * np.sin(lamX * x) * np.cos(lamY * y)
)self.Ka2 = (
lambda x, y: -4 * lamY / (smu * rho) * np.cos(lamX * x) * np.sin(lamY * y)
)self.Ka3 = lambda x, y: 4 * np.cos(lamX * x) * np.cos(lamY * y)
= np.zeros(
K1_sp =complex
(xs.size, ys.size, nmu.size), dtype# Eigenfunctions for sound pressure
) = np.zeros(
K2_vx =complex
(xs.size, ys.size, nmu.size), dtype# Eigenfunctions for particle velocity in x-direction
) = np.zeros(
K3_vy =complex
(xs.size, ys.size, nmu.size), dtype# Eigenfunctions for particle velocity in y-direction
)
for xi in range(xs.size):
for yi in range(ys.size):
= self.K1(xs[xi], ys[yi]) / nmu
K1_sp[xi, yi, :] = self.K2(xs[xi], ys[yi]) / nmu
K2_vx[xi, yi, :] = self.K3(xs[xi], ys[yi]) / nmu
K3_vy[xi, yi, :]
# Explicity copy variables
self.lx = lx
self.ly = ly
self.smu = smu
self.nmu = nmu
self.T = T
self.Fs = sampling_rate
self.xs = xs
self.ys = ys
self.K1_sp = K1_sp
self.K2_vx = K2_vx
self.K3_vy = K3_vy
self.lamX = lamX
self.lamY = lamY
def create_impulse(self, xe_rel, ye_rel):
# Explicity copy variables
= self.lx, self.ly
lx, ly
# Use a delta at exciation position on the string
= xe_rel * lx
xe = ye_rel * ly
ye
# impulse excitation at (xe,ye)
= self.Ka3(xe, ye)
fe_xy
return fe_xy
def create_random_initial(
self,
= np.random.default_rng(),
rng: np.random.Generator -> np.ndarray:
) # Explicity copy variables
= self.lx, self.ly
lx, ly = self.xs, self.ys
xs, ys = self.smu
smu = self.lamX, self.lamY
lamX, lamY
= np.zeros((1, smu.size))
fe_xy
= rng.uniform(-1, 1, (smu.size, xs.size)) # Shape: (smu.size, xs.size)
rx = rng.uniform(-1, 1, (smu.size, ys.size)) # Shape: (smu.size, ys.size)
ry = 4 * np.cos(lamX[:, None] * xs) * rx # Broadcasting lamX and rx
funX = np.cos(lamY[:, None] * ys) * ry # Broadcasting lamY and ry
funY = np.trapz(funX, xs, axis=-1) # Shape: (smu.size,)
integX = np.trapz(funY, ys, axis=-1) # Shape: (smu.size,)
integY = integX[:, np.newaxis] * integY # Shape: (smu.size, smu.size)
fe_xy = integX * integY
fe_xy
return fe_xy
def solve(
self,
u0,=None,
v0=True,
parallel
):## Copy internal variables
= self.T
T = self.smu
smu
## Simulation - state equation
if parallel:
# ybar = jnp.vander(jnp.exp(smu * T), self.t.size, increasing=True)
# ybar = ybar * fe_x[:, None]
# ybar = jnp.cumprod(jnp.exp(smu * T), axis=0)
# print(ybar.shape)
= jnp.repeat(smu[None, :], self.t.size - 1, axis=0)
smu = jax.lax.associative_scan(jnp.multiply, jnp.exp(smu * T)) * u0
ybar = ybar.T
ybar = jnp.concatenate((u0[..., None], ybar), axis=1)
ybar else:
= np.zeros((smu.size, self.t.size), dtype=complex)
ybar 0] = u0
ybar[:, for k in range(1, self.t.size): # for k = 2:length(t)
= np.exp(smu * T) * ybar[:, k - 1]
ybar[:, k]
# project back to spatial domain
= self.K1_sp @ ybar
y_sp = self.K2_vx @ ybar
y_vx = self.K3_vy @ ybar
y_vy
# y = np.float32(np.real(y))
= jnp.real(ybar).astype(jnp.float32)
ybar = jnp.real(y_sp).astype(jnp.float32)
y_sp = jnp.real(y_vx).astype(jnp.float32)
y_vx = jnp.real(y_vy).astype(jnp.float32)
y_vy
return ybar, y_sp, y_vx, y_vy
:::
# jax.config.update("jax_enable_x64", False)
= solver.create_random_initial(np.random.default_rng(42))
fe_x # fe_x = solver.create_impulse(0.5, 0.5)
= solver.solve(fe_x)
ybar, y_sp, y_vx, y_vy = solver.solve(fe_x, parallel=True)
ybar_vander, y_sp_vander, y_vx_vander, y_vy_vander
= np.abs(ybar - ybar_vander)
diff_ybar = np.abs(y_sp - y_sp_vander)
diff_y_sp = np.abs(y_vx - y_vx_vander)
diff_y_vx = np.abs(y_vy - y_vy_vander)
diff_y_vy
print("ybar diff", np.max(diff_ybar))
print("y_sp diff", np.max(diff_y_sp))
print("y_vx diff", np.max(diff_y_vx))
print("y_vy diff", np.max(diff_y_vy))
print(y_sp.shape, y_vx.shape, y_vy.shape)
= plt.subplots(1, 3, figsize=(10, 5))
fig, ax 0].imshow(y_sp[:, :, 200])
ax[1].imshow(y_vx[:, :, 200])
ax[2].imshow(y_vy[:, :, 200]) ax[
Save a fast dataset
::: {#cell-10 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
from einops import rearrange
from pathlib import Path
from fastcore.script import call_parse
from tqdm import tqdm
:::
::: {#cell-11 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
@call_parse
def create_2d_wave_data(
int, # number of initial conditions
num_ics: str, # directory to save the data
data_dir: int = 40, # width
W: float = 0.01, # time
dur: int = 48000, # sampling rate
sample_rate: int = 25, # number of modes
simulated_modes: int = 42, # random seed
seed: str = "random", # type of initial condition (random or impulse)
ic_type:
):# set the global precision to 64-bit
"jax_enable_x64", True)
jax.config.update(
= W
H = 1
room_size = 1
room_aspect_ratio = 1 / sample_rate
sampling_period = round(dur / sampling_period)
T
= WaveSolver2DJax(
solver =dur,
final_time=sample_rate,
sampling_rate=room_size,
lx=room_aspect_ratio * room_size,
ly=room_size / W,
spatial_delta=simulated_modes,
n_max_modes
)
# approximate Gb
= num_ics * T * H * W * 3 * 4
total_size_bytes print("Total size in GB", total_size_bytes / 1e9)
# make sure the data directory exists
=True, exist_ok=True)
Path(data_dir).mkdir(parents
# seed the random number generator
= np.random.default_rng(seed)
rng for i in tqdm(range(1, num_ics + 1)):
= (
u0
solver.create_impulse(rng.uniform(), rng.uniform())if ic_type == "impulse"
else solver.create_random_initial(rng)
)
= solver.solve(u0)
ybar, y_sp, y_vx, y_vy = jnp.stack((y_sp, y_vx, y_vy), axis=-1)
run_2d = rearrange(run_2d, "h w t c -> t h w c")
run_2d f"{data_dir}/ic_{i:05d}.npy", run_2d) jnp.save(
:::
= "2d_wave_data"
data_dir
1000, data_dir, ic_type="random") create_2d_wave_data(
import matplotlib.pyplot as plt
# check the data, the first 3 files should be different
for i in range(1, 4):
= np.load(f"{data_dir}/ic_{i:05d}.npy")
data print(data.shape)
= plt.subplots(1, 3, figsize=(10, 5))
fig, ax 0].imshow(data[100, :, :, 0])
ax[1].imshow(data[100, :, :, 1])
ax[2].imshow(data[100, :, :, 2])
ax[
plt.show()