Tips on Importing Models from TensorFlow, PyTorch, and ONNX
This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network or layer graph. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.
Import Functions of Deep Learning Toolbox
This table lists the Deep Learning Toolbox import functions. Use these functions to import networks or layer graphs from TensorFlow, PyTorch, and ONNX.
You must have the relevant support package to run these import functions. If the support package is not installed, each function provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.
Recommended Functions to Import TensorFlow Models
The Deep Learning Toolbox Converter for TensorFlow Models support package offers these functions:
importKerasNetwork— Import a TensorFlow model as a network.
importKerasLayers— Import a TensorFlow model as a layer graph.
importTensorFlowLayers functions are recommended over the
This table compares the Deep Learning Toolbox Converter for TensorFlow Models functions. The comparison highlights the reasons that the
importTensorFlowLayers functions are recommended over the
|Automatically generates custom layers||Yes||No|
|Supports TensorFlow 2||Yes||Limited|
|Can import network as ||Yes||No|
For more information on the advantages of migrating from TensorFlow 1 to TensorFlow 2, see Migrate
from TensorFlow 1.x to TensorFlow 2. For more information on the TensorFlow versions that the import functions support, see Limitations
importTensorFlowLayers) and Limitations
To import a TensorFlow model that is in the HDF5 format, instead of using
importKerasNetwork to import the model as a Deep Learning Toolbox network, convert the TensorFlow model to the
SavedModel format and use the
Autogenerated Custom Layers
importTensorFlowLayersfunctions can automatically generate custom layers when you import custom TensorFlow layers or when the software cannot convert TensorFlow layers into equivalent built-in MATLAB layers. For an example, see Import TensorFlow Network with Autogenerated Custom Layers. For a list of layers for which the software supports conversion, see TensorFlow-Keras Layers Supported for Conversion into Built-In MATLAB Layers.
importONNXLayersfunctions can also generate custom layers when the software cannot convert ONNX operators into equivalent built-in MATLAB layers. For an example, see Import ONNX Network with Autogenerated Custom Layers. For a list of layers for which the software supports conversion, see ONNX Operators Supported for Conversion into Built-In MATLAB Layers.
In rare cases, when
importONNXLayerscannot import an ONNX model into layers, you can use
importONNXFunctionto import the model as a function. For more information on how to select an ONNX import function, see Select Function to Import ONNX Pretrained Network.
importNetworkFromPyTorchfunction imports a PyTorch layer into MATLAB by trying these steps in order:
The function tries to import the PyTorch layer as a built-in MATLAB layer. For more information, see Conversion of PyTorch Layers.
The function tries to import the PyTorch layer as a built-in MATLAB function. For more information, see Conversion of PyTorch Layers.
The function tries to import the PyTorch layer as a custom layer. For an example, see Import Network from PyTorch and Find Generated Custom Layers.
The function imports the PyTorch layer as a custom layer with a placeholder function. For more information, see Placeholder Functions.
importNetworkFromPyTorch functions save the automatically
generated custom layers to a package in the current folder. For more information on the
custom layers package, see the
PackageName name-value argument of
importONNXLayers functions insert placeholder layers in the
place of TensorFlow layers or ONNX operators when these conditions apply:
The function cannot convert the TensorFlow layers or ONNX operators to built-in MATLAB layers. For lists of TensorFlow layers and ONNX operators for which the functions support conversion, see TensorFlow-Keras Layers Supported for Conversion into Built-In MATLAB Layers and ONNX Operators Supported for Conversion into Built-In MATLAB Layers, respectively.
The function cannot generate custom layers in place of the TensorFlow layers or ONNX operators that the function cannot convert to built-in MATLAB layers.
If these conditions apply, the
importONNXNetwork functions return an error. These flowcharts
describe these workflows.
To find the names and indices of the placeholder layers in the layer graph, use the
findPlaceholderLayers function. You can then replace a placeholder layer
with a built-in MATLAB layer, custom layer, or
functionLayerobject. For more information about custom layers, see Define Custom Deep Learning Layers. For an example with
functionLayer object, see Replace Unsupported Keras Layer with Function Layer. To replace a layer,
replaceLayer. For an example, see Import ONNX Model as Layer Graph with Placeholder Layers.
importNetworkFromPyTorch function generates a custom layer
with a placeholder function instead of a placeholder layer. For more information, see
Input Dimension Ordering
The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, and ONNX. This table compares input dimension ordering between platforms for different input types.
|Input Type||Dimension Ordering|
|2-D image sequence||HWCSN||NSWHC||NCSHW||NSCHW|
|3-D image sequence||HWDCSN||NSWHDC||NCSDHW||NSCHWD|
Variable names in the table:
N — Number of observations
C — Number of features or channels
H — Height of images
W — Width of images
D — Depth of images
S — Sequence length
Data Formats for Prediction with
importONNXNetwork functions can import a TensorFlow or ONNX model as a
dlnetwork object. Specify the type
of imported network by setting the
argument. For more details, see
importNetworkFromPyTorch function imports a PyTorch model as an uninitialized
dlnetwork object. Before you
use the network, do one of the following:
Add an input layer to the imported network and initialize the network by using the
addInputLayerfunction. For an example, see Import Network from PyTorch and Add Input Layer.
Initialize the network by using the
initializefunction and set the appropriate format. For an example, see Import Network from PyTorch and Initialize.
To predict using a
dlnetwork object, you must convert the input data
object with the appropriate data format. For an example, see Import TensorFlow Network as dlnetwork to Classify Image. Use this table to
choose the right data format for each input type and layer.
|Input Type||Input Layer **||Input Format *|
|2-D image sequence|
|3-D image sequence|
* In Deep Learning Toolbox, each data format must be one of these labels:
B— Batch observations
T— Time or sequence
dlnetwork object does not require an input layer. The network
can infer the input layer type from the input data format.
For more information on data formats, see
Input Data Preprocessing
Preprocessing data is a common first step in the deep learning workflow to prepare data in a format that the network can accept. You must preprocess the input data in the same way as the training data.
The input layer of the pretrained deep learning networks available in Deep Learning Toolbox performs some of the input data preprocessing. For example, the input
layer of the pretrained
normalizes the image input data. Display the
of the network input layer.
net = mobilenetv2; net.Layers(1).Normalization
ans = 'zscore'
Networks that you import from TensorFlow or ONNX might not have built-in preprocessing in the input layer. For example, the
input layer of the imported
MobileNetV2 from TensorFlow does not normalize the input image. Import
and display the
Normalization property of the network input
net = importTensorFlowNetwork("MobileNetV2", ... OutputLayerType="classification"); net.Layers(1).Normalization
ans = 'none'
Often, open-source repositories provide information about the required input data preprocessing. For example, see tf.keras.applications.mobilenet_v2.preprocess_input and ShuffleNet in ONNX Model Zoo. To learn more about how to preprocess images and other types of data in Deep Learning Toolbox, see Preprocess Images for Deep Learning and Preprocess Data for Deep Neural Networks.
- Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX
- Pretrained Deep Neural Networks
- Select Function to Import ONNX Pretrained Network