Transfer Learning - MATLAB & Simulink

Transfer Learning

What Is Transfer Learning?

Transfer learning is a deep learning approach in which a model trained for one task is used as a starting point for a model that performs a similar task. Updating and retraining a network with transfer learning is faster and easier than training a network from scratch. Transfer learning is used for image classification, object detection, speech recognition, and other applications.

Why Does Transfer Learning Matter?

Transfer learning lets you take advantage of the expertise in the deep learning community. Popular pretrained models offer a robust architecture and skip starting from scratch. Transfer learning is a common technique for supervised learning because:

  • It enables you to train models with less labeled data by reusing popular models already trained on large datasets.
  • It can reduce training time and computing resources. With transfer learning, the neural network weights are not learned from scratch because the pretrained model has already learned the weights based on previous learnings.
  • You can use model architectures developed by the deep learning research community, including popular architectures such as GoogLeNet and YOLO.
Diagram of transfer learning showing how knowledge is transferred from model 1 (trained with dataset 1) to model 2 (trained with dataset 2). Dataset 1 is larger than dataset 2.

Transfer knowledge from a pretrained model to another model that can be trained with less labeled data.

Training from Scratch or Transfer Learning?

To create a deep learning model, you can train a model from scratch or perform transfer learning using a pretrained model.

Developing and training a model from scratch works better for highly specific tasks for which preexisting models cannot be used. The downside of this approach is that it typically requires a large amount of data to produce accurate results. Creating a model from scratch also works well in cases where smaller networks can achieve the desired accuracy. For example, recurrent neural networks (RNNs) and long short-term memory (LSTM) networks are particularly effective with sequential data that vary in length, and they solve problems such as signal classification and time series forecasting.

Transfer learning is useful for tasks for which various pretrained models exist. For example, many popular convolutional neural networks (CNNs) are pretrained on the ImageNet dataset, which contains over 14 million images and a thousand classes of images. If you need to classify images of flowers from your garden (or any images not included in the ImageNet dataset), and you have a limited number of flower images, then you can transfer layers and their weights from a SqueezeNet network, replace the final layers, and retrain your model with the images you have.

This approach can help you achieve higher model accuracy in a shorter time with transfer learning.

Comparing network performance against training between networks with and without transfer learning. The performance curve for transfer learning shows a higher start, slope, and asymptote.

Comparing the network performance (accuracy) of training from scratch and transfer learning.

Transfer Learning Applications

Transfer learning is popular in many deep learning applications, such as:

Speech and Audio Processing

See the MATLAB example Transfer Learning with Pretrained Audio Networks in Deep Network Designer.

Text Analytics

See the MATLAB GitHub example Fine Tune BERT Model for Japanese Text.

Pretrained Models for Transfer Learning

At the center of transfer learning is the pretrained deep learning model, built by deep learning researchers, that has been trained using thousands or millions of sample data points.

Many pretrained models are available, and each has advantages and drawbacks to consider:

  1. Prediction speed: How fast can the model predict new inputs? While prediction speed can vary based on other factors such as hardware and batch size, speed will also vary based on the model's architecture and size.
  2. Size: What is the desired memory footprint for the model? The importance of your model’s size will vary depending on where and how you intend to deploy it. Will it run on embedded hardware or a desktop? The size of the network is significant when deploying to a resource-constraint target.
  3. Accuracy: How well does the model perform before retraining? A model that performs well for the ImageNet dataset will likely perform well on new, similar tasks. However, a low accuracy score on ImageNet does not necessarily mean the model will perform poorly on all tasks.
Comparing model size, prediction speed, and accuracy of pretrained CNN models, which can be used for transfer learning.

Comparing model size, speed, and accuracy for popular CNN pretrained models.

Which Model Is Best for Your Transfer Learning Workflow?

With many transfer learning models to choose from, it’s important to remember the tradeoffs involved and the overall goals of your specific project. A good approach is to try a variety of models to find the one that fits your application best.

Simple models for getting started, such as GoogLeNet, VGG-16, and VGG-19, let you iterate quickly and experiment with different data preprocessing steps and training options. Once you see what settings work well, you can try a more accurate network to see if that improves your results.

Lightweight and computationally efficient models, such as SqueezeNet, MobileNet-v2, and ShuffleNet, are good options when the deployment environment limits model size.

How to Get Pretrained Models in MATLAB?

You can explore MATLAB Deep Learning Model Hub to access the latest models by category and get tips on choosing a model. You can load most models with a single MATLAB function, such as the darknet19 function.

You can also get pretrained networks from external platforms. You can convert a model from TensorFlow™, PyTorch®, or ONNX™ to a MATLAB model using an import function, such as the importNetworkFromTensorFlow function.

Diagram on how to get pretrained models for transfer learning from the MATLAB Deep Learning Model Hub, PyTorch, TensorFlow, or ONNX.

Get pretrained deep learning models directly from MATLAB, or from external deep learning platforms (PyTorch, TensorFlow, and ONNX).

Transfer Learning Applied to Soft Sensor Design

Read how Poclain Hydraulics took advantage of pretrained networks in MATLAB to speed up the design of soft sensors.

“We identified two neural networks that were already implemented in MATLAB and these neural networks helped us embed the codes into hardware to do real-time predictions of temperature.”

Transfer Learning with MATLAB

Using MATLAB with Deep Learning Toolbox™ lets you access hundreds of pretrained models and perform transfer learning with built-in functions or interactive apps. For different transfer learning applications, you may also need to use other toolboxes such as Computer Vision Toolbox™, Audio Toolbox™, Lidar Toolbox™, or Text Analytics Toolbox™.

The Transfer Learning Workflow

The transfer learning workflow includes getting a pretrained network, modifying it, retraining it, and then using it to predict on new data.

Diagram of steps in the transfer learning workflow.

While there are various transfer learning architectures and applications, most transfer learning workflows follow a common series of steps. The following illustration shows the transfer learning workflow for image classification. Transfer learning is performed on a pretrained GoogLeNet model, a popular network with 22 layers deep trained to classify into 1000 object categories.

  1. Select a pretrained model. It can help to select a simple model when getting started.
Pretrained CNN model that can be modified for transfer learning in image classification tasks.

Architecture of the GoogLeNet model, a network with 22 layers deep and trained to classify 1000 object categories.

  1. Replace the final layers. To retrain the network to classify a new set of images and classes, you replace the last learnable layer and the final classification layer of the GoogLeNet model. The final fully connected layer (last learnable layer) is modified to contain the same number of nodes as the number of new classes. The new classification layer will produce an output based on the probabilities calculated by the softmax layer.
Replacing the final layers of a pretrained CNN model before retraining the model is essential to transfer learning.

Replace the last learnable layer and classification layer of a CNN model before retraining the model.

After modifying the layers, the final fully connected layer will specify the new number of classes the network will learn, and the classification layer will determine outputs from the new output categories available. For example, GoogLeNet was originally trained on 1000 categories, but by replacing the final layers, you can retrain it to classify only the five (or any other number) categories of objects you are interested in.

  1. Optionally freeze the weights. You can freeze the weights of earlier layers in the network by setting the learning rates in those layers to zero. During training, the parameters of frozen layers are not updated, which can significantly speed up network training. If the new data set is small, then freezing weights can also prevent overfitting the network to the new data set.
  2. Retrain the model. Retraining will update the network to learn and identify features associated with the new images and categories. In most cases, retraining requires less data than training a model from scratch.
  3. Predict and assess network accuracy. After the model is retrained, you can classify new images and evaluate how well the network performs.

An Interactive Approach to Transfer Learning

Using the Deep Network Designer app, you can interactively complete the entire transfer learning workflow—including selecting or importing (from MATLAB, TensorFlow, or PyTorch) a pretrained model, modifying the final layers, and retraining the network using new data—with little or no coding.

Learn More About Transfer Learning

Watch these videos to get started with transfer learning at the command line or by using Deep Network Designer.

Related Topics