I'm trying to import a custom pytorch model in matlab using "importNetworkFromPyTorch", but I'm having some issues with the matlab autonomously generated function "pyUnsqueeze"
My model receives as imput an array with shape [1 10 8] and ouputs one shaped [1,10,1].
I import my net as
net = importNetworkFromPyTorch("path_to_my_model", 'PyTorchInputSizes', [1,10,8])
then to initialze it I run as prompted by the warning on the command window:
dlX1 = dlarray(rand([1,10,8]), 'UUU');
and then
net = initialize(net, dlX1);
upon which I get the error:
Error using dlnetwork/initialize (line 600)
Layer 'TopLevelModule': Invalid network.
Layer 'ATEN3': Error using the predict function in layer empty_model_traced.TopLevelModule_ATEN3. The function threw an
error and could not be executed.
Unrecognized function or variable 'newShape'.
Error in empty_model_traced.ops.pyUnsqueeze (line 41)
Yval = reshape(Xval, newShape);
Error in empty_model_traced.TopLevelModule_ATEN3/tracedPyTorchFunction (line 91)
[unsqueeze_input_1] = empty_model_traced.ops.pyUnsqueeze(arange_38, Constant_30);
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error in empty_model_traced.TopLevelModule_ATEN3/predict (line 46)
[unsqueeze_input_1] = tracedPyTorchFunction(obj,size_batchu_1,false,"predict");
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The pyUnsqueeze.m function is:
function Y = pyUnsqueeze(X, dim)
import empty_model_traced.ops.*
newShape = flip(size(Xval));
newShape = ones(1, Yrank);
knownSizes = setdiff(1:Yrank, dim);
newShape(knownSizes) = size(Xval, 1:numel(knownSizes));
Yval = reshape(Xval, newShape);
Yval = dlarray(Yval, repmat('U', 1, max(2,Yrank)));
Y = struct('value', Yval, 'rank', Yrank);
With some debugging i get that Yrank = 2, and dim = 2, for which there is no case in the if-else list, hence newshape is never defined.
I can track back this to the python code:
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
Now, is this the correct behavior of the matlab function, and I'm doing something wrong in my code or is this a bug in the auto-generated function?