Training

Pytorch lightning modules for training

source

MultiShapeMultiMaterialLitModule

 MultiShapeMultiMaterialLitModule (model:torch.nn.modules.module.Module,
                                   optimizer:Type[torch.optim.optimizer.Op
                                   timizer], scheduler:Type[torch.optim.lr
                                   _scheduler.LRScheduler], criterion:torc
                                   h.nn.modules.module.Module=FFTLoss(),
                                   **kwargs)

Hooks to be used in LightningModule.

Try to run a single batch

from neuralresonator.data import MultiShapeMultiMaterialDataModule
from neuralresonator.models import FC
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

dataset_args = dict()

datamodule = MultiShapeMultiMaterialDataModule(
    train_index_map_path="data/index_map.csv",
    val_index_map_path="data/index_map.csv",
    test_index_map_path="data/index_map.csv",
)

cfg = OmegaConf.create(
    {
        "_target_": "neuralresonator.training.MultiShapeMultiMaterialLitModule",
        "model": {
            "_target_": "neuralresonator.models.CoefficientsFC",
            "input_size": 1007,
            "hidden_sizes": [1024, 1024, 1024, 1024, 1024, 1024],
            "n_parallel": 32,
            "n_biquads": 2,
        },

        "criterion": {
            "_target_": "neuralresonator.utilities.FFTLoss",
        },
        "optimizer": {
            "_target_": "torch.optim.Adam",
            "_partial_": True,
            "lr": 0.0001,
        },
        "scheduler": {
            "_target_": "torch.optim.lr_scheduler.ExponentialLR",
            "_partial_": True,
            "gamma": 0.999,
            "verbose": True,
        },
    }
)
from lightning.pytorch import loggers

model = instantiate(cfg)
logger = loggers.WandbLogger(project="neuralresonator")

trainer = pl.Trainer(
    limit_train_batches=1,
    max_epochs=1,
    limit_val_batches=1,
    logger=logger,
)

trainer.fit(model=model, datamodule=datamodule)
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:197: UserWarning: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
  rank_zero_warn(
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:197: UserWarning: Attribute 'criterion' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['criterion'])`.
  rank_zero_warn(
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/loggers/wandb.py:395: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory ./neuralresonator/fm3vnhcd/checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | CoefficientsFC    | 7.7 M 
1 | criterion | FFTLoss           | 0     
2 | encoder   | EfficientNet      | 5.3 M 
3 | mse       | MeanSquaredError  | 0     
4 | mae       | MeanAbsoluteError | 0     
------------------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.785    Total estimated model params size (MB)
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:280: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=1` reached.
Adjusting learning rate of group 0 to 1.0000e-04.
# Checkpointing
print(f"Model hparams: {model.hparams}")
trainer.save_checkpoint("checkpoint.ckpt")

# Load checkpoint
model = MultiShapeMultiMaterialLitModule.load_from_checkpoint("checkpoint.ckpt")
Model hparams: "criterion": FFTLoss()
"model":     CoefficientsFC(
  (fc): FC(
    (activation): LeakyReLU(negative_slope=0.2, inplace=True)
    (network): Sequential(
      (0): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1007, out_features=1024, bias=True)
        (ln): Identity()
      )
      (1): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (2): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (3): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (4): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (5): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (6): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (7): Linear(in_features=1024, out_features=320, bias=True)
    )
  )
)
"optimizer": functools.partial(<class 'torch.optim.adam.Adam'>, lr=0.0001)
"scheduler": functools.partial(<class 'torch.optim.lr_scheduler.ExponentialLR'>, gamma=0.999, verbose=True)
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:197: UserWarning: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
  rank_zero_warn(
/home/diaz/anaconda3/envs/modal/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:197: UserWarning: Attribute 'criterion' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['criterion'])`.
  rank_zero_warn(