Main Content

End-to-End Deep Speaker Separation

This example showcases an end-to-end deep learning network for speaker-independent speech separation.

Introduction

Speaker separation is a challenging and critical speech processing task. A number of speaker separation methods based on deep learning have been proposed recently, most of which rely on time-frequency transformations of the time-domain audio mixture (See Cocktail Party Source Separation Using Deep Learning Networks (Audio Toolbox) for an implementation of such a deep learning system).

Solutions based on time-frequency methods suffer from two main drawbacks:

  • The conversion of the time-frequency representations back to the time domain requires phase estimation, which introduces errors and leads to imperfect reconstruction.

  • Relatively long windows are required to yield high resolution frequency representations, which leads to high computational complexity and unacceptable latency for real-time scenarios.

In this example, you train a deep learning speech separation network based on the Conv-TasNet architecture [1]. The Conv-TasNet model acts directly on the audio signal and bypasses the issues arising from time-frequency transformations.

To use a pretrained speaker separation network, see separateSpeakers (Audio Toolbox). The separateSpeakers function separates speakers using either the (transformer-based) SepFormer architecture or a Conv-TasNet architecture.

For a comparison of the performance of the different models, see Compare Speaker Separation Models (Audio Toolbox).

Optionally Reduce Data Set

To train the network with the entire data set and achieve the highest possible accuracy, set speedupExample to false. To run this example more quickly, set speedupExample to true.

speedupExample = false;

Train Speech Separation Network

Examine Network Architecture

Capture.PNG

The network is based on [1] and consists of three stages: Encoding, mask estimation or separation, and decoding.

  • The encoder transforms the time-domain input mixture signals into an intermediate representation using convolutional layers.

  • The mask estimator computes one mask per speaker. The intermediate representation of each speaker is obtained by multiplying the encoder's output by its respective mask. The mask estimator is comprised of 32 blocks of convolutional and normalization layers with skip connections between blocks.

  • The decoder transforms the intermediate representations to time-domain separated speech signals using transposed convolutional layers.

The operation of the network is encapsulated in separateSpeakersConvTasNet.

To calculate loss, use the supporting functions uPIT for utterance-level permutation-invariant training (uPIT) and SISNR to calculate scale-invariant signal-to-noise ratio (SI-SNR) [1]. SI-SNR encourages the network to learn how to separate signals regardless of their initial relative energy. Without scale invariance, the network would learn how to recover the more dominant energy signal at the cost of the less dominant. uPIT resolves the problem that there is no a priori way to associate the predictions with the targets by minimizing the loss of the best permutation between predictions and targets.

Download Train Data Set

Use a subset of the LibriSpeech data set [2] to train the network. The LibriSpeech data set is a large corpus of read English speech sampled at 16 kHz. The data is derived from audiobooks read from the LibriVox project.

Download the LibriSpeech data set. If speedupExample is true, download the approximately 322 MB dev-clean set. If speedupExample is to false, download the approximately 28 GB train-clean-360 set.

downloadDatasetFolder = tempdir;

if speedupExample
    filename = "dev-clean.tar.gz";
    datasetFolder = fullfile(downloadDatasetFolder,"LibriSpeech","dev-clean");
else
    filename = "train-clean-360.tar.gz";
    datasetFolder = fullfile(downloadDatasetFolder,"LibriSpeech","train-clean-360");
end

url = "http://www.openSLR.org/resources/12/" + filename;
if ~datasetExists(datasetFolder)
    gunzip(url,downloadDatasetFolder);
    unzippedFile = fullfile(downloadDatasetFolder,filename);
    untar(unzippedFile{1}(1:end-3),downloadDatasetFolder);
end

Preprocess Data Set

The LibriSpeech data set is comprised of many audio files with a single speaker. It does not contain mixture signals where 2 or more persons are speaking simultaneously.

You will process the original data set to create a new data set that is suitable for training the speech separation network.

The steps for creating the training data set are encapsulated in createTrainingDataset. The function creates mixture signals comprised of utterances of two random speakers. The function returns three audio datastores:

  • mixDatastore points to mixture files (where two speakers are talking simultaneously).

  • speaker1Datastore points to files containing the isolated speech of the first speaker in the mixture.

  • speaker2Datastore points to files containing the isolated speech of the second speaker in the mixture.

Define the mini-batch size and the maximum training signal length (in number of samples).

miniBatchSize = 4;
duration = 5*8000;

Create the training data set.

[mixDatastore,speaker1Datastore,speaker2Datastore] = createTrainingDataset(datasetFolder,downloadDatasetFolder,duration);
Starting parallel pool (parpool) using the 'local' profile ...
Preserving jobs with IDs: 1 because they contain crash dump files.
You can use 'delete(myCluster.Jobs)' to remove all jobs created with profile Processes. To create 'myCluster' use 'myCluster = parcluster('Processes')'.
29-Nov-2023 18:42:37: Job Queued. Waiting for parallel pool job with ID 3 to start ...
29-Nov-2023 18:43:37: Job Queued. Waiting for parallel pool job with ID 3 to start ...
29-Nov-2023 18:44:38: Job Queued. Waiting for parallel pool job with ID 3 to start ...
29-Nov-2023 18:45:38: Job Queued. Waiting for parallel pool job with ID 3 to start ...
Connected to parallel pool with 6 workers.

Combine the datastores. This ensures that the files stay in the correct order when you shuffle them at the start of each new epoch in the training loop.

ds = combine(mixDatastore,speaker1Datastore,speaker2Datastore);

Train on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™.

executionEnvironment = "auto"; % Change to "cpu" to train on a CPU.

Create a minibatch queue from the datastore.

mqueue = minibatchqueue(ds, ...
    MiniBatchSize=miniBatchSize, ...
    OutputEnvironment=executionEnvironment, ...
    OutputAsDlarray=true, ...
    MiniBatchFormat="SCB", ...
    MiniBatchFcn=@preprocessMiniBatch);

Specify Training Options

Define training parameters.

Train for 10 epochs.

if speedupExample
    numEpochs = 1;
else
    numEpochs = 10;
end

Specify the options for Adam optimization. Set the initial learning rate to 1e-3. Use a gradient decay factor of 0.9 and a squared gradient decay factor of 0.999.

learnRate = 1e-3;
averageGrad = [];
averageSqGrad = [];

gradDecay = 0.9;
sqGradDecay = 0.999;

Set Up Validation Data

Create a validation signal to track the progress while training.

If a GPU is available, move the validation signal to the GPU.

multipleSpeakersSignal = audioread("MultipleSpeakers-16-8-4channel-5secs.flac");
s1 = multipleSpeakersSignal(:,1);
s2 = multipleSpeakersSignal(:,2);

mix = s1 + s2;
mix = mix/max(abs(mix));

mix = dlarray(mix,"SCB");
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    mix = gpuArray(mix);
end

Define the number of iterations between validation SI-SNR computations.

numIterPerValidation = 100;

Define a vector to hold the validation SI-SNR from each iteration.

valSNR = [];

Define a variable to hold the best validation SI-SNR.

bestSNR = -Inf;

Define a variable to hold the epoch in which the best validation score occurred.

bestEpoch = 1;

Initialize Network

Initialize the network parameters. learnables is a structure containing the learnable parameters from the network layers. states is a structure containing the states from the normalization layers.

[learnables,states] = initializeNetworkParams;

Train Network

Execute the training loop. This can take many hours to run.

The validation SI-SNR is computed periodically. If the SI-SNR is the best value so far, the network parameters are saved to params.mat.

iteration = 0;

monitor = trainingProgressMonitor( ...
    Metrics=["TrainingLoss","ValidationLoss"], ...
    Info=["Epoch","LearnRate"], ...
    XLabel="Iteration");
groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"])


% Loop over epochs.
for jj = 1:numEpochs

    updateInfo(monitor,LearnRate=learnRate,Epoch=jj)

    % Shuffle the data
    shuffle(mqueue)

    while hasdata(mqueue)

        % Compute validation loss/SNR periodically
        if mod(iteration,numIterPerValidation)==0

            [z1,z2] = separateSpeakersConvTasNet(mix,learnables,states,false);

            l = uPIT(z1,s1,z2,s2);
            valSNR(end+1) = l; %#ok

            recordMetrics(monitor,iteration,ValidationLoss=-l);

            if l > bestSNR
                bestSNR = l;
                bestEpoch = jj;
                filename = "params.mat";
                save(filename,"learnables","states");
            end
        end

        iteration = iteration + 1;

        % Get a new batch of training data
        [mixBatch,x1Batch,x2Batch] = next(mqueue);

        % Evaluate the model gradients and states using dlfeval and the modelLoss function.
        [loss,gradients,states] = dlfeval(@modelLoss,mixBatch,x1Batch,x2Batch,learnables,states,miniBatchSize);

        recordMetrics(monitor,iteration,TrainingLoss=loss);

        % Update the network parameters using the ADAM optimizer.
        [learnables,averageGrad,averageSqGrad] = adamupdate(learnables,gradients,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);

        if monitor.Stop
            return
        end
    end

    % Reduce the learning rate if the validation accuracy did not improve
    % during the epoch
    if bestEpoch ~= jj
        learnRate = learnRate/2;
    end
    if monitor.Stop
        return
    end
end

References

[1] Yi Luo, Nima Mesgarani, "Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation," 2019 IEEE/ACM transactions on audio, speech, and language processing, vol. 29, issue 8, pp. 1256-1266.

[2] V. Panayotov, G. Chen, D. Povey and S. Khudanpur, "Librispeech: An ASR corpus based on public domain audio books," 2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Brisbane, QLD, 2015, pp. 5206-5210, doi: 10.1109/ICASSP.2015.7178964

Supporting Functions

Create Training Data Set

function [mixDatastore,speaker1Datastore,speaker2Datastore] = createTrainingDataset(datasetFolder,downloadDatasetFolder,duration)
%createTrainingDataset Create training data set

newDatasetPath = fullfile(downloadDatasetFolder,"speech-sep-dataset");

% Create the new data set folders.
if isfolder(newDatasetPath)
    rmdir(newDatasetPath,"s")
end
mkdir(newDatasetPath);
mkdir(fullfile(newDatasetPath,"sp1"));
mkdir(fullfile(newDatasetPath,"sp2"));
mkdir(fullfile(newDatasetPath,"mix"));

% Create an audioDatastore that points to the LibriSpeech data set.
ads = audioDatastore(datasetFolder,IncludeSubfolders=true);

% The LibriSpeech data set is comprised of signals from different speakers.
% The unique speaker ID is encoded in the audio file names.

% Extract the speaker IDs from the file names.
ads.Labels = folders2labels(ads);

% You will create mixture signals comprised of utterances of two random speakers.  
% Randomize the IDs of all the speakers.
names = unique(ads.Labels);
names = names(randperm(length(names)));

% In this example, you create training signals based on 400 speakers. You
% generate mixture signals based on combining utterances from 200 pairs of
% speakers. 

% Define the two groups of speakers.
numPairs = min(200,floor(numel(names)/2)); 
n1 = names(1:numPairs);
n2 = names(numPairs+1:2*numPairs);

% Create the new data set. For each pair of speakers: 
% * Use subset to create two audio datastores, each containing files
%   corresponding to their respective speaker.
% * Adjust the datastores so that they have the same number of files.
% * Combine the two datastores using combine. 
% * Use writeall to preprocess the files of the combined datastore and write
%   the new resulting signals to disk.

% The preprocessing steps performed to create the signals before writing
% them to disk are encapsulated in the function createTrainingFiles. For
% each pair of signals:
% * Downsample the signals from 16 kHz to 8 kHz. 
% * Randomly select duration seconds from each downsampled signal. 
% * Create the mixture by adding the 2 signal chunks.
% * Adjust the signal power to achieve a randomly selected
%   signal-to-noise value in the range [-5,5] dB.
% * Write the 3 signals (corresponding to the first speaker, the second
%   speaker, and the mixture, respectively) to disk.
parfor index = 1:length(n1)
    spkInd1 = n1(index);
    spkInd2 = n2(index);
    spk1ds = subset(ads,ads.Labels==spkInd1);
    spk2ds = subset(ads,ads.Labels==spkInd2);
    L = min(length(spk1ds.Files),length(spk2ds.Files));
    spk1ds = subset(spk1ds,1:L);
    spk2ds = subset(spk2ds,1:L);
    pairds = combine(spk1ds,spk2ds);
    writeall(pairds,newDatasetPath, ...
        FolderLayout="flatten", ...
        WriteFcn=@(data,writeInfo,outputFmt)createTrainingFiles(data,writeInfo,outputFmt,duration));
end

% Create audio datastores pointing to the files corresponding to the individual speakers and the mixtures.
mixDatastore = audioDatastore(fullfile(newDatasetPath,"mix"));
speaker1Datastore = audioDatastore(fullfile(newDatasetPath,"sp1"));
speaker2Datastore = audioDatastore(fullfile(newDatasetPath,"sp2"));
end

Create Training Files

function mix = createTrainingFiles(data,writeInfo,~,varargin)
%createTrainingFiles Preprocess the training signals and write them to disk

duration = varargin{1};

x1 = data{1};
x2 = data{2};

% Resample from 16 kHz to 8 kHz
x1 = resample(x1,1,2);
x2 = resample(x2,1,2);

% Read a chunk from the first speaker signal
x1 = readSpeakerSignalChunk(duration,x1);

% Read a chunk from the second speaker signal
x2 = readSpeakerSignalChunk(duration,x2);

% SNR [-5 5] dB
s = snr(x1,x2);
targetSNR = 10*(rand - 0.5);
x1b = 10^((targetSNR-s)/20)*x1;
mix = x1b + x2;
mix = mix./max(abs(mix));

[~,s1] = fileparts(writeInfo.ReadInfo{1}.FileName);
[~,s2] = fileparts(writeInfo.ReadInfo{2}.FileName);
name = sprintf("%s-%s.wav",s1,s2);

audiowrite(sprintf("%s",fullfile(writeInfo.Location,"sp1",name)),x1,8000);
audiowrite(sprintf("%s",fullfile(writeInfo.Location,"sp2",name)),x2,8000);
audiowrite(sprintf("%s",fullfile(writeInfo.Location,"mix",name)),mix,8000);

end

Read Speaker Signal Chunk

function sequence = readSpeakerSignalChunk(duration,sequence)
%readSpeakerSignalChunk Read a chunk from the speaker signal
if length(sequence)<=duration
    sequence = [sequence;zeros(duration-length(sequence),1)];
else
    startInd = randi([1 length(sequence)-duration],1);
    endInd = startInd + duration - 1;
    sequence = sequence(startInd:endInd);
end
sequence = sequence./max(abs(sequence));
end

Model Loss

function [loss,gradients,states] = modelLoss(mix,x1,x2,learnables,states,miniBatchSize)
%modelLoss Compute the model loss, gradients, and states

[y1,y2,states] = separateSpeakersConvTasNet(mix,learnables,states,true);

m = uPIT(x1,y1,x2,y2);
l = sum(m);
loss = -l./miniBatchSize;

gradients = dlgradient(loss,learnables);

end

Utterance-Level Permutation Invariant Training (uPIT)

function m = uPIT(x1,y1,x2,y2)
%uPIT Compute utterance-level permutation invariant training
v1 = SISNR(y1,x1);
v2 = SISNR(y2,x2);
m1 = mean([v1;v2]);

v1 = SISNR(y2,x1);
v2 = SISNR(y1,x2);
m2 = mean([v1;v2]);

m = max(m1,m2);
end

Scale Invariant Signal-To-Noise Ratio (SI-SNR)

function z = SISNR(x,y)
%SISNR Compute SI-SNR
x = x - mean(x);
y = y - mean(y);

t = sum(x.*y).*y./(sum(y.^2)+eps);

z = 20*log((sqrt(sum(t.^2))+eps)./sqrt((sum((x-t).^2))+eps))/log(10);

end

Initialize Network Parameters

function [learnables,states] = initializeNetworkParams
%initializeNetworkParams Initialize the learnables and states of the
% network

learnables.Conv1W = initializeGlorot(20,1,256);
learnables.Conv1B = dlarray(zeros(256,1,"single"));

learnables.ln_weight = dlarray(ones(1,256,"single"));
learnables.ln_bias = dlarray(zeros(1,256,"single"));

learnables.Conv2W = initializeGlorot(1,256,256);
learnables.Conv2B = dlarray(zeros(256,1,"single"));

blk.Conv1B = dlarray(zeros(512,1,"single"));
blk.Prelu1 = dlarray(single(0.25));
blk.BN1Offset = dlarray(zeros(512,1,"single"));
blk.BN1Scale = dlarray(ones(512,1,"single"));
blk.Conv2B = dlarray(zeros(512,1,"single"));
blk.Prelu2 = dlarray(single(0.25));
blk.BN2Offset = dlarray(zeros(512,1,"single"));
blk.BN2Scale= dlarray(ones(512,1,"single"));
blk.Conv3B = dlarray(ones(256,1,"single"));

s.BN1Mean = dlarray(zeros(512,1,"single"));
s.BN1Var = dlarray(ones(512,1,"single"));
s.BN2Mean = dlarray(zeros(512,1,"single"));
s.BN2Var = dlarray(ones(512,1,"single"));

for index = 1:32
    blk.Conv1W = initializeGlorot(1,256,512);
    blk.Conv2W = initializeGlorot(3,1,512);
    blk.Conv2W = reshape(blk.Conv2W,[3 1 1 512]);
    blk.Conv3W = initializeGlorot(1,512,256); 
    learnables.Blocks(index) = blk;
    states(index) = s; %#ok
end

learnables.Conv3W = initializeGlorot(1,256,512);
learnables.Conv3B = dlarray(zeros(512,1,"single"));

learnables.TransConv1W = initializeGlorot(20,1,256);
learnables.TransConv1B = dlarray(zeros(1,1,"single"));

end

Glorot Initialization

function weights = initializeGlorot(filterSize,numChannels,numFilters)
% initializeGlorot - Perform Glorot initialization

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = numOut;

Z = 2*rand(sz,"single") - 1;
bound = sqrt(6 / (numIn + numOut));

weights = dlarray(bound*Z);
end

Separate Speakers Using Conv-TasNet

function [output1, output2, states] = separateSpeakersConvTasNet(input,learnables,states,training)
% separateSpeakersConvTasNet - Separate two speaker signals from a mixture input

if ~isdlarray(input)
    input = dlarray(input,"SCB");
end

x = dlconv(input,learnables.Conv1W,learnables.Conv1B,Stride=10);

x = relu(x);
x0 = x;

x = x-mean(x, 2);
x = x./sqrt(mean(x.^2,2) + 1e-5);
x = x.*learnables.ln_weight + learnables.ln_bias;

encoderOut = dlconv(x,learnables.Conv2W,learnables.Conv2B);

for index = 1:32
    [encoderOut,s] = convBlock(encoderOut,index-1,learnables.Blocks(index),states(index),training);
    states(index) = s;
end

masks = dlconv(encoderOut,learnables.Conv3W,learnables.Conv3B);
masks = relu(masks);

mask1 = masks(:,1:256,:);
mask2 = masks(:,257:512,:);

out1 = x0.*mask1;
out2 = x0.*mask2;

weights = learnables.TransConv1W;
bias = learnables.TransConv1B;
output2 = dltranspconv(out1,weights,bias,Stride=10);
output1 = dltranspconv(out2,weights,bias,Stride=10);

if ~training
    output1 = gather(extractdata(output1));
    output2 = gather(extractdata(output2));

    output1 = output1./max(abs(output1));
    output2 = output2./max(abs(output2));
end

end

Conv-TasNet - Convolutional Block

function [output,state] = convBlock(input,count,learnables,state,training)

% Conv:
conv1Out = dlconv(input,learnables.Conv1W,learnables.Conv1B);

% PRelu:
conv1Out = relu(conv1Out) - learnables.Prelu1.*relu(-conv1Out);

% BatchNormalization:
offset = learnables.BN1Offset;
scale = learnables.BN1Scale;
datasetMean = state.BN1Mean;
datasetVariance = state.BN1Var;
if training
    [batchOut,dsmean,dsvar] = batchnorm(conv1Out,offset,scale,datasetMean,datasetVariance);
    state.BN1Mean = dsmean;
    state.BN1Var = dsvar;
else
    batchOut = batchnorm(conv1Out,offset,scale,datasetMean,datasetVariance);
end

% Conv:
padding = [1 1] * 2^(mod(count,8));
dilationFactor = 2^(mod(count,8));
convOut = dlconv(batchOut,learnables.Conv2W,learnables.Conv2B,DilationFactor=dilationFactor,Padding=padding);

% PRelu:
convOut = relu(convOut) - learnables.Prelu2.*relu(-convOut);

% BatchNormalization:
if training
    [batchOut,dsmean,dsvar] = batchnorm(convOut,learnables.BN2Offset,learnables.BN2Scale,state.BN2Mean,state.BN2Var);
    state.BN2Mean = dsmean;
    state.BN2Var = dsvar;
else
    batchOut = batchnorm(convOut,learnables.BN2Offset,learnables.BN2Scale,state.BN2Mean,state.BN2Var);
end

% Conv:
output = dlconv(batchOut,learnables.Conv3W,learnables.Conv3B);

% Skip connection
output = output + input;

end

Preprocess Mini Batch

function [x1Batch,x2Batch,mixBatch] = preprocessMiniBatch(x1Batch,x2Batch,mixBatch)
% preprocessMiniBatch - Preprocess mini-batch
x1Batch = cat(3,x1Batch{:});
x2Batch = cat(3,x2Batch{:});
mixBatch = cat(3,mixBatch{:});
end