Metrics and losses
Metrics and losses for training and evaluation.
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax
import jax.numpy as jnp
import numpy as np:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def squared_error(y_true, y_pred):
return (y_true - y_pred) ** 2
def absolute_error(y_true, y_pred):
return jnp.abs(y_true - y_pred)
def mse(y_true, y_pred, axis=None):
return jnp.mean((y_true - y_pred) ** 2, axis=axis)
def mae(y_true, y_pred, axis=None):
return jnp.mean(jnp.abs(y_true - y_pred), axis=axis)
def mse_relative(y_true, y_pred, axis=None):
return jnp.mean(((y_true - y_pred) ** 2), axis=axis) / jnp.mean(
(y_true**2), axis=axis
)
def mae_relative(y_true, y_pred, axis=None):
return jnp.mean(jnp.abs(y_true - y_pred), axis=axis) / jnp.mean(
jnp.abs(y_true), axis=axis
)
def accumulate_metrics(metrics):
metrics = jax.device_get(metrics)
return {k: np.mean([metric[k] for metric in metrics]) for k in metrics[0]}:::
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
# Function that takes a dictionary of metrics to log, dictionary of metrics for a step and a prefix for the step and adds the metrics to the log dictionary, with the prefix
def log_metrics(
metrics: dict[str, float],
step_metrics: dict[str, float],
prefix: str,
) -> dict[str, float]:
for key, value in step_metrics.items():
metrics[prefix + "/" + key] = float(value)
return metrics:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def make_best_fn(metric: str):
"""
Returns a lambda that extracts a specific value from a PyTree based on the key.
"""
return lambda x: float(x[metric]):::
::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def aggregate_metrics(
metrics: list,
axis: int = 0,
) -> dict:
metrics = jax.device_get(metrics)
dic_of_lists = {k: [dic[k] for dic in metrics] for k in metrics[0]}
dic_of_arr_stacked = {
k: jnp.stack(dic_of_lists[k], axis=0) for k in dic_of_lists.keys()
}
metrics_agg = {
f"{k}_mean": jnp.mean(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
metrics_agg.update(
{
f"{k}_std": jnp.std(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
)
metrics_agg.update(
{
f"{k}_min": jnp.min(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
)
metrics_agg.update(
{
f"{k}_max": jnp.max(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
)
return metrics_agg:::