Scripts

Scripts for training and processing the data

Train


source

train

 train (cfg:omegaconf.dictconfig.DictConfig)

source

log_hyperparameters

 log_hyperparameters (object_dict:dict)

Log hyperparameters to all loggers.

# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = 3

with initialize(version_base=None, config_path="../configs"):
    cfg = compose(
        config_name="train.yaml",
        return_hydra_config=True,
        overrides=[
            "trainer.max_epochs=1",
            "hydra.runtime.output_dir=outputs",
            "paths.output_dir=${hydra.runtime.output_dir}",
            "paths.work_dir=${hydra.runtime.cwd}",
            "seed=42",
            "logger=null",
            "++datamodule.train_index_map_path=data/index_map.csv",
            "++datamodule.val_index_map_path=data/index_map.csv",
            "++datamodule.test_index_map_path=data/index_map.csv",
        ],
    )
    train(cfg)

Generate dataset


source

gen_dataset

 gen_dataset (cfg:omegaconf.dictconfig.DictConfig)

Export

@hydra.main(version_base=None, config_path="../configs", config_name="export")
def export(
    cfg: DictConfig,
):

    # Load checkpoint
    model = MultiShapeMultiMaterialLitModule.load_from_checkpoint(cfg.ckpt_path)
    model.eval()

    # export encoder to torchscript
    script = torch.jit.script(model.encoder)
    torch.jit.save(script, cfg.encoder_path)

    # export coefficient model to torchscript
    script = torch.jit.script(model.model)
    torch.jit.save(script, cfg.coefficient_model_path)