CSI Feedback with Autoencoders
This example shows how to use an autoencoder neural network to compress downlink channel state information (CSI) over a clustered delay line (CDL) channel. CSI feedback is in the form of a raw channel estimate array.
Introduction
In conventional 5G radio networks, CSI parameters are quantities related to the state of a channel that are extracted from the channel estimate array. The CSI feedback includes several parameters, such as the Channel Quality Indication (CQI), the precoding matrix indices (PMI) with different codebook sets, and the rank indicator (RI). The UE uses the CSI reference signal (CSI-RS) to measure and compute the CSI parameters. The user equipment (UE) reports CSI parameters to the access network node (gNB) as feedback. Upon receiving the CSI parameters, the gNB schedules downlink data transmissions with attributes such as modulation scheme, code rate, number of transmission layers, and MIMO precoding. This figure shows an overview of a CSI-RS transmission, CSI feedback, and the transmission of downlink data that is scheduled based on the CSI parameters.
The UE processes the channel estimate to reduce the amount of CSI feedback data. As an alternative approach, the UE compresses and feeds back the channel estimate array. After receipt, the gNB decompresses and processes the channel estimate to determine downlink data link parameters. The compression and decompression can be achieved using an autoencoder neural network [1, 2]. This approach eliminates the use of existing quantized codebook and can improve overall system performance.
This example uses a 5G downlink channel with these system parameters.
txAntennaSize = [2 2 2 1 1]; % rows, columns, polarizations, panels rxAntennaSize = [2 1 1 1 1]; % rows, columns, polarizations, panels rmsDelaySpread = 300e-9; % s maxDoppler = 5; % Hz nSizeGrid = 52; % Number resource blocks (RB) % 12 subcarriers per RB subcarrierSpacing = 15; % 15, 30, 60, 120 kHz numTrainingChEst = 15000; % Carrier definition carrier = nrCarrierConfig; carrier.NSizeGrid = nSizeGrid; carrier.SubcarrierSpacing = subcarrierSpacing
carrier = nrCarrierConfig with properties: NCellID: 1 SubcarrierSpacing: 15 CyclicPrefix: 'normal' NSizeGrid: 52 NStartGrid: 0 NSlot: 0 NFrame: 0 IntraCellGuardBands: [0×2 double] Read-only properties: SymbolsPerSlot: 14 SlotsPerSubframe: 1 SlotsPerFrame: 10
autoEncOpt.NumSubcarriers = carrier.NSizeGrid*12; autoEncOpt.NumSymbols = carrier.SymbolsPerSlot; autoEncOpt.NumTxAntennas = prod(txAntennaSize); autoEncOpt.NumRxAntennas = prod(rxAntennaSize);
Generate and Preprocess Data
The first step of designing an AI-based system is to prepare training and testing data. For this example, generate simulated channel estimates and preprocess the data. Use 5G Toolbox™ functions to configure a CDL-C channel. For more information on the data generation, see Prepare Data for CSI Processing example. Define a CDL-C channel.
channel = nrCDLChannel; channel.DelayProfile = 'CDL-C'; channel.DelaySpread = rmsDelaySpread; % s channel.MaximumDopplerShift = maxDoppler; % Hz channel.RandomStream = "Global stream"; channel.TransmitAntennaArray.Size = txAntennaSize; channel.ReceiveAntennaArray.Size = rxAntennaSize; channel.ChannelFiltering = false; % No filtering for
The number of samples to be generated for the dataset can be set below. For shorter runtime, the number of samples is set to 1500. Saved results use 15000 samples.
numSamples =
1500;
Select domain for preprocessed data preparation, truncation factor, and timing offset.
autoEncOpt.DataDomain ="Frequency-Spatial"; autoEncOpt.TruncationFactor =
10; autoEncOpt.ZeroTimingOffset = true;
If Parallel Computing Toolbox™ is available, set the autoEncOpt.UseParallel
variable to true to enable parallel data generation. Data generation takes about six minute for 15000 samples on a PC with Intel® Xeon® W-2133 CPU @ 3.60GHz and running in parallel on six workers.
autoEncOpt.UseParallel =
true;
Enable autoEncOpt.SaveData
to save preprocessed channel estimates to .mat
files.
autoEncOpt.SaveData =true; autoEncOpt.DataDir = "Data"; autoEncOpt.DataFilePrefix = "CH_est";
Generate Samples
The helperCSINetGenerateData
helper function generates 'numSamples
' of preprocessed channel estimates by using the process described in the Prepare Data for CSI Processing example. When you enable saveOptions.SaveData
, the function saves each channel estimate as an individual file in the saveOptions.DataDir
with the prefix of saveOptions.DataFilePrefix.
[HtruncReal,autoEncOpt] = helperCSINetGenerateData(numSamples,channel,carrier,autoEncOpt);
Starting CSI data generation 6 worker(s) running 00:00:30 - 0% Completed 00:00:37 - 0% Completed 00:00:38 - 0% Completed 00:00:50 - 0% Completed 00:00:50 - 0% Completed 00:00:53 - 0% Completed 00:00:58 - 0% Completed 00:01:05 - 100% Completed
Preprocess Samples
The HtruncReal variable contains of frames. Each frame has data for receive antennas, which are independent.
[maxDelay,nTx,Niq,nRx,Nframes] = size(HtruncReal)
maxDelay = 28
nTx = 8
Niq = 2
nRx = 2
Nframes = 750
Combine frames and antennas. Then, calculate the mean value and standard deviation, and then use the mean and standard deviation values to normalize the data.
HtruncReal = reshape(HtruncReal,maxDelay,nTx,Niq,nRx*Nframes);
meanVal = mean(HtruncReal,'all')
meanVal = single
-2.5427e-04
stdVal = std(HtruncReal,[],'all')
stdVal = single
16.1309
Separate the data into training, validation, and test sets. Also, normalize the data to achieve zero mean and a target standard deviation of 0.0212, which restricts most of the data to the range of [-0.5 0.5].
N = size(HtruncReal, 4); numTrain = floor(N*10/15)
numTrain = 1000
numVal = floor(N*3/15)
numVal = 300
numTest = floor(N*2/15)
numTest = 200
targetStd = 0.0212; HTReal = (HtruncReal(:,:,:,1:numTrain)-meanVal) ... /stdVal*targetStd+0.5; HVReal = (HtruncReal(:,:,:,numTrain+(1:numVal))-meanVal) ... /stdVal*targetStd+0.5; HTestReal = (HtruncReal(:,:,:,numTrain+numVal+(1:numTest))-meanVal) ... /stdVal*targetStd+0.5; autoEncOpt.MeanVal = meanVal; autoEncOpt.StdValue = stdVal; autoEncOpt.TargetSTDValue = targetStd;
Define and Train Neural Network Model
The second step of designing an AI-based system is to define and train the neural network model.
Define Neural Network
This example uses a modified version of the autoencoder neural network proposed in [1].
inputSize = [autoEncOpt.MaxDelay nTx 2]; % Third dimension is real and imaginary parts nLinear = prod(inputSize); nEncoded = 64; autoencoderNet = dlnetwork([ ... % Encoder imageInputLayer(inputSize,"Name","Htrunc", ... "Normalization","none","Name","Enc_Input") convolution2dLayer([3 3],2,"Padding","same","Name","Enc_Conv") batchNormalizationLayer("Epsilon",0.001,"Name","Enc_BN") leakyReluLayer(0.3,"Name","Enc_leakyRelu") flattenLayer("Name","Enc_flatten") fullyConnectedLayer(nEncoded,"Name","Enc_FC") sigmoidLayer("Name","Enc_Sigmoid") % Decoder fullyConnectedLayer(nLinear,"Name","Dec_FC") functionLayer(@(x)dlarray(reshape(x,maxDelay,nTx,2,[]),'SSCB'), ... "Formattable",true,"Acceleratable",true,"Name","Dec_Reshape") ]); autoencoderNet = ... helperCSINetAddResidualLayers(autoencoderNet, "Dec_Reshape"); autoencoderNet = addLayers(autoencoderNet, ... [convolution2dLayer([3 3],2,"Padding","same","Name","Dec_Conv") ... sigmoidLayer("Name","Dec_Sigmoid")]); autoencoderNet = ... connectLayers(autoencoderNet,"leakyRelu_2_3","Dec_Conv"); figure plot(autoencoderNet) title('CSI Compression Autoencoder')
Train Neural Network
Set the training options for the autoencoder neural network and train the network using the trainnet
(Deep Learning Toolbox) function. Training takes less than three minutes on an Intel® Xeon® W-2133 CPU @ 3.60GHz with NVIDIA® TITAN V GPU with compute capacity of 7.0 and 12 GB of memory. Set trainNow
to false
to load the pretrained network. Note that the saved network works for the following settings. If you change any of these settings, set trainNow
to true
.
txAntennaSize = [2 2 2 1 1]; % rows, columns, polarizations, panels rxAntennaSize = [2 1 1 1 1]; % rows, columns, polarizations, panels rmsDelaySpread = 300e-9; % s maxDoppler = 5; % Hz nSizeGrid = 52; % Number resource blocks (RB) % 12 subcarriers per RB subcarrierSpacing = 15;
trainNow =false; miniBatchSize = 1000; options = trainingOptions("adam", ... InitialLearnRate=0.01, ... LearnRateSchedule="piecewise", ... LearnRateDropPeriod=156, ... LearnRateDropFactor=0.5916, ... Epsilon=1e-7, ... MaxEpochs=1000, ... MiniBatchSize=miniBatchSize, ... Shuffle="every-epoch", ... ValidationData={HVReal,HVReal}, ... ValidationFrequency=20, ... Metrics="rmse", ... Verbose=true, ... ValidationPatience=20, ... OutputNetwork="best-validation-loss", ... ExecutionEnvironment="auto", ... Plots='training-progress')
options = TrainingOptionsADAM with properties: GradientDecayFactor: 0.9000 MaxEpochs: 1000 InitialLearnRate: 0.0100 LearnRateSchedule: 'piecewise' LearnRateDropFactor: 0.5916 LearnRateDropPeriod: 156 MiniBatchSize: 1000 Shuffle: 'every-epoch' CheckpointFrequencyUnit: 'epoch' PreprocessingEnvironment: 'serial' Verbose: 1 VerboseFrequency: 50 ValidationData: {[28×8×2×300 single] [28×8×2×300 single]} ValidationFrequency: 20 ValidationPatience: 20 Metrics: 'rmse' ObjectiveMetricName: 'loss' ExecutionEnvironment: 'auto' Plots: 'training-progress' OutputFcn: [] SequenceLength: 'longest' SequencePaddingValue: 0 SequencePaddingDirection: 'right' InputDataFormats: "auto" TargetDataFormats: "auto" ResetInputNormalization: 1 BatchNormalizationStatistics: 'auto' OutputNetwork: 'best-validation-loss' Acceleration: "auto" CheckpointPath: '' CheckpointFrequency: 1 CategoricalInputEncoding: 'integer' CategoricalTargetEncoding: 'auto' L2Regularization: 1.0000e-04 GradientThresholdMethod: 'l2norm' GradientThreshold: Inf SquaredGradientDecayFactor: 0.9990 Epsilon: 1.0000e-07 Show all accessible properties of TrainingOptionsADAM
lossFunc = @(x,t) nmseLossdB(x,t);
Use the normalized mean squared error (NMSE) between the network inputs and outputs in dB as the training loss function to find the best set of weights for the autoencoder.
if trainNow [net,trainInfo] = ... trainnet(HTReal,HTReal,autoencoderNet,lossFunc,options); %#ok<UNRCH> savedOptions = options; savedOptions.ValidationData = []; save("dCSITrainedNetwork_" ... + string(datetime("now","Format","dd_MM_HH_mm")), ... 'net','trainInfo','autoEncOpt','savedOptions') else autoEncOptCached = autoEncOpt; load("dCSITrainedNetwork",'net','trainInfo','autoEncOpt','savedOptions') if autoEncOpt.NumSubcarriers ~= autoEncOptCached.NumSubcarriers ... || autoEncOpt.NumSymbols ~= autoEncOptCached.NumSymbols ... || autoEncOpt.NumTxAntennas ~= autoEncOptCached.NumTxAntennas ... || autoEncOpt.NumRxAntennas ~= autoEncOptCached.NumRxAntennas ... || autoEncOpt.MaxDelay ~= autoEncOptCached.MaxDelay error("CSIExample:Missmatch", ... "Saved network does not match settings. Set trainNow to true.") end end
Test Trained Network
Use the predict
(Deep Learning Toolbox) function to process the test data.
HTestRealHat = predict(net,HTestReal);
Calculate the correlation and NMSE between the input and output of the autoencoder network. The correlation is defined as
where is the channel estimate at the input of the autoencoder and is the channel estimate at the output of the autoencoder. NMSE is defined as
where is the channel estimate at the input of the autoencoder and is the channel estimate at the output of the autoencoder.
rho = zeros(numTest,1); nmse = zeros(numTest,1); for n=1:numTest in = HTestReal(:,:,1,n) + 1i*(HTestReal(:,:,2,n)); out = HTestRealHat(:,:,1,n) + 1i*(HTestRealHat(:,:,2,n)); % Calculate correlation n1 = sqrt(sum(conj(in).*in,'all')); n2 = sqrt(sum(conj(out).*out,'all')); aa = abs(sum(conj(in).*out,'all')); rho(n) = aa / (n1*n2); % Calculate NMSE mse = mean(abs(in-out).^2,'all'); nmse(n) = 10*log10(mse / mean(abs(in).^2,'all')); end figure tiledlayout(2,1) nexttile histogram(rho,"Normalization","probability") grid on title(sprintf("Autoencoder Correlation (Mean \\rho = %1.5f)", ... mean(rho))) xlabel("\rho"); ylabel("PDF") nexttile histogram(nmse,"Normalization","probability") grid on title(sprintf("Autoencoder NMSE (Mean NMSE = %1.2f dB)",mean(nmse))) xlabel("NMSE (dB)"); ylabel("PDF")
End-to-End CSI Feedback System
This figure shows the end-to-end processing of channel estimates for CSI feedback. The UE uses the CSI-RS signal to estimate the channel response for one slot, . The preprocessed channel estimate, , is encoded by using the encoder portion of the autoencoder to produce a 1-by- compressed array. The compressed array is decompressed by the decoder portion of the autoencoder to obtain . Postprocessing produces .
To obtain the encoded array, split the autoencoder into two parts: the encoder network and the decoder network.
[encNet,decNet] = helperCSINetSplitEncoderDecoder(net,"Enc_Sigmoid");
plotNetwork(net,encNet,decNet)
Generate channel estimates for 100 slots. Each frame contains one slot and channel is reset after each frame.
numFrames = 100; [autoEncOpt,channel] = addSimOptions(autoEncOpt,channel,carrier); Hest = helperCSIGenerateData(numFrames,channel,carrier,autoEncOpt);
Encode and decode the channel estimates with Normalization
set to true
.
autoEncOpt.Normalization = true; codeword = helperCSINetEncode(encNet, Hest, autoEncOpt); Hhat = helperCSINetDecode(decNet, codeword, autoEncOpt);
Calculate the correlation and NMSE for the end-to-end CSI feedback system.
H = squeeze(mean(Hest,2)); rhoE2E = zeros(nRx,numFrames); nmseE2E = zeros(nRx,numFrames); for rx=1:nRx for n=1:numFrames out = Hhat(:,rx,:,n); in = H(:,rx,:,n); rhoE2E(rx,n) = helperCSINetCorrelation(in,out); nmseE2E(rx,n) = helperNMSE(in,out); end end figure tiledlayout(2,1) nexttile histogram(rhoE2E,"Normalization","probability") grid on title(sprintf("End-to-End Correlation (Mean \\rho = %1.5f)", ... mean(rhoE2E,'all'))) xlabel("\rho"); ylabel("PDF") nexttile histogram(nmseE2E,"Normalization","probability") grid on title(sprintf("End-to-End NMSE (Mean NMSE = %1.2f dB)", ... mean(nmseE2E,'all'))) xlabel("NMSE (dB)"); ylabel("PDF")
Effect of Quantized Codewords
Practical systems require quantizing the encoded codeword by using a small number of bits. Simulate the effect of quantization across the range of [2, 10] bits. The results show that 6-bits is enough to closely approximate the single-precision performance.
maxVal = 1; minVal = -1; idxBits = 1; nBitsVec = 2:10; rhoQ = zeros(nRx,numFrames,length(nBitsVec)); nmseQ = zeros(nRx,numFrames,length(nBitsVec)); for numBits = nBitsVec disp("Running for " + numBits + " bit quantization") % Quantize between 0:2^n-1 to get bits qCodeword = uencode(double(codeword*2-1), numBits); % Get back the floating point, quantized numbers codewordRx = (single(udecode(qCodeword,numBits))+1)/2; Hhat = helperCSINetDecode(decNet, codewordRx, autoEncOpt); H = squeeze(mean(Hest,2)); for rx=1:nRx for n=1:numFrames out = Hhat(:,rx,:,n); in = H(:,rx,:,n); rhoQ(rx,n,idxBits) = helperCSINetCorrelation(in,out); nmseQ(rx,n,idxBits) = helperNMSE(in,out); end end idxBits = idxBits + 1; end
Running for 2 bit quantization Running for 3 bit quantization Running for 4 bit quantization Running for 5 bit quantization Running for 6 bit quantization Running for 7 bit quantization Running for 8 bit quantization Running for 9 bit quantization Running for 10 bit quantization
figure tiledlayout(2,1) nexttile plot(nBitsVec,squeeze(mean(rhoQ,[1 2])),'*-') title("Correlation (Codeword-" + size(codeword,3) + ")") xlabel("Number of Quantization Bits"); ylabel("\rho") grid on nexttile plot(nBitsVec,squeeze(mean(nmseQ,[1 2])),'*-') title("NMSE (Codeword-" + size(codeword,3) + ")") xlabel("Number of Quantization Bits"); ylabel("NMSE (dB)") grid on
Further Exploration
The autoencoder is able to compress a [624 8] single-precision complex channel estimate array into a [64 1] single-precision array with a mean correlation factor of 0.99 and a NMSE of –19.55 dB. Using 6-bit quantization requires only 384 bits of CSI feedback data, which equates to a compression ratio of approximately 800:1.
display("Compression ratio is " + (624*8*32*2)/(64*6) + ":" + 1)
"Compression ratio is 832:1"
Investigate the effect of truncationFactor
on the system performance. Vary the 5G system parameters, channel parameters, and number of encoded symbols and then find the optimum values for the defined channel.
The NR PDSCH Throughput Using Channel State Information Feedback example shows how to use channel state information (CSI) feedback to adjust the physical downlink shared channel (PDSCH) parameters and measure throughput. Replace the CSI feedback algorithm with the CSI compression autoencoder and compare performance.
Helper Functions
Explore the helper functions to see the detailed implementation of the system.
Training Data Generation
helperCSINetGenerateData
helperCSIGenerateData
helperCSIChannelEstimate
Network Definition and Manipulation
helperCSINetDLNetwork
helperCSINetAddResidualLayers
helperCSINetSplitEncoderDecoder
CSI Processing
helperCSIPreprocessChannelEstimate
helperCSINetPostprocessChannelEstimate
helperCSINetEncode
helperCSINetDecode
Performance Measurement
helperCSINetCorrelation
helperNMSE
Appendix: Optimize Hyperparameters with Experiment Manager
Use the Experiment Manager app to find the optimal parameters. CSITrainingProject.mlproj
is a preconfigured project. Extract the project.
projectName = "CSITrainingProject"; if ~exist(projectName,"dir") projRoot = helperCSINetExtractProject(projectName); else projRoot = fullfile(exRoot(),projectName); end
To open the project, start the Experiment Manager app and open the following file.
disp(fullfile(".","CSITrainingProject","CSITrainingProject.prj"))
.\CSITrainingProject\CSITrainingProject.prj
The Optimize Hyperparameters experiment uses Bayesian optimization with hyperparameter search ranges specified as in the following figure. After you open the project, you can use the experiment setup function CSIAutoEncNN_setup
and the custom metric function is E2E_NMSE
.
The optimal parameters are 0.01 for initial learning rate, 156 iterations for the learning rate drop period, and 0.5916 for learning rate drop factor. After finding the optimal hyperparameters, train the network with same parameters multiple times to find the best trained network.
The ninth trial produced the best E2E_NMSE. This example uses this trained network as the saved network.
Configuring Batch Mode
When execution Mode is set to Batch Sequential
or Batch Simultaneous
, training data must be accessible to the workers in a location defined by the dataDir
variable in the Prepare Data in Bulk section. Set dataDir
to a network location that is accessible by the workers. For more information, see Offload Experiments as Batch Jobs to a Cluster (Deep Learning Toolbox).
Local Functions
function plotNetwork(net,encNet,decNet) %plotNetwork Plot autoencoder network % plotNetwork(NET,ENC,DEC) plots the full autoencoder network together % with encoder and decoder networks. fig = figure; t1 = tiledlayout(1,2,'TileSpacing','Compact'); t2 = tiledlayout(t1,1,1,'TileSpacing','Tight'); t3 = tiledlayout(t1,2,1,'TileSpacing','Tight'); t3.Layout.Tile = 2; nexttile(t2) plot(net) title("Autoencoder") nexttile(t3) plot(encNet) title("Encoder") nexttile(t3) plot(decNet) title("Decoder") pos = fig.Position; pos(3) = pos(3) + 200; pos(4) = pos(4) + 300; pos(2) = pos(2) - 300; fig.Position = pos; end function rootDir = exRoot() %exRoot Example root directory rootDir = fileparts(which("helperCSINetDLNetwork")); end function loss = nmseLossdB(x,xHat) %nmseLossdB NMSE loss in dB in = complex(x(:,:,1,:),x(:,:,2,:)); out = complex(xHat(:,:,1,:),xHat(:,:,2,:)); nmsePerObservation = helperNMSE(in,out); loss = mean(nmsePerObservation); end function [opt,channel] = addSimOptions(opt,channel,carrier) opt.SaveData = false; opt.Preprocess = false; if isa(channel,"nrCDLChannel") % Make sure that this is high enough for nrPerfectChannelEstimate to return % the full number of symbols worth of channel estimates opt.ChannelSampleDensity = 64*4; end waveInfo = nrOFDMInfo(carrier); channel.SampleRate = waveInfo.SampleRate; numSubCarriers = carrier.NSizeGrid*12; % 12 subcarriers per RB Tdelay = 1/(numSubCarriers*carrier.SubcarrierSpacing*1e3); opt.MaxDelay = round((channel.DelaySpread/Tdelay)*opt.TruncationFactor/2)*2; opt.NumSlotsPerFrame = 1; opt.Preprocess = false; opt.ResetChannelPerFrame = true; opt.Normalization = false; opt.Verbose = false; end
References
[1] Wen, Chao-Kai, Wan-Ting Shih, and Shi Jin. “Deep Learning for Massive MIMO CSI Feedback.” IEEE Wireless Communications Letters 7, no. 5 (October 2018): 748–51. https://doi.org/10.1109/LWC.2018.2818160.
[2] Zimaglia, Elisa, Daniel G. Riviello, Roberto Garello, and Roberto Fantini. “A Novel Deep Learning Approach to CSI Feedback Reporting for NR 5G Cellular Systems.” In 2020 IEEE Microwave Theory and Techniques in Wireless Communications (MTTW), 47–52. Riga, Latvia: IEEE, 2020. https://doi.org/10.1109/MTTW51045.2020.9245055.