Checkpoint utils

::: {#cell-3 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

from pathlib import Path
from typing import Any

import flax.linen as nn
import hydra
import jax
import orbax.checkpoint as obc
import wandb
from flax.training import train_state
from omegaconf import OmegaConf, DictConfig
from wandb.apis import public
import numpy as np

:::

::: {#cell-4 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def restore_experiment_state(
    run_path: Path,  # Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
    best: bool = True,  # If True, restore the best checkpoint instead of the latest
    best_metric: str = "val/mae_rel",  # Metric to use for best checkpoint selection
    step_to_restore: int = None,  # If not None, restore the checkpoint at this step, incompatible with `best`
    kwargs: dict = {},  # Additional arguments to pass to the model
    device: jax.Device = None,  # Device to restore the model on
) -> tuple[train_state.TrainState, nn.Module, obc.CheckpointManager]:
    """
    Restores the train state from a run.

    Args:
        run_path (Path): Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
        best (bool): If True, restore the best checkpoint instead of the latest
        step_to_restore (int): If not None, restore the checkpoint at this step
        kwargs (dict): Additional arguments to pass to the model
        device (jax.Device): Device to restore the model on

    Returns:
    -------
        train_state.TrainState: The train state of the experiment
        nn.Module: The model used in the experiment
        CheckpointManager: The checkpoint manager
    """

    # Make sure the path is a Path object
    run_path = Path(run_path)

    # These are hardcoded, do not change
    ckpt_path = run_path / "checkpoints"
    config_path = run_path / ".hydra" / "config.yaml"
    cfg = OmegaConf.load(config_path)

    # Check if either `best` or `step_to_restore` is set, but not both
    if best:
        if step_to_restore is not None:
            raise ValueError(
                "You cannot set both `best` and `step_to_restore`. Please choose one."
            )
        if best_metric is None:
            raise ValueError(
                "If `best=True`, you must provide a `best_metric` to determine the best checkpoint."
            )

    options = obc.CheckpointManagerOptions(
        max_to_keep=1,
        create=True,
        best_fn=lambda x: float(
            x[best_metric]
        ),  # Shouldn't be hardcoded here, not a problem atm because we only save one step, best
        best_mode="min",
    )

    def set_restore_type(x: Any) -> obc.RestoreArgs:
        return obc.RestoreArgs(restore_type=np.ndarray)

    with obc.CheckpointManager(
        ckpt_path,
        options=options,
        item_handlers={
            "state": obc.PyTreeCheckpointHandler(),
            "default": obc.PyTreeCheckpointHandler(),
        },
    ) as checkpoint_manager:
        model_cls: nn.Module = hydra.utils.instantiate(cfg.model)
        model = model_cls(training=False, **kwargs)

        # Get checkpoint metadata
        step = (
            checkpoint_manager.latest_step()
            if not best
            else checkpoint_manager.best_step()
        )
        step = step_to_restore if step_to_restore is not None else step
        metadatas = checkpoint_manager.item_metadata(step)
        print(f"Restoring checkpoint from step {step}...")

        # Restore the metadata

        # Backwards compatibility for older checkpoints
        if "state" in metadatas and metadatas.state is not None:
            restore_args_state = jax.tree_util.tree_map(
                set_restore_type, metadatas["state"]
            )
            metadatas = checkpoint_manager.restore(
                step=step,
                args=obc.args.Composite(
                    state=obc.args.PyTreeRestore(
                        item=metadatas["state"],
                        restore_args=restore_args_state,
                    ),
                ),
            )
            metadata_state = metadatas.state

        elif "default" in metadatas and metadatas.default is not None:
            print("This is a checkpoint with old formatting")
            if "model" not in metadatas.default:
                raise ValueError("No model found in the checkpoint")

            restore_args_default = jax.tree_util.tree_map(
                set_restore_type, metadatas["default"]
            )
            metadatas = checkpoint_manager.restore(
                step=step,
                args=obc.args.Composite(
                    default=obc.args.PyTreeRestore(
                        item=metadatas["default"],
                        restore_args=restore_args_default,
                    ),
                ),
            )
            metadata_state = metadatas.default["model"]
        else:
            raise ValueError("No state found in the checkpoint")

        if "batch_stats" in metadata_state:
            # Define TrainState with optional batch_stats
            class TrainState(train_state.TrainState):
                key: jax.Array
                batch_stats: Any = None  # Optional field

            # Initialize the empty state
            state = TrainState(
                key={},
                step=0,
                apply_fn=model.apply,
                params=metadata_state["params"],
                tx={},
                opt_state=metadata_state["opt_state"],
                batch_stats=metadata_state["batch_stats"],
            )
        else:
            # Define TrainState with optional batch_stats
            class TrainState(train_state.TrainState):
                key: jax.Array

            state = TrainState(
                key={},
                step=0,
                apply_fn=model.apply,
                params=metadata_state["params"],
                tx={},
                opt_state=metadata_state["opt_state"],
            )

        # If device is not specified, use the default device
        if device is None:
            device = jax.devices()[0]
        # Move the state to the specified device
        state = jax.device_put(state, device)

        return state, model, checkpoint_manager

:::

::: {#cell-5 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def download_ckpt_single_run(
    run_name: str,
    project: str = "iir-modal/physmodjax",
    tmp_dir: Path = Path("/tmp/physmodjax"),
    overwrite: bool = False,
) -> tuple[Path, DictConfig]:
    filter_dict = {
        "display_name": run_name,
    }

    # if wandb.run is None:
    # wandb.init()

    api: public.Api = wandb.Api()

    runs: public.Runs = api.runs(project, filter_dict)

    assert len(runs) > 0, f"No runs found with name {run_name}"
    assert len(runs) == 1, f"More than one run found with name {run_name}"

    run: public.Run = runs[0]
    conf = OmegaConf.create(run.config)

    artifacts: public.RunArtifacts = run.logged_artifacts()

    artifact: wandb.Artifact

    # check if no artifacts
    if len(artifacts) == 0:
        raise ValueError(f"No artifacts found for run {run_name}")

    for artifact in artifacts:
        if artifact.type == "model":
            checkpoint_path = tmp_dir / artifact.name
            if checkpoint_path.exists() and not overwrite:
                print(f"Checkpoint already exists at {checkpoint_path}, skipping")
                return checkpoint_path, conf
            else:
                artifact.download(checkpoint_path)

    # save config next to checkpoint
    conf_path = checkpoint_path / ".hydra" / "config.yaml"
    conf_path.parent.mkdir(parents=True, exist_ok=True)
    OmegaConf.save(conf, conf_path)

    print(f"Downloaded checkpoint to {checkpoint_path}")
    return checkpoint_path, conf

:::

::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def restore_run_from_wandb(
    run_name: str,  # Name of the run to restore
    wandb_project: str,  # Name of the project on wandb
    device: jax.Device = None,  # Device to restore the model on
) -> tuple[train_state.TrainState, nn.Module, DictConfig]:
    """
    Restores the train state from a wandb run.

    Assumes that we want the model with the best checkpoint according to the metric in the config, or latest step.
    Assumes that we want to initialize the model with an output length equal to the number of steps used in training.
    """

    # Download the checkpoint and config from wandb
    checkpoint_path, cfg = download_ckpt_single_run(run_name, wandb_project)

    kwargs = {"n_steps": cfg.datamodule.num_steps_train[1]}
    if (
        hasattr(cfg, "checkpoint_manager_options")
        and hasattr(cfg.checkpoint_manager_options, "best_fn")
        and hasattr(cfg.checkpoint_manager_options.best_fn, "metric")
    ):
        print(
            f"Restoring best checkpoint according to metric {cfg.checkpoint_manager_options.best_fn.metric}"
        )
        best_metric = cfg.checkpoint_manager_options.best_fn.metric
        best = True
    else:
        print("Restoring latest checkpoint")
        best_metric = None
        best = False

    # Restore the checkpoint
    state, model, _ = restore_experiment_state(
        checkpoint_path,
        best=best,
        best_metric=best_metric,
        kwargs=kwargs,
        device=device,
    )

    return state, model, cfg

:::

import random
import string
from pathlib import Path

from hydra import compose, initialize
from hydra.core.hydra_config import HydraConfig

from physmodjax.scripts.train_rnn import train_rnn

data_array = "/media/fast/data/datasets/physmodjax/ftm_nonlinear_varamp/ftm_string_nonlin_1000_Noise_4000Hz_1.0s.npy"
# data_array = "/home/carlos/projects/physmodjax/data/icaasp25/debug_validation_test/ftm_membrane_nonlin_3_Gaussian2d_16000Hz.npy"
batch_size = 1
split = [0.334, 0.334, 0.334]
extract_channels = [0, 1]
output_dir = ""
charset = string.ascii_letters + string.digits
random_name = "".join(random.choices(charset, k=12))

with initialize(version_base=None, config_path="../../conf"):
    cfg = compose(
        return_hydra_config=True,
        config_name="train_rnn",
        overrides=[
            "+experiment=1d_lru",
            f"++datamodule.data_array={data_array}",
            f"++datamodule.batch_size={batch_size}",
            f"++datamodule.split={split}",
            f"++datamodule.extract_channels={extract_channels}",
            "++model.d_vars=2",
            "++epochs=1",
            "++epochs_val=1",
            "++wandb.project=physmodjax",
            "++wandb.entity=iir-modal",
            f"++wandb.name={random_name}",
        ],
    )
    OmegaConf.register_new_resolver("eval", eval, replace=True)
    OmegaConf.resolve(cfg)

    cfg_no_hydra = {k: v for (k, v) in cfg.items() if "hydra" not in k}
    print(OmegaConf.to_yaml(cfg_no_hydra))

    HydraConfig.instance().set_config(cfg)
    print(OmegaConf.to_yaml(HydraConfig.get().runtime))

    output_dir = Path(cfg.hydra.run.dir).absolute()
    # HydraConfig.get().runtime["output_dir"] = output_dir
    HydraConfig.instance().set_config(cfg)

    print(f"Output dir: {output_dir}")

    train_rnn(cfg)
INFO:2025-06-03 18:40:13,669:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
model:
  _target_: physmodjax.models.ssm.BatchStackedSSMModel
  _partial_: true
  training: true
  norm: batch
  activation: gelu
  d_model: 101
  d_vars: 2
  n_layers: 5
  ssm_first_layer:
    _target_: physmodjax.models.ssm.LRU
    _partial_: true
    r_min: 0.99
    r_max: 1.0
    d_hidden: 128
  ssm:
    _target_: physmodjax.models.ssm.LRU
    _partial_: true
    r_min: 0.99
    r_max: 1.0
    d_hidden: 128
datamodule:
  _target_: physmodjax.utils.data.DirectoryDataModule
  split:
  - 0.334
  - 0.334
  - 0.334
  batch_size: 1
  extract_channels:
  - 0
  - 1
  total_num_train: 4000
  total_num_val: 4000
  total_num_test: 4000
  num_steps_train:
  - 1
  - 3999
  num_steps_val:
  - 1
  - 3999
  mode: split
  standardize_dataset: true
  windowed: false
  cache: true
  data_array: /media/fast/data/datasets/physmodjax/ftm_nonlinear_varamp/ftm_string_nonlin_1000_Noise_4000Hz_1.0s.npy
jax:
  platform_name: null
  preallocate_gpu_memory: false
optimiser:
  learning_rate: 0.0001
  schedules:
    regular:
      _target_: optax.constant_schedule
      value: 0.0001
    ssm:
      _target_: optax.constant_schedule
      value: 2.5e-05
  optimizers:
    regular:
      _target_: optax.adamw
      learning_rate:
        _target_: optax.constant_schedule
        value: 0.0001
    ssm:
      _target_: optax.adam
      learning_rate:
        _target_: optax.constant_schedule
        value: 2.5e-05
  gradient_transform:
    _target_: optax.multi_transform
    transforms:
      regular:
        _target_: optax.adamw
        learning_rate:
          _target_: optax.constant_schedule
          value: 0.0001
      ssm:
        _target_: optax.adam
        learning_rate:
          _target_: optax.constant_schedule
          value: 2.5e-05
    param_labels: null
  keys_to_labels:
    ssm:
    - nu_log
    - theta_log
    - gamma_log
    - B_re
    - B_im
    - C_re
    - C_im
    regular: []
gradient_clip:
  _target_: optax.clip_by_global_norm
  max_norm: 1.0
loss:
  _target_: physmodjax.utils.losses.mse_loss
  _partial_: true
seed: 3407
epochs: 1
epochs_val: 1
debug: false
frozen: []
init_with_analytical: false
init_from_linear: false
enable_checkpointing: true
early_stopping:
  _target_: flax.training.early_stopping.EarlyStopping
  min_delta: 0.001
  patience: 10
schedule_type: constant
verbosity: info
checkpoint_manager_options:
  _target_: orbax.checkpoint.CheckpointManagerOptions
  max_to_keep: 1
  create: true
  best_fn:
    _target_: physmodjax.utils.metrics.make_best_fn
    metric: val/mae_rel
  best_mode: min
wandb:
  project: physmodjax
  entity: iir-modal
  name: FktCAm7xArHa

version: 1.3.2
version_base: '1.3'
cwd: /home/carlos/projects/physmodjax_private/nbs/utils
config_sources:
- path: hydra.conf
  schema: pkg
  provider: hydra
- path: /home/carlos/projects/physmodjax_private/conf
  schema: file
  provider: main
- path: ''
  schema: structured
  provider: schema
output_dir: ???
choices:
  experiment: 1d_lru
  loss: default
  gradient_clip: default.yaml
  optimiser: default.yaml
  jax: default.yaml
  datamodule: string
  model: 1d_lru
  hydra/env: default
  hydra/callbacks: null
  hydra/job_logging: default
  hydra/hydra_logging: default
  hydra/hydra_help: default
  hydra/help: default
  hydra/sweeper: basic
  hydra/launcher: basic
  hydra/output: default

Output dir: /home/carlos/projects/physmodjax_private/nbs/utils/outputs/2025-06-03/18-40-13
[06/03/25 18:40:13] INFO     Unable to initialize backend 'tpu': INTERNAL: Failed to open         xla_bridge.py:867
                             libtpu.so: libtpu.so: cannot open shared object file: No such file                    
                             or directory                                                                          
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[8], line 54
     50 HydraConfig.instance().set_config(cfg)
     52 print(f"Output dir: {output_dir}")
---> 54 train_rnn(cfg)

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/hydra/main.py:83, in main.<locals>.main_decorator.<locals>.decorated_main(cfg_passthrough)
     80 @functools.wraps(task_function)
     81 def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any:
     82     if cfg_passthrough is not None:
---> 83         return task_function(cfg_passthrough)
     84     else:
     85         args_parser = get_args_parser()

File ~/projects/physmodjax_private/physmodjax/scripts/train_rnn.py:686, in train_rnn(cfg)
    683 # Initialise logging
    684 output_dir = Path(HydraConfig.get().run.dir).absolute()
--> 686 run = wandb.init(
    687     dir=output_dir,
    688     config=OmegaConf.to_container(
    689         cfg,
    690         resolve=True,
    691         throw_on_missing=False,
    692     ),
    693     **cfg.wandb,
    694 )
    696 model_cls = hydra.utils.instantiate(cfg.model)
    697 datamodule = hydra.utils.instantiate(cfg.datamodule)

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_init.py:1255, in init(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, fork_from, resume_from, settings)
   1253 try:
   1254     wi = _WandbInit()
-> 1255     wi.setup(kwargs)
   1256     return wi.init()
   1258 except KeyboardInterrupt as e:

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_init.py:193, in _WandbInit.setup(self, kwargs)
    184 settings__disable_service = (kwargs.get("settings") or {}).get(
    185     "_disable_service"
    186 ) or os.environ.get(wandb.env._DISABLE_SERVICE)
    188 setup_settings = {
    189     "mode": mode or settings_mode,
    190     "_disable_service": settings__disable_service,
    191 }
--> 193 self._wl = wandb_setup.setup(settings=setup_settings)
    194 # Make sure we have a logger setup (might be an early logger)
    195 assert self._wl is not None

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:386, in setup(settings)
    330 def setup(settings: Optional[Settings] = None) -> Optional["_WandbSetup"]:
    331     """Prepares W&B for use in the current process and its children.
    332 
    333     You can usually ignore this as it is implicitly called by `wandb.init()`.
   (...)
    384         ```
    385     """
--> 386     ret = _setup(settings=settings)
    387     return ret

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:326, in _setup(settings, _reset)
    323     teardown()
    324     return None
--> 326 wl = _WandbSetup(settings=settings)
    327 return wl

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:304, in _WandbSetup.__init__(self, settings)
    302     _WandbSetup._instance._update(settings=settings)
    303     return
--> 304 _WandbSetup._instance = _WandbSetup__WandbSetup(settings=settings, pid=pid)

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:115, in _WandbSetup__WandbSetup.__init__(self, pid, settings, environ)
    112 wandb.termsetup(self._settings, logger)
    114 self._check()
--> 115 self._setup()

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:250, in _WandbSetup__WandbSetup._setup(self)
    247 if not self._settings._noop and not self._settings._disable_service:
    248     from wandb.sdk.lib import service_connection
--> 250     self._connection = service_connection.connect_to_service(self._settings)
    252 sweep_path = self._settings.sweep_param_path
    253 if sweep_path:

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py:40, in connect_to_service(settings)
     37 if conn:
     38     return conn
---> 40 return _start_and_connect_service(settings)

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py:75, in _start_and_connect_service(settings)
     68 """Starts a service process and returns a connection to it.
     69 
     70 An atexit hook is registered to tear down the service process and wait for
     71 it to complete. The hook does not run in processes started using the
     72 multiprocessing module.
     73 """
     74 proc = service._Service(settings)
---> 75 proc.start()
     77 port = proc.sock_port
     78 assert port

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/service/service.py:232, in _Service.start(self)
    231 def start(self) -> None:
--> 232     self._launch_server()

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/service/service.py:224, in _Service._launch_server(self)
    222 self._startup_debug_print("wait_ports")
    223 try:
--> 224     self._wait_for_ports(fname, proc=internal_proc)
    225 except Exception as e:
    226     _sentry.reraise(e)

File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/service/service.py:105, in _Service._wait_for_ports(self, fname, proc)
     96     raise ServiceStartProcessError(
     97         f"The wandb service process exited with {proc.returncode}. "
     98         "Ensure that `sys.executable` is a valid python interpreter. "
   (...)
    102         context=context,
    103     )
    104 if not os.path.isfile(fname):
--> 105     time.sleep(0.2)
    106     continue
    107 try:

KeyboardInterrupt: 
# instantiate the datamodule

datamodule = hydra.utils.instantiate(cfg.datamodule)
train_dataloader = datamodule.train_dataloader
val_dataloader = datamodule.val_dataloader
test_dataloader = datamodule.test_dataloader
checkpoint_path, cfg = download_ckpt_single_run(random_name)
kwargs = {"n_steps": datamodule.num_steps_target_val}
state, model, ckpt_manager = restore_experiment_state(
    checkpoint_path,
    kwargs=kwargs,
)
from functools import partial

import numpy as np

from physmodjax.utils.metrics import (
    accumulate_metrics,
    mae,
    mae_relative,
    mse,
    mse_relative,
)
@partial(jax.jit, static_argnames=("model", "norm"))
def val_step(
    state: train_state.TrainState,
    x,
    y,
    model,
    norm,
):
    if norm in ["batch"]:
        pred = model.apply(
            {"params": state.params, "batch_stats": state.batch_stats}, x
        )
    else:
        pred = model.apply({"params": state.params}, x)

    metrics = {
        "val/mse": mse(y, pred),
        "val/mae": mae(y, pred),
        "val/mse_rel": mse_relative(y, pred),
        "val/mae_rel": mae_relative(y, pred),
    }
    return metrics, pred


val_batch_metrics = []
for x, y in val_dataloader:
    metrics, pred = val_step(
        state,
        x=x,
        y=y,
        model=model,
        norm=cfg.model.norm,
    )
    val_batch_metrics.append(metrics)
val_batch_metrics = accumulate_metrics(val_batch_metrics)

metrics = ckpt_manager.metrics(ckpt_manager.best_step())
val_metrics = {k: v for k, v in metrics.items() if "val" in k}

for key, value in val_metrics.items():
    assert np.isclose(value, val_batch_metrics[key], atol=1e-6), (
        f"Metric {key} does not match: {value} != {val_batch_metrics[key]}"
    )

::: {#cell-12 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def fix_prepend_ones(
    run_name: str,
    wandb_project: str,
):
    checkpoint_path, conf = download_ckpt_single_run(run_name, project=wandb_project)
    # if conf.model.dynamics_model.model exists but there is no prepend-ones key, add it.
    if hasattr(conf.model, "dynamics_model"):
        if hasattr(conf.model.dynamics_model, "model"):
            if hasattr(conf.model.dynamics_model.model, "prepend_ones"):
                print(f"Model {run_name} has prepend_ones in the wrong place")
                del conf.model.dynamics_model.model.prepend_ones
        if not hasattr(conf.model.dynamics_model, "prepend_ones"):
            print(f"Model {run_name} has no prepend_ones")
            conf.model.dynamics_model.prepend_ones = False
    # Save config updated
    conf_path = checkpoint_path / ".hydra" / "config.yaml"
    conf_path.parent.mkdir(parents=True, exist_ok=True)
    OmegaConf.save(conf, conf_path)
    return conf

:::

::: {#cell-13 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}

def get_config_single_run(
    run_name: str,  # Name of the run to restore
    project: str,  # Name of the project on wandb
) -> DictConfig:
    """
    Retrieves the configuration of a single run from WandB.
    """
    filter_dict = {
        "display_name": run_name,
    }

    api: public.Api = wandb.Api()

    runs: public.Runs = api.runs(project, filter_dict)

    if len(runs) == 0:
        raise ValueError(f"No runs found with name {run_name}")
    if len(runs) > 1:
        raise ValueError(f"More than one run found with name {run_name}")

    run: public.Run = runs[0]
    conf = OmegaConf.create(run.config)
    return conf

:::