Fitting to a real (thick) plate

This notebook shows how to fit the physical parameters of a plate to a real recording.

Load a recording of a struck plate and preprocess it

Code
stiff_string_real, file_sr = sf.read("audio/single.wav")
if file_sr != sample_rate:
    print(f"Resampling from {file_sr} to {sample_rate}")
    stiff_string_real = soxr.resample(
        stiff_string_real,
        in_rate=file_sr,
        out_rate=sample_rate,
    )

print("The sample rate is", sample_rate)

scale = 1
duration = 2.0
offset = int(0.00 * sample_rate)
stop = int(1 * sample_rate)
# ensure the audio has exactly the same length
stiff_string_real = stiff_string_real[offset : offset + stop]

# high pass filter the audio
b, a = butter(N=4, Wn=800, btype="high", fs=sample_rate)
stiff_string_real = lfilter(b, a, stiff_string_real)

# get the rfft of the real audio
u_stiff_string_rfft = np.fft.rfft(stiff_string_real)

# get the spectral envelope
a_lpc_cpu_solve_autocorr, g_lpc_solve_autocorr = lpc_cpu_solve(
    stiff_string_real,
    128,
    method="autocorrelation",
    biased=False,
)

w, h = freqz(
    b=g_lpc_solve_autocorr,
    a=np.concatenate([[1], a_lpc_cpu_solve_autocorr]),
    worN=u_stiff_string_rfft.shape[0],
    fs=sample_rate,
)

# impulse response
H = g_lpc_solve_autocorr / np.fft.rfft(
    np.concatenate([[1], a_lpc_cpu_solve_autocorr]),
    n=sample_rate,
)
y = np.fft.irfft(H, n=sample_rate)
y_rfft = np.abs(np.fft.rfft(y))

t = np.linspace(0, duration, len(stiff_string_real))
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.set_title("RFFT and spectral envelope")
ax.semilogx(to_db(np.abs(u_stiff_string_rfft)), label="RFFT")
ax.semilogx(w, to_db(np.abs(h)), label="LPC spectral envelope", ls="--")
ax.grid("both")
ax.legend()
fig.tight_layout()

display_audio_with_title(stiff_string_real, sample_rate, "Original")
display_audio_with_title(y, sample_rate, "LPC fit and filtered")
Resampling from 48000 to 44100
The sample rate is 44100
Original
LPC fit and filtered

Sample the envelope using the bark scale.

Code
hz_range = np.array([200, 20000])
barkrange = hz2bark(hz_range)
worN = bark2hz(np.linspace(barkrange[0], barkrange[1], 20000))

w, h = freqz(
    g_lpc_solve_autocorr,
    a=np.concatenate([[1], a_lpc_cpu_solve_autocorr]),
    worN=worN,
    fs=sample_rate,
)
target_freq_resp = jnp.abs(h) / jnp.max(jnp.abs(h))

Define the initial parameters and constraints.

Code
RANGE_BENDING_STIFFNESS = 15000

rng = np.random.default_rng(654)
pars = {
    "bending_stiffness": rng.normal(),
    "gamma_mu": jnp.linspace(5, 15, n_modes),
    "d1": rng.normal(),
    "d3": rng.normal(),
    "Ts0": rng.normal(),
    "l1": rng.normal(),
    "l2": rng.normal(),
    "z0": rng.normal(size=(n_modes, 1)).astype(np.float32),
    "gain": rng.normal(scale=1e-3),
    "zero_radii": rng.normal(size=(n_modes)).astype(np.float32),
    "zero_angles": rng.normal(size=(n_modes)).astype(np.float32),
}


def get_bending_stiffness(params):
    return jax.nn.sigmoid(params["bending_stiffness"]) * RANGE_BENDING_STIFFNESS


def get_l1(params):
    return jax.nn.sigmoid(params["l1"])


def get_l2(params):
    return jax.nn.sigmoid(params["l2"])


def get_z0(params):
    return params["z0"]


def get_gamma_mu(params):
    return -jax.nn.relu(params["gamma_mu"])
    # return -jnp.exp(params["gamma_mu"])


def get_Ts0(params):
    return 0.0


def get_gain(params):
    return params["gain"]


def get_zeros(pars):
    return jax.nn.sigmoid(pars["zero_radii"]) * jnp.exp(
        2j * np.pi * jax.nn.sigmoid(pars["zero_angles"])
    )

Simulate the plate using the initial parameters.

Code
def tf_modified(
    pars,
    lambda_mu,
    dt,
):
    omega_mu_squared = (
        get_bending_stiffness(pars) * lambda_mu**2 + get_Ts0(pars) * lambda_mu
    )
    gamma_mu = get_gamma_mu(pars)
    omega_mu = jnp.sqrt(omega_mu_squared - gamma_mu**2)

    # discretise
    radius = jnp.exp(gamma_mu * dt)
    real = radius * jnp.cos(omega_mu * dt)

    zeros = get_zeros(pars)
    b1 = -2.0 * zeros.real
    b2 = zeros.real**2 + zeros.imag**2

    a1 = -2.0 * real
    a2 = radius**2

    ones = jnp.ones_like(lambda_mu)

    b = jnp.stack([ones, b1, b2], axis=-1)
    a = jnp.stack([ones, a1, a2], axis=-1)
    return b, a


def simulate_membrane(pars):
    wnx, wny = plate_wavenumbers(
        n_max_modes_x,
        n_max_modes_y,
        get_l1(pars),
        get_l2(pars),
    )
    lambda_mu = plate_eigenvalues(wnx, wny).reshape(-1)
    lambda_mu = lambda_mu.reshape(-1).sort()[:n_modes]

    b, a = tf_modified(pars, lambda_mu, dt)
    b = b * get_z0(pars) * get_gain(pars)
    h = tf_freqz(b, a, worN, sample_rate)
    pred_freq_resp = jnp.mean(jnp.abs(h), axis=0)
    return pred_freq_resp, b, a


initial_freq_resp, b, a = simulate_membrane(pars)

fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.semilogx(
    worN,
    to_db(target_freq_resp),
    label="Target",
)
ax.semilogx(
    worN,
    to_db(initial_freq_resp),
    label="Initial",
    ls="--",
)
ax.grid(which="both")
_ = ax.legend()


x = jnp.zeros(shape=(sample_rate), dtype=jnp.float32)
x = x.at[0].set(1.0)
pred_imp_resp = iir_filter_parallel(b, a, x).mean(axis=1)

display_audio_with_title(y, sample_rate, "Target")
display_audio_with_title(pred_imp_resp, sample_rate, "Initial")
Target
Initial

Optimise the parameters using gradient descent.

Code
iterations = 30_000
learning_rate = 1e-2
scheduler = optax.cosine_onecycle_schedule(
    transition_steps=iterations,
    peak_value=learning_rate,
)
optimiser = optax.chain(
    optax.clip_by_global_norm(2.0),
    optax.adam(learning_rate=scheduler),
)
state = optimiser.init(pars)


@jax.jit
def train_step(pars, state):
    def loss_fn(pars):
        pred_freq_resp, b, a = simulate_membrane(pars)

        log_pred_freq_resp = safe_log(pred_freq_resp)
        log_target_freq_resp = safe_log(target_freq_resp)

        lin_diff = pred_freq_resp - target_freq_resp
        log_diff = log_pred_freq_resp - log_target_freq_resp

        lin_l2_loss = jnp.mean(
            jnp.square(
                lin_diff,
            ),
        )
        log_l1_loss = jnp.mean(
            jnp.abs(
                log_diff,
            ),
        )
        sc_loss = spectral_convergence_loss(
            log_pred_freq_resp,
            log_target_freq_resp,
        )
        ot_loss = jnp.mean(
            spectral_wasserstein(
                pred_freq_resp,
                target_freq_resp,
                squared=True,
                is_mag=True,
            )
        )

        return lin_l2_loss + log_l1_loss * 0.1 + sc_loss + ot_loss * 0.001

    loss, grads = jax.value_and_grad(loss_fn)(pars)

    updates, state = optimiser.update(grads, state, pars)
    pars = optax.apply_updates(pars, updates)
    return pars, state, loss


bar = tqdm(range(iterations))
for i in bar:
    pars, state, loss = train_step(pars, state)
    bar.set_description(
        f"Loss: {loss:.3f}, bending_stiffness: {get_bending_stiffness(pars):.3f}, l1: {get_l1(pars):.3f}, l2: {get_l2(pars):.3f}"
    )
Loss: 0.158, bending_stiffness: 3998.537, l1: 0.738, l2: 0.357: 100%|██████████| 30000/30000 [00:47<00:00, 632.80it/s]
Target
Optimised