from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from physmodjax.scripts.train_rnn import train_rnn
from pathlib import Path
Checkpoint utils
/Users/diaz/mambaforge/envs/physmodjax/lib/python3.10/site-packages/fastcore/docscrape.py:230: UserWarning: potentially wrong underline length...
Returns:
------- in
Restores the train state from a run.
...
else: warn(msg)
restore_experiment_state
restore_experiment_state (run_path:pathlib.Path, best:bool=True, step_to_restore:int=None, x0_shape:Tuple[int]=(1, 101, 1), x_shape:Tuple[int]=(1, 1, 101, 1), kwargs:dict={})
*Restores the train state from a run.
Args: run_path (Path): Path to the run directory (e.g. “outputs/2024-01-23/22-15-11”)
Returns:
train_state.TrainState: The train state of the experiment
nn.Module: The model used in the experiment
CheckpointManager: The checkpoint manager*
Type | Default | Details | |
---|---|---|---|
run_path | Path | Path to the run directory (e.g. “outputs/2024-01-23/22-15-11”) | |
best | bool | True | If True, restore the best checkpoint instead of the latest |
step_to_restore | int | None | If not None, restore the checkpoint at this step |
x0_shape | Tuple | (1, 101, 1) | Shape of the initial condition |
x_shape | Tuple | (1, 1, 101, 1) | Shape of the input data |
kwargs | dict | {} | Additional arguments to pass to the model |
Returns | Tuple |
download_ckpt_single_run
download_ckpt_single_run (run_name:str, project:str, tmp_dir:pathlib.Path=Path('/tmp/physmodjax'), overwrite:bool=False)
= "../data/ftm_string_nonlin_1000_Noise_4000Hz_1.0s.npy"
data_array = 1
batch_size = [0.01, 0.01, 0.01]
split = [0]
extract_channels = ""
output_dir
with initialize(version_base=None, config_path="../../conf"):
= compose(
cfg =True,
return_hydra_config="train_rnn",
config_name=[
overrides"+experiment=1d_koopman",
f"++datamodule.data_array={data_array}",
f"++datamodule.batch_size={batch_size}",
f"++datamodule.split={split}",
f"++datamodule.extract_channels={extract_channels}",
"++model.d_vars=1",
"++epochs=1",
"++epochs_val=1",
"++wandb.project=physmodjax",
"++wandb.entity=iir-modal"
],
)"eval", eval, replace=True)
OmegaConf.register_new_resolver(
OmegaConf.resolve(cfg)
= {k:v for (k,v) in cfg.items() if "hydra" not in k}
cfg_no_hydra print(OmegaConf.to_yaml(cfg_no_hydra))
HydraConfig.instance().set_config(cfg)print(OmegaConf.to_yaml((HydraConfig.get().runtime)))
= Path(cfg.hydra.run.dir).absolute()
output_dir # HydraConfig.get().runtime["output_dir"] = output_dir
HydraConfig.instance().set_config(cfg)
print(f"Output dir: {output_dir}")
train_rnn(cfg)
wandb: WARNING Path /Users/diaz/projects/physmodjax/nbs/utils/outputs/2024-09-03/12-28-33/wandb/ wasn't writable, using system temp directory.
model:
_target_: physmodjax.models.autoencoders.BatchedKoopmanAutoencoder1D
_partial_: true
d_vars: 1
d_model: 101
norm: layer
encoder_model:
_target_: physmodjax.models.mlp.MLP
_partial_: true
hidden_channels:
- 128
- 128
- 256
kernel_init:
_target_: flax.linen.initializers.orthogonal
decoder_model:
_target_: physmodjax.models.mlp.MLP
_partial_: true
hidden_channels:
- 128
- 128
- 101
kernel_init:
_target_: flax.linen.initializers.orthogonal
dynamics_model:
_target_: physmodjax.models.recurrent.LRUDynamics
_partial_: true
d_hidden: 128
r_min: 0.99
r_max: 0.999
max_phase: 6.28
clip_eigs: true
datamodule:
_target_: physmodjax.utils.data.DirectoryDataModule
split:
- 0.01
- 0.01
- 0.01
batch_size: 1
extract_channels:
- 0
total_num_train: 4000
total_num_val: 4000
total_num_test: 4000
num_steps_train:
- 1
- 3999
num_steps_val:
- 1
- 3999
mode: split
standardize_dataset: true
windowed: false
cache: true
data_array: ../data/ftm_string_nonlin_1000_Noise_4000Hz_1.0s.npy
jax:
platform_name: null
preallocate_gpu_memory: false
optimiser:
_target_: optax.adamw
learning_rate: 0.0001
gradient_clip:
_target_: optax.clip_by_global_norm
max_norm: 1.0
loss:
_target_: physmodjax.utils.losses.lindyn_loss
_partial_: true
encdec_weight: 1.0
lindyn_weight: 0.01
pred_weight: 1.0
seed: 3407
epochs: 1
epochs_val: 1
frozen: []
init_from_linear: false
schedule_type: constant
wandb:
group: 1d
project: physmodjax
entity: iir-modal
version: 1.3.2
version_base: '1.3'
cwd: /Users/diaz/projects/physmodjax/nbs/utils
config_sources:
- path: hydra.conf
schema: pkg
provider: hydra
- path: /Users/diaz/projects/physmodjax/conf
schema: file
provider: main
- path: ''
schema: structured
provider: schema
output_dir: ???
choices:
experiment: 1d_koopman
loss: lindyn
gradient_clip: default.yaml
optimiser: default.yaml
jax: default.yaml
datamodule: string
model: 1d_koopman
hydra/env: default
hydra/callbacks: null
hydra/job_logging: default
hydra/hydra_logging: default
hydra/hydra_help: default
hydra/help: default
hydra/sweeper: basic
hydra/launcher: basic
hydra/output: default
Output dir: /Users/diaz/projects/physmodjax/nbs/utils/outputs/2024-09-03/12-28-33
Finishing last run (ID:bg96gmb3) before initializing another...
View run chocolate-dream-1762 at: https://wandb.ai/iir-modal/physmodjax/runs/bg96gmb3
View project at: https://wandb.ai/iir-modal/physmodjax
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
View project at: https://wandb.ai/iir-modal/physmodjax
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at:
/var/folders/f_/jbsvj3wx2gv9z_2s7ywlc8p00000gn/T/wandb/run-20240903_120018-bg96gmb3/logs
Successfully finished last run (ID:bg96gmb3). Initializing new run:
Tracking run with wandb version 0.17.8
Run data is saved locally in
/var/folders/f_/jbsvj3wx2gv9z_2s7ywlc8p00000gn/T/wandb/run-20240903_122833-4z5togcj
View project at https://wandb.ai/iir-modal/physmodjax
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:92, in _call_target(_target_, _partial_, args, kwargs, full_key) 91 try: ---> 92 return _target_(*args, **kwargs) 93 except Exception as e: File ~/projects/physmodjax/physmodjax/utils/data.py:407, in DirectoryDataModule.__init__(self, data_array, split, batch_size, extract_channels, num_steps_train, num_steps_val, standardize_dataset, mean, std, mode, windowed, hankelize, shuffle_train, cache, rolling_windows, total_num_train, total_num_val, total_num_test) 405 self.val_batch_size = batch_size --> 407 assert Path(data_array).exists() or all( 408 [Path(d).exists() for d in data_array] 409 ), "The data array does not exist" 411 # if we have a list of directories, we will concatenate the data 412 # else we will assume that the directory contains the data AssertionError: The data array does not exist The above exception was the direct cause of the following exception: InstantiationException Traceback (most recent call last) Cell In[12], line 41 37 HydraConfig.instance().set_config(cfg) 39 print(f"Output dir: {output_dir}") ---> 41 train_rnn(cfg) File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/main.py:83, in main.<locals>.main_decorator.<locals>.decorated_main(cfg_passthrough) 80 @functools.wraps(task_function) 81 def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any: 82 if cfg_passthrough is not None: ---> 83 return task_function(cfg_passthrough) 84 else: 85 args_parser = get_args_parser() File ~/projects/physmodjax/physmodjax/scripts/train_rnn.py:482, in train_rnn(cfg) 471 run = wandb.init( 472 dir=output_dir, 473 config=OmegaConf.to_container( (...) 478 **cfg.wandb, 479 ) 481 model_cls = hydra.utils.instantiate(cfg.model) --> 482 datamodule = hydra.utils.instantiate(cfg.datamodule) 484 # Log data info 485 wandb.config.update({"output_dir": output_dir}) File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:226, in instantiate(config, *args, **kwargs) 223 _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE) 224 _partial_ = config.pop(_Keys.PARTIAL, False) --> 226 return instantiate_node( 227 config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ 228 ) 229 elif OmegaConf.is_list(config): 230 # Finalize config (convert targets to strings, merge with kwargs) 231 config_copy = copy.deepcopy(config) File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:347, in instantiate_node(node, convert, recursive, partial, *args) 342 value = instantiate_node( 343 value, convert=convert, recursive=recursive 344 ) 345 kwargs[key] = _convert_node(value, convert) --> 347 return _call_target(_target_, partial, args, kwargs, full_key) 348 else: 349 # If ALL or PARTIAL non structured or OBJECT non structured, 350 # instantiate in dict and resolve interpolations eagerly. 351 if convert == ConvertMode.ALL or ( 352 convert in (ConvertMode.PARTIAL, ConvertMode.OBJECT) 353 and node._metadata.object_type in (None, dict) 354 ): File ~/mambaforge/envs/physmodjax/lib/python3.10/site-packages/hydra/_internal/instantiate/_instantiate2.py:97, in _call_target(_target_, _partial_, args, kwargs, full_key) 95 if full_key: 96 msg += f"\nfull_key: {full_key}" ---> 97 raise InstantiationException(msg) from e InstantiationException: Error in call to target 'physmodjax.utils.data.DirectoryDataModule': AssertionError('The data array does not exist') full_key: datamodule
# instantiate the datamodule
= hydra.utils.instantiate(cfg.datamodule)
datamodule = datamodule.train_dataloader
train_dataloader = datamodule.val_dataloader
val_dataloader = datamodule.test_dataloader test_dataloader
= download_ckpt_single_run("eager-valley-1758")
checkpoint_path, cfg = {"n_steps": datamodule.num_steps_target_val}
kwargs = restore_experiment_state(
state, model, ckpt_manager
checkpoint_path,=kwargs,
kwargs )
Checkpoint already exists at /tmp/physmodjax/checkpoints_fiug7qv5:v0, skipping
Using data_info from config: [1, 101, 1]
Restoring checkpoint from step 1...
/home/diaz/anaconda3/envs/physmodjax_private/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1552: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
from functools import partial
from physmodjax.utils.metrics import (
mse,
mae,
mse_relative,
mae_relative,
accumulate_metrics,
)import numpy as np
@partial(jax.jit, static_argnames=("model", "norm"))
def val_step(
state: train_state.TrainState,
x,
y,
model,
norm,
):if norm in ["batch"]:
= model.apply(
pred "params": state.params, "batch_stats": state.batch_stats}, x
{
)else:
= model.apply({"params": state.params}, x)
pred
= {
metrics "val/mse": mse(y, pred),
"val/mae": mae(y, pred),
"val/mse_rel": mse_relative(y, pred),
"val/mae_rel": mae_relative(y, pred),
}return metrics, pred
= []
val_batch_metrics for x, y in val_dataloader:
= val_step(
metrics, pred
state,=x,
x=y,
y=model,
model=cfg.model.norm,
norm
)
val_batch_metrics.append(metrics)= accumulate_metrics(val_batch_metrics)
val_batch_metrics
= ckpt_manager.metrics(ckpt_manager.best_step())
metrics = {k: v for k, v in metrics.items() if "val" in k}
val_metrics
for key, value in val_metrics.items():
assert np.isclose(
=1e-6
value, val_batch_metrics[key], atolf"Metric {key} does not match: {value} != {val_batch_metrics[key]}" ),