Tips on Importing Models from TensorFlow, PyTorch, and ONNX
This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.
Import Functions of Deep Learning Toolbox
This table lists the Deep Learning Toolbox import functions. Use these functions to import networks or layer graphs from TensorFlow, PyTorch, and ONNX.
You must have the relevant support package to run these import functions. If the support package is not installed, each function provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.
Autogenerated Custom Layers
importNetworkFromONNXfunctions can automatically generate custom layers, or custom layers with placeholder functions, when you import TensorFlow, PyTorch, or ONNX layers that the software cannot convert into equivalent built-in MATLAB functions or layers.
importNetworkFromONNXfunctions import an external platform layer into MATLAB by trying these steps in order:
The function imports the external layer as a built-in MATLAB layer.
The function imports the external layer as a built-in MATLAB function (for TensorFlow and PyTorch only).
The function imports the external layer as a custom layer.
The function imports the external layer as a custom layer with a placeholder function.
Input Dimension Ordering
The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX. This table compares input dimension ordering between platforms for different input types.
|2-D image sequence
|3-D image sequence
Variable names in the table:
N — Number of observations
C — Number of features or channels
H — Height of images
W — Width of images
D — Depth of images
S — Sequence length
Data Formats for Prediction with
importNetworkFromTensorFlow function imports a TensorFlow network as an initialized
dlnetwork object. For an example,
see Import TensorFlow Network and Classify Image. If the network does
not have fixed input size, the function imports the model as an uninitialized
dlnetwork object without an input layer. For an example about how
to initialize this network, see Import and Initialize TensorFlow Network.
importNetworkFromPyTorch function imports a PyTorch network as an uninitialized or initialized
object. If the imported network is uninitialized, before you use the network, do one of
Add an input layer to the imported network and initialize the network by using the
addInputLayerfunction. For an example, see Import Network from PyTorch and Add Input Layer.
A PyTorch network can be imported as an initialized
object by using the
PyTorchInputSizes name-value argument. For an
example, see Import Network from PyTorch using PyTorchInputSizes.
To predict using a
dlnetwork object, you must convert the input data
object with the appropriate data format. For an example, see Import TensorFlow Network and Classify Image. Use this table to
choose the right data format for each input type and layer.
|Input Layer **
|Input Format *
|2-D image sequence
|3-D image sequence
* In Deep Learning Toolbox, each data format must be one of these labels:
B— Batch observations
T— Time or sequence
dlnetwork object does not require an input layer. The network
can infer the input layer type from the input data format.
For more information on data formats, see
Input Data Preprocessing
To use a pretrained network for prediction or transfer learning on new images, you must preprocess your images in the as same way the images that were used to train the imported model. The most common preprocessing steps are resizing images, subtracting image average values, and converting the images from BGR format to RGB format.
For more information about preprocessing images for training and prediction, see Preprocess Images for Deep Learning.
- Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX
- Pretrained Deep Neural Networks
- Select Function to Import ONNX Pretrained Network