Training at scale

Author

Marie-Hélène Burle

Using a JupyterHub to prototype code might be fine, but when you want to access more resources, it is much more resource-efficient to submit sbatch jobs to Slurm.

This section covers the workflow.

Write a Python script

The first step is to put all your code in a Python script that you can evaluate during the job.

Let’s call it main.py:

from datasets import load_dataset
import numpy as np
from torchvision.transforms import v2 as T
import grain.python as grain
import jax
import jax.numpy as jnp
from flax import nnx
from transformers import FlaxViTForImageClassification
import optax
import matplotlib.pyplot as plt
from time import time

train_size = 5 * 750
val_size = 5 * 250

train_dataset = load_dataset("food101",
                             split=f"train[:{train_size}]")

val_dataset = load_dataset("food101",
                           split=f"validation[:{val_size}]")

labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
    label = val_dataset[i]["label"]
    if label not in labels_mapping:
        labels_mapping[label] = index
        index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

img_size = 224

def to_np_array(pil_image):
  return np.asarray(pil_image.convert("RGB"))

def normalize(image):
    mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
    image = image.astype(np.float32) / 255.0
    return (image - mean) / std

tv_train_transforms = T.Compose([
    T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2, 0.2, 0.2),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

tv_test_transforms = T.Compose([
    T.Resize((img_size, img_size)),
    T.Lambda(to_np_array),
    T.Lambda(normalize),
])

def get_transform(fn):
    def wrapper(batch):
        batch["image"] = [
            fn(pil_image) for pil_image in batch["image"]
        ]
        batch["label"] = [
            labels_mapping[label] for label in batch["label"]
        ]
        return batch
    return wrapper

train_transforms = get_transform(tv_train_transforms)
val_transforms = get_transform(tv_test_transforms)

train_dataset = train_dataset.with_transform(train_transforms)
val_dataset = val_dataset.with_transform(val_transforms)

seed = 12
train_batch_size = 32
val_batch_size = 2 * train_batch_size

train_sampler = grain.IndexSampler(
    len(train_dataset),
    shuffle=True,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

val_sampler = grain.IndexSampler(
    len(val_dataset),
    shuffle=False,
    seed=seed,
    shard_options=grain.NoSharding(),
    num_epochs=1,
)

train_loader = grain.DataLoader(
    data_source=train_dataset,
    sampler=train_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(train_batch_size, drop_remainder=True),
    ]
)

val_loader = grain.DataLoader(
    data_source=val_dataset,
    sampler=val_sampler,
    worker_count=4,
    worker_buffer_size=2,
    operations=[
        grain.Batch(val_batch_size),
    ]
)

class VisionTransformer(nnx.Module):
    def __init__(
        self,
        num_classes: int = 1000,
        in_channels: int = 3,
        img_size: int = 224,
        patch_size: int = 16,
        num_layers: int = 12,
        num_heads: int = 12,
        mlp_dim: int = 3072,
        hidden_size: int = 768,
        dropout_rate: float = 0.1,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ):
        # Patch and position embedding
        n_patches = (img_size // patch_size) ** 2
        self.patch_embeddings = nnx.Conv(
            in_channels,
            hidden_size,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            use_bias=True,
            rngs=rngs,
        )

        initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        self.position_embeddings = nnx.Param(
            initializer(
                rngs.params(),
                (1, n_patches + 1, hidden_size),
                jnp.float32
            )
        )
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)

        self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))

        # Transformer Encoder blocks
        self.encoder = nnx.Sequential(*[
            TransformerEncoder(
                hidden_size,
                mlp_dim,
                num_heads,
                dropout_rate,
                rngs=rngs
            )
            for i in range(num_layers)
        ])
        self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)

        # Classification head
        self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        # Patch and position embedding
        patches = self.patch_embeddings(x)
        batch_size = patches.shape[0]
        patches = patches.reshape(batch_size, -1, patches.shape[-1])

        cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1])
        x = jnp.concat([cls_token, patches], axis=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)

        # Encoder blocks
        x = self.encoder(embeddings)
        x = self.final_norm(x)

        # fetch the first token
        x = x[:, 0]

        # Classification
        return self.classifier(x)

class TransformerEncoder(nnx.Module):
    def __init__(
        self,
        hidden_size: int,
        mlp_dim: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        *,
        rngs: nnx.Rngs = nnx.Rngs(0),
    ) -> None:

        self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)
        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            dropout_rate=dropout_rate,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            rngs=rngs,
        )
        self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)

        self.mlp = nnx.Sequential(
            nnx.Linear(hidden_size, mlp_dim, rngs=rngs),
            nnx.gelu,
            nnx.Dropout(dropout_rate, rngs=rngs),
            nnx.Linear(mlp_dim, hidden_size, rngs=rngs),
            nnx.Dropout(dropout_rate, rngs=rngs),
        )

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

model = VisionTransformer(num_classes=1000)

tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

def vit_inplace_copy_weights(*, src_model, dst_model):
    assert isinstance(src_model, FlaxViTForImageClassification)
    assert isinstance(dst_model, VisionTransformer)

    tf_model_params = src_model.params
    tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)

    flax_model_params = nnx.state(dst_model, nnx.Param)
    flax_model_params_fstate = flax_model_params.flat_state()

    params_name_mapping = {
        ("cls_token",): ("vit", "embeddings", "cls_token"),
        ("position_embeddings",): (
            "vit",
            "embeddings",
            "position_embeddings"
        ),
        **{
            ("patch_embeddings", x): (
                "vit",
                "embeddings",
                "patch_embeddings",
                "projection",
                x
            )
            for x in ["kernel", "bias"]
        },
        **{
            ("encoder", "layers", i, "attn", y, x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                "attention",
                "attention",
                y,
                x
            )
            for x in ["kernel", "bias"]
            for y in ["key", "value", "query"]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, "attn", "out", x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                "attention",
                "output",
                "dense",
                x
            )
            for x in ["kernel", "bias"]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, "mlp", "layers", y1, x): (
                "vit",
                "encoder",
                "layer",
                str(i),
                y2,
                "dense",
                x
            )
            for x in ["kernel", "bias"]
            for y1, y2 in [(0, "intermediate"), (3, "output")]
            for i in range(12)
        },
        **{
            ("encoder", "layers", i, y1, x): (
                "vit", "encoder", "layer", str(i), y2, x
            )
            for x in ["scale", "bias"]
            for y1, y2 in [
                    ("norm1", "layernorm_before"),
                    ("norm2", "layernorm_after")
            ]
            for i in range(12)
        },
        **{
            ("final_norm", x): ("vit", "layernorm", x)
            for x in ["scale", "bias"]
        },
        **{
            ("classifier", x): ("classifier", x)
            for x in ["kernel", "bias"]
        }
    }

    nonvisited = set(flax_model_params_fstate.keys())

    for key1, key2 in params_name_mapping.items():
        assert key1 in flax_model_params_fstate, key1
        assert key2 in tf_model_params_fstate, (key1, key2)

        nonvisited.remove(key1)

        src_value = tf_model_params_fstate[key2]
        if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
            shape = src_value.shape
            src_value = src_value.reshape((shape[0], 12, 64))

        if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
            src_value = src_value.reshape((12, 64))

        if key2[-4:] == ("attention", "output", "dense", "kernel"):
            shape = src_value.shape
            src_value = src_value.reshape((12, 64, shape[-1]))

        dst_value = flax_model_params_fstate[key1]
        assert src_value.shape == dst_value.value.shape, (
            key2, src_value.shape, key1, dst_value.value.shape
        )
        dst_value.value = src_value.copy()
        assert dst_value.value.mean() == src_value.mean(), (
            dst_value.value, src_value.mean()
        )

    assert len(nonvisited) == 0, nonvisited

    nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))

vit_inplace_copy_weights(src_model=tf_model, dst_model=model)

model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))

num_epochs = 3
learning_rate = 0.001
momentum = 0.8
total_steps = len(train_dataset) // train_batch_size

lr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)

optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))

def compute_losses_and_logits(model: nnx.Module, images: jax.Array, labels: jax.Array):
    logits = model(images)

    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=labels
    ).mean()
    return loss, logits

@nnx.jit
def train_step(
    model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, np.ndarray]
):
    # Convert np.ndarray to jax.Array on GPU
    images = jnp.array(batch["image"])
    labels = jnp.array(batch["label"], dtype=jnp.int32)

    grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)

    optimizer.update(grads)  # In-place updates.

    return loss

@nnx.jit
def eval_step(
    model: nnx.Module, batch: dict[str, np.ndarray], eval_metrics: nnx.MultiMetric
):
    # Convert np.ndarray to jax.Array on GPU
    images = jnp.array(batch["image"])
    labels = jnp.array(batch["label"], dtype=jnp.int32)
    loss, logits = compute_losses_and_logits(model, images, labels)

    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels,
    )

eval_metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
    accuracy=nnx.metrics.Accuracy(),
)

train_metrics_history = {
    "train_loss": [],
}

eval_metrics_history = {
    "val_loss": [],
    "val_accuracy": [],
}

def train_one_epoch(epoch):
    model.train()

def evaluate_model(epoch):
    model.eval()

    eval_metrics.reset()
    for val_batch in val_loader:
        eval_step(model, val_batch, eval_metrics)

    for metric, value in eval_metrics.compute().items():
        eval_metrics_history[f'val_{metric}'].append(value)

    print(f"[val] epoch: {epoch + 1}/{num_epochs}")
    print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
    print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")

start = time()

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)

end = time()

print(f"Training took {round((end - start) / 60, 1)} minutes")

plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")
plt.legend()
plt.savefig('loss.png')

fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs[0].set_title("Loss value on validation set")
axs[0].plot(eval_metrics_history["val_loss"])
axs[1].set_title("Accuracy on validation set")
axs[1].plot(eval_metrics_history["val_accuracy"])
plt.savefig('validation.png')

We have to make a few changes to our code:

  • Strip your code of anything unnecessary that you might have used during prototyping.

  • It doesn’t make sense to use tqdm anymore, so remove the corresponding code.

  • We can’t display the graphs anymore, so we save them to files with plt.savefig()

  • When we aren’t using IPython (directly or via Jupyter), we don’t have access to the built-in magic commands such as %%time to time the execution of a cell. Instead, we use the following snippet:

start = time()

<Code to time>

end = time()

print(f"Training took {round((end - start) / 60, 1)} minutes")

In this case, since it is the training that we want to time:

start = time()

for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate_model(epoch)

end = time()

print(f"Training took {round((end - start) / 60, 1)} minutes")

Write a Slurm script

Then you need to write a Bash script for the Slurm scheduler.

Our training cluster is made of 50 nodes of the c4-30gb flavour (each node contains 4 CPU and 30G, meaning 7500M per CPU).

If we want to train on a single CPU using the maximum amount of memory for that CPU, this is what our script looks like (let’s call it train.sh):

train.sh
#!/bin/bash
#SBATCH --time=20
#SBATCH --mem-per-cpu=7500M

module load python/3.12.4 arrow/19.0.1
source /project/60055/env/bin/activate

python main.py

When I tested this earlier, training took 36.8 minutes.

Our training cluster doesn’t require an account and it doesn’t have GPUs. It also doesn’t have huge amounts of memory. Moreover our code only contains 5 classes of foods to make training much faster. Finally, our Python virtual environment is in /project so that we can all access it while you normally would store it in your home.

If you were to train our model on an Alliance cluster at scale, the script would thus look something like this:

train.sh
#!/bin/bash
#SBATCH --account=def-<name>
#SBATCH --time=5:0:0
#SBATCH --mem=50G
#SBATCH --gpus-per-node=1

module load python/3.12.4 arrow/19.0.1
source ~/env/bin/activate

python main.py

Replace <name> by your Alliance account name.

This assumes that you have a Python virtual environment in ~/env with all necessary packages installed.

Also note that if you are using the Alliance supercomputer Cedar, there is a policy for this cluster blocking you from running jobs in the /home filesystem, so you will have to copy your files to /scratch or your /project and run the job from there.

Notice the following differences:

  • we provide an account name,
  • we ask for a lot more time (training at scale)—this could even be days or weeks,
  • we ask for a lot more memory,
  • we ask for a GPU—sometimes you will need several GPUs (remember that the same JAX code can run on any device),
  • we source a virtual environment which is in our home.

Run the script

sbatch train.sh

Monitor the job

To see whether your job is still running and to get the job ID, you can use the Alliance alias:

sq
  • PD      ➔ the job is pending
  • R      ➔ the job is running
  • No output  ➔ the job is done

While your job is running, you can monitor it by opening a new terminal and, from the login node, running:

srun --jobid=<jobID> --pty bash

Replace <jobID> by the job ID you got by running sq.

Then launch htop:

alias htop='htop -u $USER -s PERCENT_CPU'
htop                     # monitor all your processes
htop --filter "python"   # filter processes by name

Check average memory usage with:

sstat -j <jobID> --format=AveRSS

Or maximum memory usage with:

sstat -j <jobID> --format=MaxRSS

Get the results

The results will be in a file created by Slurm and called, by default, slurm-<jobID>.out (you can change the name of this file by adding an option in your Slurm script).

You can look at them with:

bat slurm-<jobID>.out

Retrieve files

We created two images (loss.png and validation.png). To retrieve them, you can use scp from your computer:

scp username@hostname:path/file path

For instance:

scp userxx@hostname:loss.png ~/

Replace hostname by the hostname for this cluster and ~/ by the path where you want to download your file.