Optimising the parameters of the Von Karman plate

In this notebook we optimise the bending stiffness of a Von Karman plate using backpropagation through time.

First, we generate a target simulation of the plate. We define the parameters of the plate and the excitation.

Code
n_modes = 20
sampling_rate = 44100
sampling_period = 1 / sampling_rate
h = 0.004  # grid spacing in the lowest resolution
nx = 50  # number of grid points in the x direction in the lowest resolution
ny = 75  # number of grid points in the y direction in the lowest resolution
levels = 2  # number of grid refinements to perform
amplitude = 0.5
params = PlateParameters(
    E=2e12,
    nu=0.3,
    rho=7850,
    h=5e-4,
    l1=0.2,
    l2=0.3,
    Ts0=100,
)
force_position = (0.05, 0.05)
readout_position = (0.1, 0.1)
Code
# boundary conditions for the transverse modes
bcs_phi = np.array(
    [
        [1e15, 0],
        [1e15, 0],
        [1e15, 0],
        [1e15, 0],
    ]
)
# boundary conditions for the in-plane modes
bcs_psi = np.array(
    [
        [1e15, 1e15],
        [1e15, 1e15],
        [1e15, 1e15],
        [1e15, 1e15],
    ]
)

psi, zeta_mu_squared, nx_final, ny_final, h_final, psi_norms = (
    multiresolution_eigendecomposition(
        params,
        n_modes,
        bcs_psi,
        h,
        nx,
        ny,
        levels=2,
    )
)

phi, lambda_mu_squared, nx_final, ny_final, h_final, phi_norms = (
    multiresolution_eigendecomposition(
        params,
        n_modes,
        bcs_phi,
        h,
        nx,
        ny,
        levels=2,
    )
)

H = compute_coupling_matrix_numerical(
    psi,
    phi,
    h_final,
    nx_final,
    ny_final,
)
e = params.E / (2 * params.rho)
H = H * np.sqrt(e)
lambda_mu = jnp.sqrt(lambda_mu_squared)
Refining grid to h = 0.002, nx = 100, ny = 150
Refining grid to h = 0.002, nx = 100, ny = 150
Code
# generate a 1d raised cosine excitation
rc = create_1d_raised_cosine(
    duration=1.0,
    start_time=0.001,
    end_time=0.003,
    amplitude=amplitude,
    sample_rate=44100,
)

phi_reshaped = np.reshape(
    phi,
    shape=(ny_final + 1, nx_final + 1, n_modes),
    order="F",
)

mode_gains_at_pos = phi_reshaped[
    int(force_position[1] * ny_final),
    int(force_position[0] * nx_final),
    :,
]

mode_gains_at_readout = phi_reshaped[
    int(readout_position[1] * ny_final),
    int(readout_position[0] * nx_final),
    :,
]
# the modal excitation needs to be scaled by A_inv and divided by the density
mode_gains_at_pos_normalised = mode_gains_at_pos / params.density
modal_excitation_normalised_short = rc[: 4410 * 3, None] * mode_gains_at_pos_normalised
modal_excitation_normalised_long = rc[:44100, None] * mode_gains_at_pos_normalised

Loss landscape

Let’s take a small detour to explore how the loss function varies with respect to a single parameter. First define the loss function.

Code
def combined_loss_fn(
    pars,
    lm_loss_weight=1.0,
    ot_loss_weight=1.0,
    sc_loss_weight=1.0,
    time_loss_weight=1.0,
):
    out_pos = simulate_vkplate(pars, modal_excitation_normalised_short)

    out_pos_fft_mag = jnp.abs(stft(out_pos)) * out_pos_gt_fft_mag_scale

    log_diff = safe_log(out_pos_gt_fft_mag + 1e-10) - safe_log(out_pos_fft_mag + 1e-10)
    lm_loss = jnp.mean(jnp.abs(log_diff))

    ot_loss = jnp.mean(
        jax.vmap(spectral_wasserstein, in_axes=(0, 0, None, None))(
            out_pos_fft_mag,
            out_pos_gt_fft_mag,
            True,
            True,
        )
    )
    time_loss = jnp.mean(jnp.square(out_pos - out_pos_gt))
    sc_loss = spectral_convergence_loss(
        out_pos_fft_mag,
        out_pos_gt_fft_mag,
    )

    combined_loss = (
        lm_loss * lm_loss_weight
        + ot_loss * ot_loss_weight
        + sc_loss * sc_loss_weight
        + time_loss * time_loss_weight
    )
    return combined_loss, (lm_loss, ot_loss, sc_loss, time_loss)

Plot the loss landscape for the bending stiffness

Code
def compute_losses_for_stiffness(
    bending_stiffness,
    loss_fn,
):
    pars = {
        "bending_stiffness": bending_stiffness,
        "Ts0": params.Ts0 / params.density,
    }
    return loss_fn(pars)


bending_stiffness_normalised_range = jnp.linspace(
    gt_pars["bending_stiffness"] - 4.5,
    gt_pars["bending_stiffness"] + 4.5,
    100,
)

losses_combined, (losses_lm, losses_ot, losses_sc, losses_time) = jax.vmap(
    partial(
        compute_losses_for_stiffness,
        loss_fn=combined_loss_fn,
    )
)(bending_stiffness_normalised_range)

Optimise the bending stiffness

Now we optimise the bending stiffness. Starting from an initial value of 10.0

Target
Initial

Optimisation loop. This might take a while to run, depending on how long is the sequence we want to optimise over and the number of iterations. Here we optimise over a 0.3 second sequence (13230 samples), which is on the longer side for this sort of optimisation.

Code
learning_rate = 2e-1
iterations = 1000
scheduler = optax.cosine_onecycle_schedule(
    transition_steps=iterations,
    peak_value=learning_rate,
)
optimiser = optax.adam(learning_rate=scheduler)

state = optimiser.init(pars)
value_and_grad = jax.value_and_grad(
    partial(
        combined_loss_fn,
        lm_loss_weight=0.0,
        ot_loss_weight=1.0,
        sc_loss_weight=0.001,
        time_loss_weight=0.0,
    ),
    has_aux=True,
)


@jax.jit
def train_step(pars, state):
    (loss, _), grads = value_and_grad(pars)
    updates, state = optimiser.update(grads, state, pars)
    pars = optax.apply_updates(pars, updates)
    return pars, state, loss


bar = tqdm(range(iterations))
for i in bar:
    pars, state, loss = train_step(pars, state)
    bar.set_description(
        f"Loss: {loss:.3f}, bending stiffness: {pars['bending_stiffness']:.4f}, ground truth: {gt_pars['bending_stiffness']:.4f}",
    )
  0%|          | 0/1000 [00:00<?, ?it/s]2025-04-04 12:06:51.270375: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng28{k2=1,k3=0} for conv %cudnn-conv.4 = (f32[1,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[1,1024,15277]{2,1,0} %bitcast.12099, f32[1,1024,1024]{2,1,0} %bitcast.12104), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 12:06:51.550709: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.280406645s
Trying algorithm eng28{k2=1,k3=0} for conv %cudnn-conv.4 = (f32[1,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[1,1024,15277]{2,1,0} %bitcast.12099, f32[1,1024,1024]{2,1,0} %bitcast.12104), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
Loss: 0.000, bending stiffness: 5.8328, ground truth: 5.8328: 100%|██████████| 1000/1000 [03:49<00:00,  4.36it/s]

Target
Optimised

Multistart parallel optimisation

We can also optimise over a larger range of initial bending stiffness values, using parallel multiple starts in parallel. Here we use 50 random initial values between 2 and 50. NB: with longer sequences, this will take longer to run also because the XLA compilation will take longer to find an appropriate implementation because of the STFT.

Code
value_and_grad = jax.value_and_grad(
    partial(
        combined_loss_fn,
        lm_loss_weight=0.0,
        ot_loss_weight=1.0,
        sc_loss_weight=0.001,
        time_loss_weight=0.0,
    ),
    has_aux=True,
)


def losses_and_grads_for_single_stiffness(bending_stiffness):
    pars = {
        "bending_stiffness": bending_stiffness,
        "Ts0": params.Ts0 / params.density,
    }

    (loss, _), grads = value_and_grad(pars)

    # return loss and gradient for bending stiffness only
    return loss, grads["bending_stiffness"]


compute_vec_loss_grad = jax.vmap(losses_and_grads_for_single_stiffness)

# Generate starting points
num_starts = 50
start_points = jnp.linspace(2.0, 50.0, num_starts)

learning_rate = 1e-2
iterations = 100
scheduler = optax.cosine_onecycle_schedule(
    transition_steps=iterations,
    peak_value=learning_rate,
)
optimiser = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=scheduler),
)
opt_state = optimiser.init(start_points)


@jax.jit
def train_step(pars, state):
    losses, grads = compute_vec_loss_grad(pars)
    updates, state = optimiser.update(grads, state, pars)
    pars = optax.apply_updates(pars, updates)
    return pars, state, losses


bar = tqdm(range(iterations))
for i in bar:
    start_points, opt_state, losses = train_step(start_points, opt_state)

    best_idx = jnp.argmin(losses)
    best_loss = losses[best_idx]
    best_param = start_points[best_idx]
    bar.set_description(f"Best loss: {best_loss:.6f}, Best param: {best_param:.6f}")

# Get final results
final_losses, _ = compute_vec_loss_grad(start_points)
best_idx = jnp.argmin(final_losses)
best_param = start_points[best_idx]
best_loss = final_losses[best_idx]

print(f"Best bending stiffness found: {best_param:.6f}")
print(f"Ground truth: {gt_pars['bending_stiffness']:.6f}")
print(f"Best loss: {best_loss:.6f}")
  0%|          | 0/100 [00:00<?, ?it/s]2025-04-04 11:57:23.800760: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng55{k2=8,k13=1,k14=3,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:27.647765: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 4.84707795s
Trying algorithm eng55{k2=8,k13=1,k14=3,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:28.648244: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng56{k2=8,k12=-1,k13=1,k14=3,k15=0,k17=512,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:32.474092: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 4.826270609s
Trying algorithm eng56{k2=8,k12=-1,k13=1,k14=3,k15=0,k17=512,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:33.474217: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng54{k2=5,k12=-1,k13=1,k14=2,k15=0,k17=512,k18=1,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:40.681969: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 8.207817642s
Trying algorithm eng54{k2=5,k12=-1,k13=1,k14=2,k15=0,k17=512,k18=1,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:41.682097: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng53{k2=5,k13=1,k14=2,k18=1,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:48.827023: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 8.144989253s
Trying algorithm eng53{k2=5,k13=1,k14=2,k18=1,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:49.827151: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng54{k2=1,k12=-1,k13=0,k14=3,k15=0,k17=256,k18=1,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:56.732148: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 7.905064559s
Trying algorithm eng54{k2=1,k12=-1,k13=0,k14=3,k15=0,k17=256,k18=1,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:57:57.732277: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng28{k2=3,k3=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:09.450809: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 12.718606294s
Trying algorithm eng28{k2=3,k3=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:10.450932: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng41{k2=0,k12=-1,k13=2,k14=3,k15=0,k17=512,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:25.186257: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 15.735370362s
Trying algorithm eng41{k2=0,k12=-1,k13=2,k14=3,k15=0,k17=512,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:26.186537: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng35{k2=5,k5=2,k14=6} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:37.667011: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 12.480538189s
Trying algorithm eng35{k2=5,k5=2,k14=6} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:38.667145: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng48{k2=2,k6=2,k13=1,k14=0,k22=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:54.286052: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 16.618976657s
Trying algorithm eng48{k2=2,k6=2,k13=1,k14=0,k22=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:58:55.286177: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng48{k2=15,k6=2,k13=1,k14=0,k22=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:59:15.215454: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 20.92933587s
Trying algorithm eng48{k2=15,k6=2,k13=1,k14=0,k22=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:59:16.215570: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng56{k2=6,k12=-1,k13=0,k14=1,k15=0,k17=342,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:59:48.464969: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 33.249455776s
Trying algorithm eng56{k2=6,k12=-1,k13=0,k14=1,k15=0,k17=342,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 11:59:49.465108: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng55{k2=3,k13=2,k14=2,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 12:00:04.464957: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 15.999924232s
Trying algorithm eng55{k2=3,k13=2,k14=2,k18=1,k22=0,k23=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 12:00:05.465090: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng28{k2=0,k3=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 12:00:41.373249: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 36.908233659s
Trying algorithm eng28{k2=0,k3=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 12:00:42.373385: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng28{k2=1,k3=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-04-04 12:01:41.780386: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1m0.407078507s
Trying algorithm eng28{k2=1,k3=0} for conv %cudnn-conv.4 = (f32[50,1,14254]{2,1,0}, u8[0]{0}) custom-call(f32[50,1024,15277]{2,1,0} %bitcast.12196, f32[1,1024,1024]{2,1,0} %bitcast.12201), window={size=1024}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="jit(train_step)/jit(main)/conv_general_dilated" source_file="/tmp/ipykernel_100446/1261401191.py" source_line=38}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
Best loss: 0.000006, Best param: 5.832441: 100%|██████████| 100/100 [07:57<00:00,  4.78s/it] 
Best bending stiffness found: 5.832441
Ground truth: 5.832808
Best loss: 0.000005