PyTorch Lightning and SLURM

PyTorch Lightning and SLURM

Jan 30, 2022
PyTorch Lightning and SLURM
Deep Learning
notion image


PyTorch Lightning is a framework for doing deep learning research with PyTorch. It makes things like checkpointing, logging, and distributed training a lot more easier.
In this blog post, I will focus primarily on how to setup the PyTorch lightning to work on SLURM cluster. In the typical SLURM cluster workflow, you need to submit a job to the cluster. Oftentimes, if you specify the job with a very long timeout, your job will get queued very slowly. The workaround here is to submit a short-time job (i.e., 1-6 hours or time enough for 1 training and 1 validation loops) and checkpoint the training state for the next job. The emphasis of this post would be on this particular topic: how to setup PyTorch Lightning to work with a SLURM cluster and checkpointing training?
Dependencies (
  • torch ≥ 1.8
  • pytorch-lightning ≥ 1.5.9
  • hydra ≥ 1.1
  • hydra-submitit-launcher ≥ 1.1.6
In this post, I will show you two ways of accomplishing this:
  1. Setting a model checkpoint to store HPC weights,
  1. Using the fault-tolerant training.

SLURM Introduction

If you are not very familiar with a SLURM environment, SLURM workload manager is a resource management tool used in many supercomputers and computer clusters. Even if you are not working in a SLURM environment, you may skip this section.
First, let’s take a look at a sample script for submitting a job to SLURM.

# Parameters
#SBATCH --account=realitylab
#SBATCH --constraint=[a40|rtx6k|titan]
#SBATCH --cpus-per-task=5
#SBATCH --error=/mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_0_log.err
#SBATCH --gpus-per-node=1
#SBATCH --job-name=run
#SBATCH --mem=256GB
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --open-mode=append
#SBATCH --output=/mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_0_log.out
#SBATCH --partition=ckpt
#SBATCH --signal=USR1@180
#SBATCH --time=120
#SBATCH --wckey=submitit

# command
srun --output /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_%t_log.out --error /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_%t_log.err --unbuffered /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/.env/bin/python -u -m submitit.core._submit /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j
Example script for submitting a job to SLURM
Let’s explain a few details based on the above script. Here, we are running the training script on
  • Account --account=realitylab: an account in which you have resources for submitting a job (Note: this is relevant only if you have multiple accounts under different affiliations in the SLURM cluster.
  • Partition --partition=ckpt: a partition to which you want to submit a job. Usually, a computing cluster is separated into CPU partition, GPU partition, and checkpoint partition. The checkpoint partition is mostly used for small short tasks and shared across the entire SLURM cluster whereas a normal CPU/GPU partition is specific to a particular account.
  • Resource Requested cpus-per-task gpus-per-node mem nodes ntasks-per-node: these are specifications for CPU/GPU/Memory that you request.
  • Time --time=120: time in minute(s) that you want your script to run. In our context, you don’t need to specify the time very long because you would want to minimize the wait time between checkpoints.
  • End-time Signal --signal=USR1@180: this tells that there will be a signal SIGUSR1 that happened around 180 seconds before the end-time of your allocation. This is very useful for our application as we want to wrap things up before we start the next iteration (e.g., saving the model’s weight and important states). In a SLURM environment, the

Hydra and Submitit

Hydra is an open-source Python framework that simplifies complex applications that may have lots of configurations involved. In our context of deep learning experiments, you oftentimes encounter situations where you want to change one of the variables in the data modules or models. This becomes a problem once you make too many changes and are unable to keep track of what’s in each experiment. Hydra is designed to solve this problem.
For example, this is a sample configuration file that I use for my project.
# @package _global_

# to execute this experiment run:
# python experiment=audiovideostream.yaml

  - override /mode: exp.yaml
  - override /trainer: null
  - override /model: null
  - override /datamodule: null
  - override /callbacks: null
  - override /logger: null
  - override /hydra/launcher@_here_: submitit_slurm

# we override default configurations with nulls to prevent them from loading at all
# instead we define all modules and their paths directly in this config,
# so everything is stored in one place
name: "audiovideostream"

seed: 12345

  _target_: pytorch_lightning.Trainer
  gpus: 1
  min_epochs: 1
  max_epochs: 200
  gradient_clip_val: 2.0
  accelerator: gpu
  strategy: ddp
  weights_summary: top
  resume_from_checkpoint: null

  _target_: src.models.audiovideostream.AudioVideoStreamLitModel
  lr: 0.0001
  b1: 0.5
  b2: 0.9
  audio_enc_channels: 32
  audio_emb_channels: 80
  video_enc_channels: 5
  video_emb_channels: 10
  loss_type: all
  video_only: True
  audio_only: False

  _target_: src.datamodules.audiovideostream_datamodule.AudioVideoStreamDataModule
  hdf5_path: ${data_dir}/VoxCeleb2/dataset.hdf5
  video_fps: 25
  audio_fps: 16000
  length: 3.0
  preload: True
  train_val_test_split: [156, 4, 4]
  batch_size: 2
  num_workers: 0
  pin_memory: True

    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val/g_loss"
    save_top_k: -1
    save_last: True
    mode: "min"
    dirpath: "."
    filename: "hpc_ckpt_{epoch:d}"
    auto_insert_metric_name: False

    _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
    save_dir: "."
    name: "csv/${name}"
    version: ${name}
    prefix: ""
    _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
    save_dir: "tensorboard/"
    name: null
    version: ${name}
    log_graph: False
    default_hp_metric: True
    prefix: ""

    _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
    submitit_folder: ${hydra.sweep.dir}/.submitit/%j
    account: realitylab
    partition: ckpt
    timeout_min: 120
    cpus_per_task: 5
    gpus_per_node: 1
    tasks_per_node: 1
    mem_gb: 256
    nodes: 1
    name: ${}
    comment: null
    constraint: "[a40|rtx6k|titan]"
    exclude: null
    cpus_per_gpu: null
    gpus_per_task: null
    mem_per_gpu: null
    mem_per_cpu: null
    signal_delay_s: 600
    max_num_timeout: 20
    additional_parameters: {}
    array_parallelism: 256
    setup: []
Hydra’s key features taken from its website are:
  • Hierarchical configuration composable from multiple sources,
  • Configuration can be specified or overridden from the command line,
  • Dynamic command line tab completion,
  • Run your application locally or launch it to run remotely,
  • Run multiple jobs with different arguments with a single command.
Submitit launcher is a plugin that provides a SLURM launcher based on Submitit. For instance, you can embed these following configurations to Hydra config file.
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
timeout_min: 60
cpus_per_task: 1
gpus_per_node: 0
tasks_per_node: 1
mem_gb: 4
nodes: 1
name: ${}
Combining Hydra and Submitit plugin, we can have a single configuration file that describes the entire experiment. It’s very useful for reproducing or understanding the experiment.
More details could be found as follows.

PyTorch Lightning and Checkpointing

In this section, I will first describe how to manually restore training state in the PyTorch Lightning environment. Then, I will describe the design of PyTorch Lightning that works in the high-performance computing scenario where fault-tolerant (i.e., server-down, job’s preemption).
Manually Restore Training State
To load a model along with its weights, biases, and hyperparameters, you can use the following method.
# Load training state to a model
model = MyLightingModule.load_from_checkpoint(PATH)

# Load training state to a trainer 
model = LitModel()
trainer = Trainer(), ckpt_path="some/path/to/my_checkpoint.ckpt")
Note that load_from_checkpoint is a classmethod. So, if your code is written in the incorrect way, you may encounter a bug in loading the training state.
model = LitModel()
model.load_from_checkpoint(PATH)         # incorrect way of loading state
model = model.load_from_checkpoint(PATH) # correct way of loading state
PyTorch Lightning and Checkpointing
If you want PyTorch Lightning to automatically resume the training from the previous training state, you need to understand how this is done under the hood in PyTorch Lightning. Note that this post is based on PyTorch Lightning 1.5.9. There might be some changes to the future version.
From [Link], the logic for automatically restore the training state is provided in this function.
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
		"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
		1. from HPC weights if found
		2. from fault-tolerant auto-saved checkpoint if found
		3. from `checkpoint_path` file if provided
		4. don't restore
		self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path
		checkpoint_path = self.resume_checkpoint_path
		if not checkpoint_path:
		    log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
		rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
		self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
Basically, PyTorch Lightning will load the training state in this following order:
  1. HPC weight (From /hpc_ckpt_{epoch:d}.ckpt),
  1. Fault-tolerant auto-saved checkpoint (From /.pl_auto_save.ckpt),
  1. Checkpoint path (From ckpt_path in the Trainer).
Let’s walk through a single training scenario where you would like to fine-tune a model’s weight on some tasks. You load a checkpoint base.ckpt. You then train a model for some number of epochs until it stops because of your wall-time or some server interruption. There will be some weights files like hpc_ckpt_10.ckpt (or .pl_auto_save.ckpt if you use Fault-tolerant training). On the next training loop, you will run the same training script. PyTorch Lightning will start from hpc_ckpt_10.ckpt (or .pl_auto_save.ckpt if you use Fault-tolerant training). Note that the training will not start from the checkpoint provided from base.ckpt even if it’s provided in the training script because the priority is lower than hpc_ckpt_10.ckpt or .pl_auto_save.ckpt.
In the following sections, I will show how to setup the training to create a HPC weight or Fault-tolerant checkpoint.

[Recommended] Automatically Reloading HPC weight

In the previous section, we talk about how PyTorch Lightning can automatically reload weights from /hpc_ckpt_{epoch:d}.ckpt. Now, we show one example way of setting this up through Hydra configuration file.
  _target_: pytorch_lightning.callbacks.ModelCheckpoint
  monitor: "val/g_loss"
  save_top_k: -1                               # important 
  dirpath: "."                                 # important 
  filename: "hpc_ckpt_{epoch:d}"               # important 
  auto_insert_metric_name: False               # important 
Example Hydra configuration for automatically reloading HPC weight.
Essentially, you would want to use pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint to save the checkpoints in the format of hpc_ckpt_{epoch:d} at every single epoch. Note that the format must be exact otherwise the automatic weight loading logic will break. If you want to save weights in other format or name, you can add another checkpoint with the desired configurations.

[Experimental] Fault-tolerant Training

For more details on Fault-tolerant Training, please visit Note that Fault-tolerant training is an experimental feature and contains some bugs. Based on my experience, it sometimes doesn’t work as expected. You may be able to track the progress of fault-tolerant training implementation [here].
The idea of fault-tolerant training is a mechanism for PyTorch Lightning to recover from a hardware/software failure. This is particularly relevant in the case of SLURM cluster in which instances could be shutdown at any time. Under the hood, PyTorch keep tracks of the following state updates during training:
  • Samplers indices and random state
  • Optimizers, LR schedulers, callbacks, etc.
  • Loop progression
  • Logging internal states
To enable Fault-tolerant Training on PyTorch Lightning, you need to set an environment variable PL_FAULT_TOLERANT_TRAINING=1. An example execution script is
PL_FAULT_TOLERANT_TRAINING=1 python experiment=audiovideostream_debug --multirun
During training, PyTorch Lightning will automatically create a file named .pl_auto_save.ckpt for keeping required states.
Note: if you also setup the model checkpoint like in the previous section. The .pl_auto_save.ckpt will not be loaded.