How are the functions "train" and "trainNetwork" different underneath?

39 views (last 30 days)
How are the functions "train" and "trainNetwork" different underneath? When should I use "train" instead of "trainNetwork" or vice-versa?

Accepted Answer

MathWorks Support Team
MathWorks Support Team on 26 Jul 2018
"train" and "trainNetwork" and their associated functions correspond to two totally independent universes of network learning. "train" and associated functions are used to create and train a "Shallow Neural Network" which is useful for function approximation and clustering. On the other hand, "trainNetwork" and its associated functions are used to create and train a "Deep Neural Network"* which *is predominantly used for image classification.
The functions “train” and “trainNetwork” sit on top of totally independent code bases. They are part of the same toolbox (Neural Networks Toolbox), but independent of each other.
While you can train deep network or shallow networks with both functions, there are a few reasons that make training deep networks with “trainNetworks” easier:
  • Deep network usually require a lot of training data, which usually don’t fit in RAM (or in the GPU if training on GPU). The function “trainNetwork” was designed with that into account using algorithms like stochastic gradient descend, and ADAM optimization algorithm, that work on mini-batches, while keeping the rest of the training data out of memory.
  • The function “trainNetwork” (and all the associated functions) is more tightly integrated with cuDNN (NVIDIA’s low-level library to do NN on their GPUs) (e.g., to support fast operations like convolutions).
  • The ecosystem around “trainNetwork” closely follows new trends in deep learning, while the one for “train” focuses on the classic algorithms used for training neural networks.
If you set the same network architecture for both functions, the codebase optimizing the network parameters will be different, and the algorithms doing so will also be different. The function “train” offers more variety of algorithms. The function “trainNetwork” offers algorithms used in recent state of the art research on deep learning.
Boris Hristovski
Boris Hristovski on 18 May 2021
Edited: Boris Hristovski on 18 May 2021
I wish the User guide was clearer on important matters like this. A structure diagram/ map of the language would help.

Sign in to comment.

More Answers (0)


Find more on Deep Learning in Parallel and in the Cloud in Help Center and File Exchange




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!