Fine-tuning the model
In this section, we fine-tune our model with our sample (5 classes) of the Food-101 dataset [1].
Context
Load packages
# to have a progress bar during training
import tqdm
# to visualize evolution of loss and sample data
import matplotlib.pyplot as plt
Training and evaluation functions
= "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"
bar_format
def train_one_epoch(epoch):
model.train()with tqdm.tqdm(
=f"[train] epoch: {epoch}/{num_epochs}, ",
desc=total_steps,
total=bar_format,
bar_format=True,
leaveas pbar:
) for batch in train_loader:
= train_step(model, optimizer, batch)
loss "train_loss"].append(loss.item())
train_metrics_history["loss": loss.item()})
pbar.set_postfix({1)
pbar.update(
def evaluate_model(epoch):
eval()
model.
eval_metrics.reset()for val_batch in val_loader:
eval_step(model, val_batch, eval_metrics)
for metric, value in eval_metrics.compute().items():
f'val_{metric}'].append(value)
eval_metrics_history[
print(f"[val] epoch: {epoch + 1}/{num_epochs}")
print(f"- total loss: {eval_metrics_history['val_loss'][-1]:0.4f}")
print(f"- Accuracy: {eval_metrics_history['val_accuracy'][-1]:0.4f}")
Train the model
%%time
for epoch in range(num_epochs):
train_one_epoch(epoch) evaluate_model(epoch)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) File <timed exec>:1 NameError: name 'num_epochs' is not defined
OOM issues
As you can see, I ran out of memory when running this code on my machine.
Out of memory (OOM) problems are common when trying to train a model with JAX on GPUs. See for instance this question on Stack Overflow and this issue in the JAX repo.
According to the JAX documentation on GPU memory allocation, you can try the following:
import os
'XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5' os.environ[
or, if you use IPython (or Jupyter which runs IPython), you can use the equivalent syntax using the IPython built-in magic command to set environment variables %env:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.5
None of these solutions worked for me neither on my machine nor on Cedar and I am starting to suspect that there is a problem with this particular version of jaxlib
.
Without GPUs (so on our training cluster), training will be much longer, but you won’t run into this problem.
Metrics graphs
If we hadn’t run out of memory, we could graph our metrics.
Evolution of the loss during training:
"train_loss"], label="Loss value during the training")
plt.plot(train_metrics_history[ plt.legend()
Loss and accuracy on the validation set:
= plt.subplots(1, 2, figsize=(10, 10))
fig, axs 0].set_title("Loss value on validation set")
axs[0].plot(eval_metrics_history["val_loss"])
axs[1].set_title("Accuracy on validation set")
axs[1].plot(eval_metrics_history["val_accuracy"]) axs[
Check sample data
And we could look at the model predictions for 5 items:
= [1, 250, 500, 750, 1000]
test_indices
= jnp.array([val_dataset[i]["image"] for i in test_indices])
test_images = [val_dataset[i]["label"] for i in test_indices]
expected_labels
eval()
model.= model(test_images) preds
= len(test_indices)
num_samples = train_dataset.features["label"].names
names_map
= nnx.softmax(preds, axis=1)
probas = probas.argmax(axis=1)
pred_labels
= plt.subplots(1, num_samples, figsize=(20, 10))
fig, axs for i in range(num_samples):
= test_images[i], expected_labels[i]
img, expected_label
= pred_labels[i].item()
pred_label = probas[i, pred_label].item()
proba if img.dtype in (np.float32, ):
= ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8)
img
= names_map[inv_labels_mapping[expected_label]]
expected_label_str = names_map[inv_labels_mapping[pred_label]]
pred_label_str f"Expected: {expected_label_str} vs \nPredicted: {pred_label_str}, P={proba:.2f}")
axs[i].set_title( axs[i].imshow(img)
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) Cell In[6], line 18 15 img = ((img - img.min()) / (img.max() - img.min()) * 255.0).astype(np.uint8) 17 expected_label_str = names_map[inv_labels_mapping[expected_label]] ---> 18 pred_label_str = names_map[inv_labels_mapping[pred_label]] 19 axs[i].set_title(f"Expected: {expected_label_str} vs \nPredicted: {pred_label_str}, P={proba:.2f}") 20 axs[i].imshow(img) KeyError: 693