Training an RNN

An script to train a generic RNN model.

source

create_train_state

 create_train_state (model:flax.linen.module.Module, rng:jax.Array,
                     x_shape:Tuple[int,int,int,int], num_steps:int,
                     norm:str='layer', learning_rate:float=0.001, grad_cli
                     p:optax._src.base.GradientTransformation=GradientTran
                     sformation(init=<function
                     clip_by_global_norm.<locals>.init_fn at 0x15ec58c10>,
                     update=<function
                     clip_by_global_norm.<locals>.update_fn at
                     0x15ec58ca0>), components_to_freeze:List[str]=[],
                     schedule_type:str='constant', debug:bool=False)
Type Default Details
model Module
rng Array
x_shape Tuple
num_steps int number of training steps
norm str layer “layer” or “batch”
learning_rate float 0.001
grad_clip GradientTransformation GradientTransformation(init=<function clip_by_global_norm..init_fn at 0x15ec58c10>, update=<function clip_by_global_norm..update_fn at 0x15ec58ca0>)
components_to_freeze List []
schedule_type str constant “cosine” or “constant”
debug bool False print debug information
Returns TrainState

source

train

 train (model_cls, datamodule, cfg:omegaconf.dictconfig.DictConfig,
        checkpoint_manager:orbax.checkpoint.checkpoint_manager.CheckpointM
        anager)

source

train_rnn

 train_rnn (cfg:omegaconf.dictconfig.DictConfig)

Train RNN model

# TODO: Make a ROOT_DIR global variable that can be used anywhere to run commands reproducibly. Maybe force hydra to always run there?
!cd ../.. ; env HYDRA_FULL_ERROR=1 WANDB_MODE=disabled train_rnn +experiment=test
model:
  _target_: physmodjax.fno.rnn.BatchFNORNN
  hidden_channels: 4
  grid_size: 101
  n_spectral_layers: 2
  out_channels: 2
datamodule:
  _target_: physmodjax.scripts.dataset_generation.DirectoryDataModule
  batch_size: 1
  data_directory: data/test
jax:
  platform_name: cpu
  preallocate_gpu_memory: false
optimiser:
  _target_: optax.adam
  learning_rate: 0.001
gradient_clip:
  _target_: optax.clip_by_global_norm
  max_norm: 1.0
seed: 3407
epochs: 1
wandb:
  project: physmodjax
  entity: iir-modal
  group: rnn-test
  job_type: train
  name: null
project: physmodjax

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1702859671.726977   71570 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
[2023-12-18 00:34:31,800][jax._src.xla_bridge][INFO] - Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
[2023-12-18 00:34:31,801][jax._src.xla_bridge][INFO] - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
jax devices:  [CpuDevice(id=0)]
100%|█████████████████████████████████████████████| 1/1 [00:33<00:00, 33.96s/it]
[2023-12-18 00:35:08,662][absl][INFO] - OCDBT is initialized successfully.
[2023-12-18 00:35:08,662][absl][INFO] - Saving item to /home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints.
[2023-12-18 00:35:08,687][absl][INFO] - Renaming /home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints.orbax-checkpoint-tmp-1702859708663140 to /home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints
[2023-12-18 00:35:08,687][absl][INFO] - Finished saving checkpoint to `/home/carlos/projects/physmodjax/outputs/2023-12-18/00-34-31/checkpoints`.