Metrics and losses

Metrics and losses for training and evaluation.

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

import jax
import jax.numpy as jnp
import numpy as np

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def squared_error(y_true, y_pred):
    return (y_true - y_pred) ** 2


def absolute_error(y_true, y_pred):
    return jnp.abs(y_true - y_pred)


def mse(y_true, y_pred, axis=None):
    return jnp.mean((y_true - y_pred) ** 2, axis=axis)


def mae(y_true, y_pred, axis=None):
    return jnp.mean(jnp.abs(y_true - y_pred), axis=axis)


def mse_relative(y_true, y_pred, axis=None):
    return jnp.mean(((y_true - y_pred) ** 2), axis=axis) / jnp.mean(
        (y_true**2), axis=axis
    )


def mae_relative(y_true, y_pred, axis=None):
    return jnp.mean(jnp.abs(y_true - y_pred), axis=axis) / jnp.mean(
        jnp.abs(y_true), axis=axis
    )


def accumulate_metrics(metrics):
    metrics = jax.device_get(metrics)
    return {k: np.mean([metric[k] for metric in metrics]) for k in metrics[0]}

:::

::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

# Function that takes a dictionary of metrics to log, dictionary of metrics for a step and a prefix for the step and adds the metrics to the log dictionary, with the prefix
def log_metrics(
    metrics: dict[str, float],
    step_metrics: dict[str, float],
    prefix: str,
) -> dict[str, float]:
    for key, value in step_metrics.items():
        metrics[prefix + "/" + key] = float(value)
    return metrics

:::

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def make_best_fn(metric: str):
    """
    Returns a lambda that extracts a specific value from a PyTree based on the key.
    """
    return lambda x: float(x[metric])

:::

::: {#cell-7 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def aggregate_metrics(
    metrics: list,
    axis: int = 0,
) -> dict:
    metrics = jax.device_get(metrics)
    dic_of_lists = {k: [dic[k] for dic in metrics] for k in metrics[0]}
    dic_of_arr_stacked = {
        k: jnp.stack(dic_of_lists[k], axis=0) for k in dic_of_lists.keys()
    }
    metrics_agg = {
        f"{k}_mean": jnp.mean(dic_of_arr_stacked[k], axis=axis)
        for k in dic_of_arr_stacked.keys()
    }
    metrics_agg.update(
        {
            f"{k}_std": jnp.std(dic_of_arr_stacked[k], axis=axis)
            for k in dic_of_arr_stacked.keys()
        }
    )
    metrics_agg.update(
        {
            f"{k}_min": jnp.min(dic_of_arr_stacked[k], axis=axis)
            for k in dic_of_arr_stacked.keys()
        }
    )
    metrics_agg.update(
        {
            f"{k}_max": jnp.max(dic_of_arr_stacked[k], axis=axis)
            for k in dic_of_arr_stacked.keys()
        }
    )

    return metrics_agg

:::