Finding pretrained models for transfer learning

Author

Marie-Hélène Burle

Training models from scratch requires way too much data, time, and computing power (or money) to be a practical option. This is why transfer learning has become such a common practice: by starting with models trained on related problems, you are saving time and achieving good results with little data.

Now, where do you find such models?

In this workshop, we will see how to use pre-trained models included in PyTorch libraries, have a look at some of the most popular pre-trained models repositories, and learn how to search models in the literature and on GitHub.

What are pre-trained models?

Transfer learning

If you build models from scratch, expect their performance to be mediocre. Totally naive models with random weights and biases usually need to be trained for a long time on very large datasets, using vast amounts of computing resources, before they produce competitive results. You may not even have enough data to train a model from scratch.

Instead of starting from zero however, you can use a model that has been trained on a similar task. For instance, if your goal is to create a model able to identify bird species from pictures, you could look for a model developed for image recognition tasks trained on a classic dataset such as ImageNet. Classic such models include AlexNet (2012) and ResNet (2015). These models will already have features that are useful to you and you will get better performance with less training time and fewer data. This is called transfer learning.

How transfer learning works

Typically, you remove the last layer (for instance, with AlexNet, you would remove the classification layer), replace it with a layer suitable to your task, then, optionally, you can fine tune the model.

Fine tuning a model consists of freezing the first layers (fixing their weights and biases) while retraining the model with data specific to the new task. This will only train the last few layers, greatly reducing the size of the model actually being trained and taking advantage of the early features from the source model.

I will talk about transfer learning in another workshop, but today, we are focusing on finding a suitable pre-trained model.

Note that the most powerful recent transformers such as GPT-3 and 4 and their competitors perform well in different tasks without the need for re-training.

How to find a pre-trained model

Key to transfer learning is the search for an appropriate source model. The great news is that the world of machine learning research is incredibly open: many teams make their papers and models available online. But you need a way to navigate this abundance of resource.

Things you should probably care about when looking for a pre-trained model include:

  • How pertinent is the model relative to your task?
  • Does the model have an open license?
  • Is the performance good?
  • Is the model size suitable for the resources I have?

Models in PyTorch libraries

The PyTorch ecosystem contains domain specific libraries (e.g. torchvision, torchtext, torchaudio). Among many domain specific utilities, these libraries contain many pretrained models in vision, text, and audio.

These models benefit from optimum convenience since they are entirely integrated into PyTorch.

Loading ResNet-18 is as simple as:

import torchvision
model = torchvision.models.resnet18()

Initializing a pretrained ResNet-50 model with the best currently available weights is as simple as:

from torchvision.models import resnet50, ResNet50_Weights
model = resnet50(weights=ResNet50_Weights.DEFAULT)

PyTorch Hub

PyTorch Hub is a repository of pretrained models.

Loading ResNet-18 from the hub is done with:

import torch
model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)

Your turn:

Look for a small image classification model in the PyTorch Hub.

Hugging Face

Hugging Face, launched in 2016, provides a Model Hub. Let’s explore it together.

Note that Hugging Face also has a Dataset Hub.

Your turn:

Find a pre-trained model for image classification in PyTorch, trained on ImageNet, with an open license, and less than 100MB in size.

timm

For computer vision specifically, the timm (PyTorch Image Models) library contains more than 700 pretrained models, as well as scripts, utilities, optimizers, data-loaders, etc. The repo can be found here.

You can load models from the Hugging Face Hub with:

import timm
model = timm.create_model('hf_hub:author/model', pretrained=True)

GitHub

A large number of open source models are hosted on GitHub and the platform can be searched directly for specific models.

Your turn:

Do a search on GitHub, trying to find pre-trained models in PyTorch for image classification.

Literature

While a less direct way to find pre-trained models, the literature is invaluable to (try to) keep up with what people are doing in the field.

Papers With Code gathers machine learning papers with open source code.

arXiv is an open-source repository of scientific preprints created by Paul Ginsparg from Cornell University in 1991. It contains a huge number of e-prints on machine learning in the computer science and the statistics fields. arxiv-sanity, created by Andrej Karpathy, tracks arXiv machine learning papers and is easier to browse.