Clear Filters
Clear Filters

Initializing LSTM which is imported using ONNX

13 views (last 30 days)
Andreas
Andreas on 17 Jul 2024
Answered: Andreas on 23 Jul 2024
Hi,
I am training an LSTM for RL using Ray in Python. I would like to export this model using ONNX and afterwards import it in Matlab. As far as I have understood, I need to initialize the model in matlab after importing. However, I cannot find out the correct input shapes/formats in Matlab to make this work.
Minimum working example:
Python code to train LSTM:
import torch
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
% Config Algorithm
algo = (
PPOConfig()
.env_runners(num_env_runners=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.training(model={"use_lstm": True})
.build()
)
% train for 2 episodes
for i in range(2):
result = algo.train()
% get policiy
ppo_policy = algo.get_policy()
% batch size
B=1
% initialize LSTM input:
input_dict = {"obs": torch.tensor(np.random.uniform(0, 1.0, size=(B,4)).astype(np.float32))}
state_batches = [torch.zeros((B,256), dtype=torch.float32),torch.zeros((B,256), dtype=torch.float32)]
seq_lens = torch.ones([B], dtype=int)
% apply LSTM to inputs
policy = algo.get_policy()
model = policy.model
print(model(input_dict, state=state_batches, seq_lens=seq_lens))
% save model to ONNX
ppo_policy.export_model('onnx14', onnx=14)
Code in Matlab:
% Import model from where I saved it
net = importNetworkFromONNX('path/to/onnx-model');
% input shapes
obs_size = [1,4];
state_size=[2,1,256];
seq_lens_size=[1];
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
Error message:
I appreciate any help!
Best,
Andreas
  2 Comments
Nilesh
Nilesh on 17 Jul 2024
Edited: Nilesh on 17 Jul 2024
Hello Andreas,
Have you tried asking your issue with ChatGPT.

Sign in to comment.

Answers (3)

Joss Knight
Joss Knight on 18 Jul 2024
This code is suspect
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
I think your network has a single input, so you need to pass a single input to initialize (along with the network), basically just some example input exactly like you want to pass to predict. I think you have two channels and a sequence length of 256? And one of your dimensions is Time so you need a T dimension. And I don't think you have any spatial dimensions, so no S labels. So you need something like
exampleInput = dlarray(rand(2,1,256),'CBT');
net = initialize(net, exampleInput);
Or if you prefer, a permutation of that like
exampleInput = dlarray(rand(256,2,1),'TCB');
net = initialize(net, exampleInput);
If this doesn't work, try running analyzeNetwork(net) to see where your inputs are and we can work out what to expect.
  1 Comment
Andreas
Andreas on 23 Jul 2024
Hi,
the network does not have a single input. I managed to solve the issue, see below for my response. Thank you, for your help anyway!

Sign in to comment.


Kaustab Pal
Kaustab Pal on 19 Jul 2024
It seems you want to determine the input dimension of your imported network. You can easily find this information using the analyzeNetwork function. This function provides an interactive visualization of the network architecture and detailed information, including:
  • Layer types
  • Sizes and formats of layer learnable parameters
  • States and activations
  • Total number of learnable parameters
The activation size of the topmost layer will give you the input dimension.
Additionally, when creating dlarray objects in MATLAB, you need to specify the format, which must follow this order:
  • "S" (Spatial)
  • "C" (Channel)
  • "B" (Batch)
  • "T" (Time)
  • "U" (Unspecified)
For more details, you can refer to the following links:
  1. analyzeNetwork Documentation: https://www.mathworks.com/help/deeplearning/ref/analyzenetwork.html#mw_bdd24886-fa03-4540-a111-391541a0a684
  2. dlarray Documentation:: https://www.mathworks.com/help/deeplearning/ref/dlarray.html#d126e57736:~:text=When%20you%20create%20a%20formatted%20dlarray%20object%2C%20the%20software%20automatically%20permutes%20the%20dimensions%20such%20that%20the%20format%20has%20dimensions%20in%20this%20order%3A
Hope this helps.

Andreas
Andreas on 23 Jul 2024
Helly everyone,
thank you for your help. Unfortunately, I had to work around this issue but I could solve it in the end. I believe the reason for matlab struggling is that within Ray's Rllib the models contain a lot of complicated overhead. In particular the inputs to the network are lists/dicts etc which undergo quite some reformatting which seemed to cause some issues. In the end, what I did is extract the actual torch models which are relevant from the trained Rllib object and joined them in a new torch.nn.Module object. For this object it worked out just fine using torch.onnx.export.
Thank you all for your help.
Best, Andreas

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!