Checkpoint utils

/Users/diaz/mambaforge/envs/physmodjax/lib/python3.10/site-packages/fastcore/docscrape.py:230: UserWarning: potentially wrong underline length... 
Returns: 
------- in 
Restores the train state from a run.
...
  else: warn(msg)

source

restore_experiment_state

 restore_experiment_state (run_path:pathlib.Path, best:bool=True,
                           step_to_restore:int=None,
                           x0_shape:Tuple[int]=(1, 101, 1),
                           x_shape:Tuple[int]=(1, 1, 101, 1),
                           kwargs:dict={})

*Restores the train state from a run.

Args: run_path (Path): Path to the run directory (e.g. “outputs/2024-01-23/22-15-11”)

Returns:

train_state.TrainState: The train state of the experiment
nn.Module: The model used in the experiment
CheckpointManager: The checkpoint manager*
Type Default Details
run_path Path Path to the run directory (e.g. “outputs/2024-01-23/22-15-11”)
best bool True If True, restore the best checkpoint instead of the latest
step_to_restore int None If not None, restore the checkpoint at this step
x0_shape Tuple (1, 101, 1) Shape of the initial condition
x_shape Tuple (1, 1, 101, 1) Shape of the input data
kwargs dict {} Additional arguments to pass to the model
Returns Tuple

source

download_ckpt_single_run

 download_ckpt_single_run (run_name:str, project:str,
                           tmp_dir:pathlib.Path=Path('/tmp/physmodjax'),
                           overwrite:bool=False)
from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from physmodjax.scripts.train_rnn import train_rnn
from pathlib import Path
data_array = "../data/ftm_string_nonlin_1000_Noise_4000Hz_1.0s.npy"
batch_size = 1
split = [0.01, 0.01, 0.01]
extract_channels = [0]
output_dir = ""

with initialize(version_base=None, config_path="../../conf"):
    cfg = compose(
        return_hydra_config=True,
        config_name="train_rnn",
        overrides=[
            "+experiment=1d_koopman",
            f"++datamodule.data_array={data_array}",
            f"++datamodule.batch_size={batch_size}",
            f"++datamodule.split={split}",
            f"++datamodule.extract_channels={extract_channels}",
            "++model.d_vars=1",
            "++epochs=1",
            "++epochs_val=1",
            "++wandb.project=physmodjax",
            "++wandb.entity=iir-modal"
        ],
    )
    OmegaConf.register_new_resolver("eval", eval, replace=True)
    OmegaConf.resolve(cfg)

    cfg_no_hydra = {k:v for (k,v) in cfg.items() if "hydra" not in k} 
    print(OmegaConf.to_yaml(cfg_no_hydra))

    HydraConfig.instance().set_config(cfg)
    print(OmegaConf.to_yaml((HydraConfig.get().runtime)))

    output_dir = Path(cfg.hydra.run.dir).absolute()
    # HydraConfig.get().runtime["output_dir"] = output_dir
    HydraConfig.instance().set_config(cfg)

    print(f"Output dir: {output_dir}")

    train_rnn(cfg)
wandb: WARNING Path /Users/diaz/projects/physmodjax/nbs/utils/outputs/2024-09-03/12-28-33/wandb/ wasn't writable, using system temp directory.
model:
  _target_: physmodjax.models.autoencoders.BatchedKoopmanAutoencoder1D
  _partial_: true
  d_vars: 1
  d_model: 101
  norm: layer
  encoder_model:
    _target_: physmodjax.models.mlp.MLP
    _partial_: true
    hidden_channels:
    - 128
    - 128
    - 256
    kernel_init:
      _target_: flax.linen.initializers.orthogonal
  decoder_model:
    _target_: physmodjax.models.mlp.MLP
    _partial_: true
    hidden_channels:
    - 128
    - 128
    - 101
    kernel_init:
      _target_: flax.linen.initializers.orthogonal
  dynamics_model:
    _target_: physmodjax.models.recurrent.LRUDynamics
    _partial_: true
    d_hidden: 128
    r_min: 0.99
    r_max: 0.999
    max_phase: 6.28
    clip_eigs: true
datamodule:
  _target_: physmodjax.utils.data.DirectoryDataModule
  split:
  - 0.01
  - 0.01
  - 0.01
  batch_size: 1
  extract_channels:
  - 0
  total_num_train: 4000
  total_num_val: 4000
  total_num_test: 4000
  num_steps_train:
  - 1
  - 3999
  num_steps_val:
  - 1
  - 3999
  mode: split
  standardize_dataset: true
  windowed: false
  cache: true
  data_array: ../data/ftm_string_nonlin_1000_Noise_4000Hz_1.0s.npy
jax:
  platform_name: null
  preallocate_gpu_memory: false
optimiser:
  _target_: optax.adamw
  learning_rate: 0.0001
gradient_clip:
  _target_: optax.clip_by_global_norm
  max_norm: 1.0
loss:
  _target_: physmodjax.utils.losses.lindyn_loss
  _partial_: true
  encdec_weight: 1.0
  lindyn_weight: 0.01
  pred_weight: 1.0
seed: 3407
epochs: 1
epochs_val: 1
frozen: []
init_from_linear: false
schedule_type: constant
wandb:
  group: 1d
  project: physmodjax
  entity: iir-modal

version: 1.3.2
version_base: '1.3'
cwd: /Users/diaz/projects/physmodjax/nbs/utils
config_sources:
- path: hydra.conf
  schema: pkg
  provider: hydra
- path: /Users/diaz/projects/physmodjax/conf
  schema: file
  provider: main
- path: ''
  schema: structured
  provider: schema
output_dir: ???
choices:
  experiment: 1d_koopman
  loss: lindyn
  gradient_clip: default.yaml
  optimiser: default.yaml
  jax: default.yaml
  datamodule: string
  model: 1d_koopman
  hydra/env: default
  hydra/callbacks: null
  hydra/job_logging: default
  hydra/hydra_logging: default
  hydra/hydra_help: default
  hydra/help: default
  hydra/sweeper: basic
  hydra/launcher: basic
  hydra/output: default

Output dir: /Users/diaz/projects/physmodjax/nbs/utils/outputs/2024-09-03/12-28-33
Finishing last run (ID:bg96gmb3) before initializing another...
View run chocolate-dream-1762 at: https://wandb.ai/iir-modal/physmodjax/runs/bg96gmb3
View project at: https://wandb.ai/iir-modal/physmodjax
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: /var/folders/f_/jbsvj3wx2gv9z_2s7ywlc8p00000gn/T/wandb/run-20240903_120018-bg96gmb3/logs
Successfully finished last run (ID:bg96gmb3). Initializing new run:
Tracking run with wandb version 0.17.8
Run data is saved locally in /var/folders/f_/jbsvj3wx2gv9z_2s7ywlc8p00000gn/T/wandb/run-20240903_122833-4z5togcj
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:92, in _call_target(_target_, _partial_, args, kwargs, full_key)
     91 try:
---> 92     return _target_(*args, **kwargs)
     93 except Exception as e:

File ~/projects/physmodjax/physmodjax/utils/data.py:407, in DirectoryDataModule.__init__(self, data_array, split, batch_size, extract_channels, num_steps_train, num_steps_val, standardize_dataset, mean, std, mode, windowed, hankelize, shuffle_train, cache, rolling_windows, total_num_train, total_num_val, total_num_test)
    405 self.val_batch_size = batch_size
--> 407 assert Path(data_array).exists() or all(
    408     [Path(d).exists() for d in data_array]
    409 ), "The data array does not exist"
    411 # if we have a list of directories, we will concatenate the data
    412 # else we will assume that the directory contains the data

AssertionError: The data array does not exist

The above exception was the direct cause of the following exception:

InstantiationException                    Traceback (most recent call last)
Cell In[12], line 41
     37 HydraConfig.instance().set_config(cfg)
     39 print(f"Output dir: {output_dir}")
---> 41 train_rnn(cfg)

File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/main.py:83, in main.<locals>.main_decorator.<locals>.decorated_main(cfg_passthrough)
     80 @functools.wraps(task_function)
     81 def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any:
     82     if cfg_passthrough is not None:
---> 83         return task_function(cfg_passthrough)
     84     else:
     85         args_parser = get_args_parser()

File ~/projects/physmodjax/physmodjax/scripts/train_rnn.py:482, in train_rnn(cfg)
    471 run = wandb.init(
    472     dir=output_dir,
    473     config=OmegaConf.to_container(
   (...)
    478     **cfg.wandb,
    479 )
    481 model_cls = hydra.utils.instantiate(cfg.model)
--> 482 datamodule = hydra.utils.instantiate(cfg.datamodule)
    484 # Log data info
    485 wandb.config.update({"output_dir": output_dir})

File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:226, in instantiate(config, *args, **kwargs)
    223     _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)
    224     _partial_ = config.pop(_Keys.PARTIAL, False)
--> 226     return instantiate_node(
    227         config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
    228     )
    229 elif OmegaConf.is_list(config):
    230     # Finalize config (convert targets to strings, merge with kwargs)
    231     config_copy = copy.deepcopy(config)

File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:347, in instantiate_node(node, convert, recursive, partial, *args)
    342                 value = instantiate_node(
    343                     value, convert=convert, recursive=recursive
    344                 )
    345             kwargs[key] = _convert_node(value, convert)
--> 347     return _call_target(_target_, partial, args, kwargs, full_key)
    348 else:
    349     # If ALL or PARTIAL non structured or OBJECT non structured,
    350     # instantiate in dict and resolve interpolations eagerly.
    351     if convert == ConvertMode.ALL or (
    352         convert in (ConvertMode.PARTIAL, ConvertMode.OBJECT)
    353         and node._metadata.object_type in (None, dict)
    354     ):

File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:97, in _call_target(_target_, _partial_, args, kwargs, full_key)
     95 if full_key:
     96     msg += f"\nfull_key: {full_key}"
---> 97 raise InstantiationException(msg) from e

InstantiationException: Error in call to target 'physmodjax.utils.data.DirectoryDataModule':
AssertionError('The data array does not exist')
full_key: datamodule
# instantiate the datamodule

datamodule = hydra.utils.instantiate(cfg.datamodule)
train_dataloader = datamodule.train_dataloader
val_dataloader = datamodule.val_dataloader
test_dataloader = datamodule.test_dataloader
checkpoint_path, cfg = download_ckpt_single_run("eager-valley-1758")
kwargs = {"n_steps": datamodule.num_steps_target_val}
state, model, ckpt_manager = restore_experiment_state(
    checkpoint_path,
    kwargs=kwargs,
)
Checkpoint already exists at /tmp/physmodjax/checkpoints_fiug7qv5:v0, skipping
Using data_info from config: [1, 101, 1]
Restoring checkpoint from step 1...
/home/diaz/anaconda3/envs/physmodjax_private/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1552: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(
from functools import partial
from physmodjax.utils.metrics import (
    mse,
    mae,
    mse_relative,
    mae_relative,
    accumulate_metrics,
)
import numpy as np
@partial(jax.jit, static_argnames=("model", "norm"))
def val_step(
    state: train_state.TrainState,
    x,
    y,
    model,
    norm,
):
    if norm in ["batch"]:
        pred = model.apply(
            {"params": state.params, "batch_stats": state.batch_stats}, x
        )
    else:
        pred = model.apply({"params": state.params}, x)

    metrics = {
        "val/mse": mse(y, pred),
        "val/mae": mae(y, pred),
        "val/mse_rel": mse_relative(y, pred),
        "val/mae_rel": mae_relative(y, pred),
    }
    return metrics, pred


val_batch_metrics = []
for x, y in val_dataloader:

    metrics, pred = val_step(
        state,
        x=x,
        y=y,
        model=model,
        norm=cfg.model.norm,
    )
    val_batch_metrics.append(metrics)
val_batch_metrics = accumulate_metrics(val_batch_metrics)

metrics = ckpt_manager.metrics(ckpt_manager.best_step())
val_metrics = {k: v for k, v in metrics.items() if "val" in k}

for key, value in val_metrics.items():
    assert np.isclose(
        value, val_batch_metrics[key], atol=1e-6
    ), f"Metric {key} does not match: {value} != {val_batch_metrics[key]}"