how to make cross attention use attentionlayer?
Show older comments
I want to replace the dual-branch merge section of the model in the following link with cross-attention for fusion, but it's not successful. Is my operation incorrect? I have written an example, but I still don't understand how to embed it into the model in the link.
net one:(failure, loss dont down)
initialLayers = [
sequenceInputLayer(1, "MinLength", numSamples, "Name", "input", "Normalization", "zscore", "SplitComplexInputs", true)
convolution1dLayer(7, 2, "stride", 1)
];
stftBranchLayers = [
stftLayer("TransformMode", "squaremag", "Window", hann(64), "OverlapLength", 52, "Name", "stft", "FFTLength", 256, "WeightLearnRateFactor", 0 )
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="stft_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "stft_conv_1")
layerNormalizationLayer("Name", "stft_layernorm_1")
reluLayer("Name", "stft_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "stft_conv_2")
layerNormalizationLayer("Name", "stft_layernorm_2")
reluLayer("Name", "stft_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "stft_conv_3")
layerNormalizationLayer("Name", "stft_layernorm_3")
reluLayer("Name", "stft_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 2], "Name", "stft_maxpool_3")
flattenLayer("Name", "stft_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_stft")
];
cwtBranchLayers = [
cwtLayer("SignalLength", numSamples, "TransformMode", "squaremag", "Name","cwt", "WeightLearnRateFactor", 0);
functionLayer(@(x)dlarray(x, 'SCBS'), Formattable=true, Acceleratable=true, Name="cwt_reformat")
convolution2dLayer([4, 8], 16, "Padding", "same", "Name", "cwt_conv_1")
layerNormalizationLayer("Name", "cwt_layernorm_1")
reluLayer("Name", "cwt_relu_1")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_1")
convolution2dLayer([4, 8], 24, "Padding", "same", "Name", "cwt_conv_2")
layerNormalizationLayer("Name", "cwt_layernorm_2")
reluLayer("Name", "cwt_relu_2")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_2")
convolution2dLayer([4, 8], 32, "Padding", "same", "Name", "cwt_conv_3")
layerNormalizationLayer("Name", "cwt_layernorm_3")
reluLayer("Name", "cwt_relu_3")
maxPooling2dLayer([4, 8], "Stride", [1 4], "Name", "cwt_maxpool_3")
flattenLayer("Name", "cwt_flatten")
dropoutLayer(0.5)
fullyConnectedLayer(100,"Name","fc_cwt")
];
finalLayers = [
attentionLayer(4,"Name","attention")
layerNormalizationLayer("Name","layernorm")
fullyConnectedLayer(48,"Name","fc_1")
fullyConnectedLayer(numel(waveformClasses),"Name","fc_2")
softmaxLayer("Name","softmax")
];
dlLayers2 = dlnetwork(initialLayers);
dlLayers2 = addLayers(dlLayers2, stftBranchLayers);
dlLayers2 = addLayers(dlLayers2, cwtBranchLayers);
dlLayers2 = addLayers(dlLayers2, finalLayers);
dlLayers2 = connectLayers(dlLayers2, "conv1d", "stft");
dlLayers2 = connectLayers(dlLayers2, "conv1d", "cwt");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/key");
dlLayers2 = connectLayers(dlLayers2,"fc_stft","attention/value");
dlLayers2 = connectLayers(dlLayers2,"fc_cwt","attention/query");
my example:(is it right ?)
numChannels = 10;
numObservations = 128;
numTimeSteps = 100;
X = rand(numChannels,numObservations,numTimeSteps);
X = dlarray(X);
Y = rand(numChannels,numObservations,numTimeSteps);
Y = dlarray(Y);
numHeads = 8;
outputSize = numChannels*numHeads;
WQ = rand(outputSize, numChannels, 1, 1);
WK = rand(outputSize, numChannels, 1, 1);
WV = rand(outputSize, numChannels, 1, 1);
WO = rand(outputSize, outputSize, 1, 1);
Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO);
function Z = crossAttention(X, Y, numHeads, WQ, WK, WV, WO)
queries = WQ * X;
keys = WK * Y;
values = WV * Y;
A = attention(queries, keys, values, numHeads, 'DataFormat', 'CBT');
Z = WO * A;
end
Accepted Answer
More Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!