def restore_experiment_state( run_path: Path, # Path to the run directory (e.g. "outputs/2024-01-23/22-15-11") best: bool=True, # If True, restore the best checkpoint instead of the latest best_metric: str="val/mae_rel", # Metric to use for best checkpoint selection step_to_restore: int=None, # If not None, restore the checkpoint at this step, incompatible with `best` kwargs: dict= {}, # Additional arguments to pass to the model device: jax.Device =None, # Device to restore the model on) ->tuple[train_state.TrainState, nn.Module, obc.CheckpointManager]:""" Restores the train state from a run. Args: run_path (Path): Path to the run directory (e.g. "outputs/2024-01-23/22-15-11") best (bool): If True, restore the best checkpoint instead of the latest step_to_restore (int): If not None, restore the checkpoint at this step kwargs (dict): Additional arguments to pass to the model device (jax.Device): Device to restore the model on Returns: ------- train_state.TrainState: The train state of the experiment nn.Module: The model used in the experiment CheckpointManager: The checkpoint manager """# Make sure the path is a Path object run_path = Path(run_path)# These are hardcoded, do not change ckpt_path = run_path /"checkpoints" config_path = run_path /".hydra"/"config.yaml" cfg = OmegaConf.load(config_path)# Check if either `best` or `step_to_restore` is set, but not bothif best:if step_to_restore isnotNone:raiseValueError("You cannot set both `best` and `step_to_restore`. Please choose one." )if best_metric isNone:raiseValueError("If `best=True`, you must provide a `best_metric` to determine the best checkpoint." ) options = obc.CheckpointManagerOptions( max_to_keep=1, create=True, best_fn=lambda x: float( x[best_metric] ), # Shouldn't be hardcoded here, not a problem atm because we only save one step, best best_mode="min", )def set_restore_type(x: Any) -> obc.RestoreArgs:return obc.RestoreArgs(restore_type=np.ndarray)with obc.CheckpointManager( ckpt_path, options=options, item_handlers={"state": obc.PyTreeCheckpointHandler(),"default": obc.PyTreeCheckpointHandler(), }, ) as checkpoint_manager: model_cls: nn.Module = hydra.utils.instantiate(cfg.model) model = model_cls(training=False, **kwargs)# Get checkpoint metadata step = ( checkpoint_manager.latest_step()ifnot bestelse checkpoint_manager.best_step() ) step = step_to_restore if step_to_restore isnotNoneelse step metadatas = checkpoint_manager.item_metadata(step)print(f"Restoring checkpoint from step {step}...")# Restore the metadata# Backwards compatibility for older checkpointsif"state"in metadatas and metadatas.state isnotNone: restore_args_state = jax.tree_util.tree_map( set_restore_type, metadatas["state"] ) metadatas = checkpoint_manager.restore( step=step, args=obc.args.Composite( state=obc.args.PyTreeRestore( item=metadatas["state"], restore_args=restore_args_state, ), ), ) metadata_state = metadatas.stateelif"default"in metadatas and metadatas.default isnotNone:print("This is a checkpoint with old formatting")if"model"notin metadatas.default:raiseValueError("No model found in the checkpoint") restore_args_default = jax.tree_util.tree_map( set_restore_type, metadatas["default"] ) metadatas = checkpoint_manager.restore( step=step, args=obc.args.Composite( default=obc.args.PyTreeRestore( item=metadatas["default"], restore_args=restore_args_default, ), ), ) metadata_state = metadatas.default["model"]else:raiseValueError("No state found in the checkpoint")if"batch_stats"in metadata_state:# Define TrainState with optional batch_statsclass TrainState(train_state.TrainState): key: jax.Array batch_stats: Any =None# Optional field# Initialize the empty state state = TrainState( key={}, step=0, apply_fn=model.apply, params=metadata_state["params"], tx={}, opt_state=metadata_state["opt_state"], batch_stats=metadata_state["batch_stats"], )else:# Define TrainState with optional batch_statsclass TrainState(train_state.TrainState): key: jax.Array state = TrainState( key={}, step=0, apply_fn=model.apply, params=metadata_state["params"], tx={}, opt_state=metadata_state["opt_state"], )# If device is not specified, use the default deviceif device isNone: device = jax.devices()[0]# Move the state to the specified device state = jax.device_put(state, device)return state, model, checkpoint_manager
def download_ckpt_single_run( run_name: str, project: str="iir-modal/physmodjax", tmp_dir: Path = Path("/tmp/physmodjax"), overwrite: bool=False,) ->tuple[Path, DictConfig]: filter_dict = {"display_name": run_name, }# if wandb.run is None:# wandb.init() api: public.Api = wandb.Api() runs: public.Runs = api.runs(project, filter_dict)assertlen(runs) >0, f"No runs found with name {run_name}"assertlen(runs) ==1, f"More than one run found with name {run_name}" run: public.Run = runs[0] conf = OmegaConf.create(run.config) artifacts: public.RunArtifacts = run.logged_artifacts() artifact: wandb.Artifact# check if no artifactsiflen(artifacts) ==0:raiseValueError(f"No artifacts found for run {run_name}")for artifact in artifacts:if artifact.type=="model": checkpoint_path = tmp_dir / artifact.nameif checkpoint_path.exists() andnot overwrite:print(f"Checkpoint already exists at {checkpoint_path}, skipping")return checkpoint_path, confelse: artifact.download(checkpoint_path)# save config next to checkpoint conf_path = checkpoint_path /".hydra"/"config.yaml" conf_path.parent.mkdir(parents=True, exist_ok=True) OmegaConf.save(conf, conf_path)print(f"Downloaded checkpoint to {checkpoint_path}")return checkpoint_path, conf
def restore_run_from_wandb( run_name: str, # Name of the run to restore wandb_project: str, # Name of the project on wandb device: jax.Device =None, # Device to restore the model on) ->tuple[train_state.TrainState, nn.Module, DictConfig]:""" Restores the train state from a wandb run. Assumes that we want the model with the best checkpoint according to the metric in the config, or latest step. Assumes that we want to initialize the model with an output length equal to the number of steps used in training. """# Download the checkpoint and config from wandb checkpoint_path, cfg = download_ckpt_single_run(run_name, wandb_project) kwargs = {"n_steps": cfg.datamodule.num_steps_train[1]}if (hasattr(cfg, "checkpoint_manager_options")andhasattr(cfg.checkpoint_manager_options, "best_fn")andhasattr(cfg.checkpoint_manager_options.best_fn, "metric") ):print(f"Restoring best checkpoint according to metric {cfg.checkpoint_manager_options.best_fn.metric}" ) best_metric = cfg.checkpoint_manager_options.best_fn.metric best =Trueelse:print("Restoring latest checkpoint") best_metric =None best =False# Restore the checkpoint state, model, _ = restore_experiment_state( checkpoint_path, best=best, best_metric=best_metric, kwargs=kwargs, device=device, )return state, model, cfg
INFO:2025-06-03 18:40:13,669:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
[06/03/25 18:40:13] INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open xla_bridge.py:867 libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
---------------------------------------------------------------------------KeyboardInterrupt Traceback (most recent call last)
Cell In[8], line 54 50 HydraConfig.instance().set_config(cfg)
52print(f"Output dir: {output_dir}")
---> 54train_rnn(cfg)
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/hydra/main.py:83, in main.<locals>.main_decorator.<locals>.decorated_main(cfg_passthrough) 80@functools.wraps(task_function)
81defdecorated_main(cfg_passthrough: Optional[DictConfig] =None) -> Any:
82if cfg_passthrough isnotNone:
---> 83returntask_function(cfg_passthrough) 84else:
85 args_parser = get_args_parser()
File ~/projects/physmodjax_private/physmodjax/scripts/train_rnn.py:686, in train_rnn(cfg) 683# Initialise logging 684 output_dir = Path(HydraConfig.get().run.dir).absolute()
--> 686 run =wandb.init( 687dir=output_dir, 688config=OmegaConf.to_container( 689cfg, 690resolve=True, 691throw_on_missing=False, 692), 693**cfg.wandb, 694) 696 model_cls = hydra.utils.instantiate(cfg.model)
697 datamodule = hydra.utils.instantiate(cfg.datamodule)
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_init.py:1255, in init(job_type, dir, config, project, entity, reinit, tags, group, name, notes, magic, config_exclude_keys, config_include_keys, anonymous, mode, allow_val_change, resume, force, tensorboard, sync_tensorboard, monitor_gym, save_code, id, fork_from, resume_from, settings) 1253try:
1254 wi = _WandbInit()
-> 1255wi.setup(kwargs) 1256return wi.init()
1258exceptKeyboardInterruptas e:
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_init.py:193, in _WandbInit.setup(self, kwargs) 184 settings__disable_service = (kwargs.get("settings") or {}).get(
185"_disable_service" 186 ) or os.environ.get(wandb.env._DISABLE_SERVICE)
188 setup_settings = {
189"mode": mode or settings_mode,
190"_disable_service": settings__disable_service,
191 }
--> 193self._wl =wandb_setup.setup(settings=setup_settings) 194# Make sure we have a logger setup (might be an early logger) 195assertself._wl isnotNone
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:386, in setup(settings) 330defsetup(settings: Optional[Settings] =None) -> Optional["_WandbSetup"]:
331"""Prepares W&B for use in the current process and its children. 332 333 You can usually ignore this as it is implicitly called by `wandb.init()`. (...) 384 ``` 385 """--> 386 ret =_setup(settings=settings) 387return ret
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:326, in _setup(settings, _reset) 323 teardown()
324returnNone--> 326 wl =_WandbSetup(settings=settings) 327return wl
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:304, in _WandbSetup.__init__(self, settings) 302 _WandbSetup._instance._update(settings=settings)
303return--> 304 _WandbSetup._instance =_WandbSetup__WandbSetup(settings=settings,pid=pid)
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:115, in _WandbSetup__WandbSetup.__init__(self, pid, settings, environ) 112 wandb.termsetup(self._settings, logger)
114self._check()
--> 115self._setup()
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/wandb_setup.py:250, in _WandbSetup__WandbSetup._setup(self) 247ifnotself._settings._noop andnotself._settings._disable_service:
248fromwandb.sdk.libimport service_connection
--> 250self._connection =service_connection.connect_to_service(self._settings) 252 sweep_path =self._settings.sweep_param_path
253if sweep_path:
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py:40, in connect_to_service(settings) 37if conn:
38return conn
---> 40return_start_and_connect_service(settings)
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/lib/service_connection.py:75, in _start_and_connect_service(settings) 68"""Starts a service process and returns a connection to it. 69 70An atexit hook is registered to tear down the service process and wait for 71it to complete. The hook does not run in processes started using the 72multiprocessing module. 73""" 74 proc = service._Service(settings)
---> 75proc.start() 77 port = proc.sock_port
78assert port
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/service/service.py:232, in _Service.start(self) 231defstart(self) ->None:
--> 232self._launch_server()
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/service/service.py:224, in _Service._launch_server(self) 222self._startup_debug_print("wait_ports")
223try:
--> 224self._wait_for_ports(fname,proc=internal_proc) 225exceptExceptionas e:
226 _sentry.reraise(e)
File ~/projects/physmodjax_private/.venv/lib/python3.10/site-packages/wandb/sdk/service/service.py:105, in _Service._wait_for_ports(self, fname, proc) 96raise ServiceStartProcessError(
97f"The wandb service process exited with {proc.returncode}. " 98"Ensure that `sys.executable` is a valid python interpreter. " (...) 102 context=context,
103 )
104ifnot os.path.isfile(fname):
--> 105time.sleep(0.2) 106continue 107try:
KeyboardInterrupt:
def fix_prepend_ones( run_name: str, wandb_project: str,): checkpoint_path, conf = download_ckpt_single_run(run_name, project=wandb_project)# if conf.model.dynamics_model.model exists but there is no prepend-ones key, add it.ifhasattr(conf.model, "dynamics_model"):ifhasattr(conf.model.dynamics_model, "model"):ifhasattr(conf.model.dynamics_model.model, "prepend_ones"):print(f"Model {run_name} has prepend_ones in the wrong place")del conf.model.dynamics_model.model.prepend_onesifnothasattr(conf.model.dynamics_model, "prepend_ones"):print(f"Model {run_name} has no prepend_ones") conf.model.dynamics_model.prepend_ones =False# Save config updated conf_path = checkpoint_path /".hydra"/"config.yaml" conf_path.parent.mkdir(parents=True, exist_ok=True) OmegaConf.save(conf, conf_path)return conf
def get_config_single_run( run_name: str, # Name of the run to restore project: str, # Name of the project on wandb) -> DictConfig:""" Retrieves the configuration of a single run from WandB. """ filter_dict = {"display_name": run_name, } api: public.Api = wandb.Api() runs: public.Runs = api.runs(project, filter_dict)iflen(runs) ==0:raiseValueError(f"No runs found with name {run_name}")iflen(runs) >1:raiseValueError(f"More than one run found with name {run_name}") run: public.Run = runs[0] conf = OmegaConf.create(run.config)return conf