Metrics and losses
Metrics and losses for training and evaluation.
::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
import jax
import jax.numpy as jnp
import numpy as np
:::
::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def squared_error(y_true, y_pred):
return (y_true - y_pred) ** 2
def absolute_error(y_true, y_pred):
return jnp.abs(y_true - y_pred)
def mse(y_true, y_pred, axis=None):
return jnp.mean((y_true - y_pred) ** 2, axis=axis)
def mae(y_true, y_pred, axis=None):
return jnp.mean(jnp.abs(y_true - y_pred), axis=axis)
def mse_relative(y_true, y_pred, axis=None):
return jnp.mean(((y_true - y_pred) ** 2), axis=axis) / jnp.mean(
**2), axis=axis
(y_true
)
def mae_relative(y_true, y_pred, axis=None):
return jnp.mean(jnp.abs(y_true - y_pred), axis=axis) / jnp.mean(
abs(y_true), axis=axis
jnp.
)
def accumulate_metrics(metrics):
= jax.device_get(metrics)
metrics return {k: np.mean([metric[k] for metric in metrics]) for k in metrics[0]}
:::
::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
# Function that takes a dictionary of metrics to log, dictionary of metrics for a step and a prefix for the step and adds the metrics to the log dictionary, with the prefix
def log_metrics(
dict[str, float],
metrics: dict[str, float],
step_metrics: str,
prefix: -> dict[str, float]:
) for key, value in step_metrics.items():
+ "/" + key] = float(value)
metrics[prefix return metrics
:::
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def make_best_fn(metric: str):
"""
Returns a lambda that extracts a specific value from a PyTree based on the key.
"""
return lambda x: float(x[metric])
:::
::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def aggregate_metrics(
list,
metrics: int = 0,
axis: -> dict:
) = jax.device_get(metrics)
metrics = {k: [dic[k] for dic in metrics] for k in metrics[0]}
dic_of_lists = {
dic_of_arr_stacked =0) for k in dic_of_lists.keys()
k: jnp.stack(dic_of_lists[k], axis
}= {
metrics_agg f"{k}_mean": jnp.mean(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
metrics_agg.update(
{f"{k}_std": jnp.std(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
)
metrics_agg.update(
{f"{k}_min": jnp.min(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
)
metrics_agg.update(
{f"{k}_max": jnp.max(dic_of_arr_stacked[k], axis=axis)
for k in dic_of_arr_stacked.keys()
}
)
return metrics_agg
:::