@hydra.main(version_base=None, config_path="../../conf", config_name="train_rnn")def train_rnn(cfg: DictConfig) ->None:""" Train RNN model """ OmegaConf.register_new_resolver("eval",eval, replace=True, ) logging.debug(OmegaConf.to_yaml(cfg, resolve=True)) jax.config.update("jax_platform_name", cfg.jax.platform_name) logging.debug("jax devices: ", jax.devices())# Set matplotlib backend to Agg when running on cluster matplotlib.use("Agg")# Initialise logging output_dir = Path(HydraConfig.get().run.dir).absolute() wandb.require("core") run = wandb.init(dir=output_dir, config=OmegaConf.to_container( cfg, resolve=True, throw_on_missing=False, ),**cfg.wandb, ) model_cls = hydra.utils.instantiate(cfg.model) datamodule = hydra.utils.instantiate(cfg.datamodule)# Log data info wandb.config.update({"output_dir": output_dir}) wandb.config.update({"data_info": datamodule.get_info()}) wandb.config.update( {"data_std": datamodule.std ifhasattr(datamodule, "std") elseNone} ) wandb.config.update( {"data_mean": datamodule.mean ifhasattr(datamodule, "mean") elseNone} ) options = obc.CheckpointManagerOptions( max_to_keep=1, create=True, best_fn=lambda x: float(x["val/mse"]), best_mode="min", )with obc.CheckpointManager( directory=Path(output_dir) /"checkpoints", options=options, item_handlers={"state": obc.PyTreeCheckpointHandler()}, ) as checkpoint_manager: state = train( model_cls=model_cls, datamodule=datamodule, cfg=cfg, checkpoint_manager=checkpoint_manager, ) checkpoint_manager.wait_until_finished() logging.info(f"Checkpoint best step {checkpoint_manager.best_step()}, number of steps: {checkpoint_manager.all_steps()}" )# Save model to wandb artifact = wandb.Artifact( name=f"checkpoints_{wandb.run.id}",type="model", ) artifact.add_dir(checkpoint_manager.directory, name="checkpoints") run.log_artifact(artifact) wandb.finish()
:::
# TODO: Make a ROOT_DIR global variable that can be used anywhere to run commands reproducibly. Maybe force hydra to always run there?!cd ../.. ; env HYDRA_FULL_ERROR=1 WANDB_MODE=disabled train_rnn +experiment=test
model:
_target_: physmodjax.fno.rnn.BatchFNORNN
hidden_channels: 4
grid_size: 101
n_spectral_layers: 2
out_channels: 2
datamodule:
_target_: physmodjax.scripts.dataset_generation.DirectoryDataModule
batch_size: 1
data_directory: data/test
jax:
platform_name: cpu
preallocate_gpu_memory: false
optimiser:
_target_: optax.adam
learning_rate: 0.001
gradient_clip:
_target_: optax.clip_by_global_norm
max_norm: 1.0
seed: 3407
epochs: 1
wandb:
project: physmodjax
entity: iir-modal
group: rnn-test
job_type: train
name: null
project: physmodjax
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1702859671.726977 71570 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
[2023-12-18 00:34:31,800][jax._src.xla_bridge][INFO] - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
[2023-12-18 00:34:31,801][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
jax devices: [CpuDevice(id=0)]
100%|█████████████████████████████████████████████| 1/1 [00:33<00:00, 33.96s/it]
[2023-12-18 00:35:08,662][absl][INFO] - OCDBT is initialized successfully.
[2023-12-18 00:35:08,662][absl][INFO] - Saving item to /home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints.
[2023-12-18 00:35:08,687][absl][INFO] - Renaming /home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints.orbax-checkpoint-tmp-1702859708663140 to /home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints
[2023-12-18 00:35:08,687][absl][INFO] - Finished saving checkpoint to `/home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints`.