Using a JupyterHub to prototype code might be fine, but when you want to access more resources, it is much more resource-efficient to submit sbatch jobs to Slurm.
This section covers the workflow.
Write a Python script
The first step is to put all your code in a Python script that you can evaluate during the job.
Let’s call it main.py:
main.py
from datasets import load_datasetimport numpy as npfrom torchvision.transforms import v2 as Timport grain.python as grainimport jaximport jax.numpy as jnpfrom flax import nnxfrom transformers import FlaxViTForImageClassificationimport optaximport matplotlib.pyplot as pltfrom time import timetrain_size =5*750val_size =5*250train_dataset = load_dataset("food101", split=f"train[:{train_size}]")val_dataset = load_dataset("food101", split=f"validation[:{val_size}]")labels_mapping = {}index =0for i inrange(0, len(val_dataset), 250): label = val_dataset[i]["label"]if label notin labels_mapping: labels_mapping[label] = index index +=1inv_labels_mapping = {v: k for k, v in labels_mapping.items()}img_size =224def to_np_array(pil_image):return np.asarray(pil_image.convert("RGB"))def normalize(image): mean = np.array([0.5, 0.5, 0.5], dtype=np.float32) std = np.array([0.5, 0.5, 0.5], dtype=np.float32) image = image.astype(np.float32) /255.0return (image - mean) / stdtv_train_transforms = T.Compose([ T.RandomResizedCrop((img_size, img_size), scale=(0.7, 1.0)), T.RandomHorizontalFlip(), T.ColorJitter(0.2, 0.2, 0.2), T.Lambda(to_np_array), T.Lambda(normalize),])tv_test_transforms = T.Compose([ T.Resize((img_size, img_size)), T.Lambda(to_np_array), T.Lambda(normalize),])def get_transform(fn):def wrapper(batch): batch["image"] = [ fn(pil_image) for pil_image in batch["image"] ] batch["label"] = [ labels_mapping[label] for label in batch["label"] ]return batchreturn wrappertrain_transforms = get_transform(tv_train_transforms)val_transforms = get_transform(tv_test_transforms)train_dataset = train_dataset.with_transform(train_transforms)val_dataset = val_dataset.with_transform(val_transforms)seed =12train_batch_size =32val_batch_size =2* train_batch_sizetrain_sampler = grain.IndexSampler(len(train_dataset), shuffle=True, seed=seed, shard_options=grain.NoSharding(), num_epochs=1,)val_sampler = grain.IndexSampler(len(val_dataset), shuffle=False, seed=seed, shard_options=grain.NoSharding(), num_epochs=1,)train_loader = grain.DataLoader( data_source=train_dataset, sampler=train_sampler, worker_count=4, worker_buffer_size=2, operations=[ grain.Batch(train_batch_size, drop_remainder=True), ])val_loader = grain.DataLoader( data_source=val_dataset, sampler=val_sampler, worker_count=4, worker_buffer_size=2, operations=[ grain.Batch(val_batch_size), ])class VisionTransformer(nnx.Module):def__init__(self, num_classes: int=1000, in_channels: int=3, img_size: int=224, patch_size: int=16, num_layers: int=12, num_heads: int=12, mlp_dim: int=3072, hidden_size: int=768, dropout_rate: float=0.1,*, rngs: nnx.Rngs = nnx.Rngs(0), ):# Patch and position embedding n_patches = (img_size // patch_size) **2self.patch_embeddings = nnx.Conv( in_channels, hidden_size, kernel_size=(patch_size, patch_size), strides=(patch_size, patch_size), padding="VALID", use_bias=True, rngs=rngs, ) initializer = jax.nn.initializers.truncated_normal(stddev=0.02)self.position_embeddings = nnx.Param( initializer( rngs.params(), (1, n_patches +1, hidden_size), jnp.float32 ) )self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)self.cls_token = nnx.Param(jnp.zeros((1, 1, hidden_size)))# Transformer Encoder blocksself.encoder = nnx.Sequential(*[ TransformerEncoder( hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs )for i inrange(num_layers) ])self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)# Classification headself.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)def__call__(self, x: jax.Array) -> jax.Array:# Patch and position embedding patches =self.patch_embeddings(x) batch_size = patches.shape[0] patches = patches.reshape(batch_size, -1, patches.shape[-1]) cls_token = jnp.tile(self.cls_token, [batch_size, 1, 1]) x = jnp.concat([cls_token, patches], axis=1) embeddings = x +self.position_embeddings embeddings =self.dropout(embeddings)# Encoder blocks x =self.encoder(embeddings) x =self.final_norm(x)# fetch the first token x = x[:, 0]# Classificationreturnself.classifier(x)class TransformerEncoder(nnx.Module):def__init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float=0.0,*, rngs: nnx.Rngs = nnx.Rngs(0), ) ->None:self.norm1 = nnx.LayerNorm(hidden_size, rngs=rngs)self.attn = nnx.MultiHeadAttention( num_heads=num_heads, in_features=hidden_size, dropout_rate=dropout_rate, broadcast_dropout=False, decode=False, deterministic=False, rngs=rngs, )self.norm2 = nnx.LayerNorm(hidden_size, rngs=rngs)self.mlp = nnx.Sequential( nnx.Linear(hidden_size, mlp_dim, rngs=rngs), nnx.gelu, nnx.Dropout(dropout_rate, rngs=rngs), nnx.Linear(mlp_dim, hidden_size, rngs=rngs), nnx.Dropout(dropout_rate, rngs=rngs), )def__call__(self, x: jax.Array) -> jax.Array: x = x +self.attn(self.norm1(x)) x = x +self.mlp(self.norm2(x))return xmodel = VisionTransformer(num_classes=1000)tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')def vit_inplace_copy_weights(*, src_model, dst_model):assertisinstance(src_model, FlaxViTForImageClassification)assertisinstance(dst_model, VisionTransformer) tf_model_params = src_model.params tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params) flax_model_params = nnx.state(dst_model, nnx.Param) flax_model_params_fstate = flax_model_params.flat_state() params_name_mapping = { ("cls_token",): ("vit", "embeddings", "cls_token"), ("position_embeddings",): ("vit","embeddings","position_embeddings" ),**{ ("patch_embeddings", x): ("vit","embeddings","patch_embeddings","projection", x )for x in ["kernel", "bias"] },**{ ("encoder", "layers", i, "attn", y, x): ("vit","encoder","layer",str(i),"attention","attention", y, x )for x in ["kernel", "bias"]for y in ["key", "value", "query"]for i inrange(12) },**{ ("encoder", "layers", i, "attn", "out", x): ("vit","encoder","layer",str(i),"attention","output","dense", x )for x in ["kernel", "bias"]for i inrange(12) },**{ ("encoder", "layers", i, "mlp", "layers", y1, x): ("vit","encoder","layer",str(i), y2,"dense", x )for x in ["kernel", "bias"]for y1, y2 in [(0, "intermediate"), (3, "output")]for i inrange(12) },**{ ("encoder", "layers", i, y1, x): ("vit", "encoder", "layer", str(i), y2, x )for x in ["scale", "bias"]for y1, y2 in [ ("norm1", "layernorm_before"), ("norm2", "layernorm_after") ]for i inrange(12) },**{ ("final_norm", x): ("vit", "layernorm", x)for x in ["scale", "bias"] },**{ ("classifier", x): ("classifier", x)for x in ["kernel", "bias"] } } nonvisited =set(flax_model_params_fstate.keys())for key1, key2 in params_name_mapping.items():assert key1 in flax_model_params_fstate, key1assert key2 in tf_model_params_fstate, (key1, key2) nonvisited.remove(key1) src_value = tf_model_params_fstate[key2]if key2[-1] =="kernel"and key2[-2] in ("key", "value", "query"): shape = src_value.shape src_value = src_value.reshape((shape[0], 12, 64))if key2[-1] =="bias"and key2[-2] in ("key", "value", "query"): src_value = src_value.reshape((12, 64))if key2[-4:] == ("attention", "output", "dense", "kernel"): shape = src_value.shape src_value = src_value.reshape((12, 64, shape[-1])) dst_value = flax_model_params_fstate[key1]assert src_value.shape == dst_value.value.shape, ( key2, src_value.shape, key1, dst_value.value.shape ) dst_value.value = src_value.copy()assert dst_value.value.mean() == src_value.mean(), ( dst_value.value, src_value.mean() )assertlen(nonvisited) ==0, nonvisited nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))vit_inplace_copy_weights(src_model=tf_model, dst_model=model)model.classifier = nnx.Linear(model.classifier.in_features, 5, rngs=nnx.Rngs(0))num_epochs =3learning_rate =0.001momentum =0.8total_steps =len(train_dataset) // train_batch_sizelr_schedule = optax.linear_schedule(learning_rate, 0.0, num_epochs * total_steps)optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))def compute_losses_and_logits(model: nnx.Module, images: jax.Array, labels: jax.Array): logits = model(images) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=labels ).mean()return loss, logits@nnx.jitdef train_step( model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, np.ndarray]):# Convert np.ndarray to jax.Array on GPU images = jnp.array(batch["image"]) labels = jnp.array(batch["label"], dtype=jnp.int32) grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True) (loss, logits), grads = grad_fn(model, images, labels) optimizer.update(grads) # In-place updates.return loss@nnx.jitdef eval_step( model: nnx.Module, batch: dict[str, np.ndarray], eval_metrics: nnx.MultiMetric):# Convert np.ndarray to jax.Array on GPU images = jnp.array(batch["image"]) labels = jnp.array(batch["label"], dtype=jnp.int32) loss, logits = compute_losses_and_logits(model, images, labels) eval_metrics.update( loss=loss, logits=logits, labels=labels, )eval_metrics = nnx.MultiMetric( loss=nnx.metrics.Average('loss'), accuracy=nnx.metrics.Accuracy(),)train_metrics_history = {"train_loss": [],}eval_metrics_history = {"val_loss": [],"val_accuracy": [],}def train_one_epoch(epoch): model.train()def evaluate_model(epoch): model.eval() eval_metrics.reset()for val_batch in val_loader: eval_step(model, val_batch, eval_metrics)for metric, value in eval_metrics.compute().items(): eval_metrics_history[f'val_{metric}'].append(value)print(f"[val] epoch: {epoch +1}/{num_epochs}")print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")start = time()for epoch inrange(num_epochs): train_one_epoch(epoch) evaluate_model(epoch)end = time()print(f"Training took {round((end - start) /60, 1)} minutes")plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")plt.legend()plt.savefig('loss.png')fig, axs = plt.subplots(1, 2, figsize=(10, 10))axs[0].set_title("Loss value on validation set")axs[0].plot(eval_metrics_history["val_loss"])axs[1].set_title("Accuracy on validation set")axs[1].plot(eval_metrics_history["val_accuracy"])plt.savefig('validation.png')
We have to make a few changes to our code:
Strip your code of anything unnecessary that you might have used during prototyping.
It doesn’t make sense to use tqdm anymore, so remove the corresponding code.
We can’t display the graphs anymore, so we save them to files with plt.savefig()
When we aren’t using IPython (directly or via Jupyter), we don’t have access to the built-in magic commands such as %%time to time the execution of a cell. Instead, we use the following snippet:
start = time()<Code to time>end = time()print(f"Training took {round((end - start) /60, 1)} minutes")
In this case, since it is the training that we want to time:
When I tested this earlier, training took 36.8 minutes.
Our training cluster doesn’t require an account and it doesn’t have GPUs. It also doesn’t have huge amounts of memory. Moreover our code only contains 5 classes of foods to make training much faster. Finally, our Python virtual environment is in /project so that we can all access it while you normally would store it in your home.
If you were to train our model on an Alliance cluster at scale, the script would thus look something like this:
This assumes that you have a Python virtual environment in ~/env with all necessary packages installed.
Also note that if you are using the Alliance supercomputer Cedar, there is a policy for this cluster blocking you from running jobs in the /home filesystem, so you will have to copy your files to /scratch or your /project and run the job from there.
Notice the following differences:
we provide an account name,
we ask for a lot more time (training at scale)—this could even be days or weeks,
we ask for a lot more memory,
we ask for a GPU—sometimes you will need several GPUs (remember that the same JAX code can run on any device),
we source a virtual environment which is in our home.
Run the script
sbatch train.sh
Monitor the job
To see whether your job is still running and to get the job ID, you can use the Alliance alias:
sq
PD ➔ the job is pending
R ➔ the job is running
No output ➔ the job is done
While your job is running, you can monitor it by opening a new terminal and, from the login node, running:
srun--jobid=<jobID> --pty bash
Replace <jobID> by the job ID you got by running sq.
Then launch htop:
alias htop='htop -u $USER -s PERCENT_CPU'htop# monitor all your processeshtop--filter"python"# filter processes by name
Check average memory usage with:
sstat-j<jobID> --format=AveRSS
Or maximum memory usage with:
sstat-j<jobID> --format=MaxRSS
Get the results
The results will be in a file created by Slurm and called, by default, slurm-<jobID>.out (you can change the name of this file by adding an option in your Slurm script).
You can look at them with:
bat slurm-<jobID>.out
Retrieve files
We created two images (loss.png and validation.png). To retrieve them, you can use scpfrom your computer:
scp username@hostname:path/file path
For instance:
scp userxx@hostname:loss.png ~/
Replace hostname by the hostname for this cluster and ~/ by the path where you want to download your file.