为什么我的识别准确率比较低?

9 views (last 30 days)
吉永磊
吉永磊 on 25 Apr 2025
Answered: cdarling on 4 Jun 2025
clear; clc; rng(0);
%% 参数设置
N_subcarriers = 64; % 子载波数量
N_symbols = 100; % OFDM符号数
CP_length = 16; % 循环前缀长度
SNR_dB = 30; % 信噪比
mod_types = {'BPSK', 'QPSK', '16QAM', '64QAM', 'PAM4', 'GFSK', 'CPFSK', '8PSK'}; % 调制类型列表
train_samples = 2000; % 每种调制方式的训练样本数
test_samples = 200; % 每种调制方式的测试样本数
%% 1. 生成训练数据集
train_data = {};
train_labels = [];
for mod_idx = 1:length(mod_types)
mod_type = mod_types{mod_idx};
for i = 1:train_samples
[feat, ~] = generate_ofdm_features(mod_type, N_subcarriers, N_symbols, CP_length, SNR_dB);
% 将每个样本的特征从列向量转换为行向量,以适应LSTM输入
train_data{end+1} = feat'; % 每个特征是一个时间序列,添加到元胞数组
train_labels = [train_labels; mod_idx-1]; % 标签: BPSK(0), QPSK(1), 16QAM(2), 64QAM(3)
end
end
%% 2. 数据预处理:将标签转换为分类
train_labels = categorical(train_labels, 0:length(mod_types)-1, mod_types);
%% 3. 生成测试数据集并评估准确度
test_data = {};
test_labels = [];
for mod_idx = 1:length(mod_types)
mod_type = mod_types{mod_idx};
for i = 1:test_samples
[feat, ~] = generate_ofdm_features(mod_type, N_subcarriers, N_symbols, CP_length, SNR_dB);
% 将每个样本的特征从列向量转换为行向量,以适应LSTM输入
test_data{end+1} = feat'; % 每个特征是一个时间序列,添加到元胞数组
test_labels = [test_labels; mod_idx-1];
end
end
% 转换为分类数组并指定类别名称
test_labels = categorical(test_labels, 0:length(mod_types)-1, mod_types);
%% 4. 设计LSTM网络
numFeatures = size(train_data{1}, 1); % 特征维度
numClasses = length(mod_types); % 类别数量
layers = [
sequenceInputLayer(numFeatures) % 特征维度作为LSTM输入
% 特征增强层
fullyConnectedLayer(256)
reluLayer
dropoutLayer(0.3)
% 时序特征提取
lstmLayer(128, 'OutputMode','sequence')
lstmLayer(64, 'OutputMode','last')
% 分类层
fullyConnectedLayer(32)
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer
];
%% 5. 训练LSTM模型
options = trainingOptions('adam', ...
'MaxEpochs',15, ...
'MiniBatchSize',64, ...
'Shuffle', 'every-epoch', ...
'VerboseFrequency', 100, ...
'LearnRateSchedule', 'piecewise', ...
'ExecutionEnvironment', 'gpu',...
'InitialLearnRate', 1e-3, ...
'LearnRateDropFactor', 0.5, ...
'LearnRateDropPeriod', 15,...
'ValidationData', {test_data, test_labels}, ...
'Plots', 'training-progress');
% 训练模型
lstm_net = trainNetwork(train_data, train_labels, layers, options);
%% 6. 进行预测
predicted_labels = classify(lstm_net, test_data);
%% 7. 计算准确度
confusion_mat = confusionmat(test_labels, predicted_labels);
accuracy = sum(diag(confusion_mat)) / sum(confusion_mat(:)) * 100;
fprintf('整体识别准确度: %.2f%%\n', accuracy);
% 绘制混淆矩阵
% 重命名分类标签的类别名称
train_labels = renamecats(train_labels, mod_types);
test_labels = renamecats(test_labels, mod_types);
predicted_labels = renamecats(predicted_labels, mod_types);
figure
cm = confusionchart(test_labels, predicted_labels);
cm.Title = '混淆矩阵';
cm.RowSummary = 'row-normalized';
%% 辅助函数:生成特征向量(支持BPSK/64QAM)
function [features, rx_freq_eq] = generate_ofdm_features(mod_type, N_subcarriers, N_symbols, CP_length, SNR_dB)
persistent fskModulator_gfsk fskModulator_cpfsk;
if isempty(fskModulator_gfsk)
fskModulator_gfsk = comm.CPMModulator(...
'ModulationOrder', 2, ...
'ModulationIndex', 1, ...
'SamplesPerSymbol', 1, ...
'FrequencyPulse', 'Gaussian', ...
'BitInput', true,...
'BandwidthTimeProduct', 0.3);
end
if isempty(fskModulator_cpfsk)
fskModulator_cpfsk = comm.CPFSKModulator(...
'ModulationOrder', 2, ...
'ModulationIndex', 0.5, ...
'BitInput', true,...
'SamplesPerSymbol', 1);
end
% 生成调制数据
switch mod_type
case 'BPSK'
data = randi([0 1], N_subcarriers*N_symbols, 1);
mod_data = pskmod(data, 2, pi); % BPSK调制
case 'QPSK'
data = randi([0 3], N_subcarriers*N_symbols, 1);
mod_data = pskmod(data, 4, pi/4);
case '16QAM'
data = randi([0 15], N_subcarriers*N_symbols, 1);
mod_data = qammod(data, 16, 'UnitAveragePower', true);
case '64QAM'
data = randi([0 63], N_subcarriers*N_symbols, 1);
mod_data = qammod(data, 64, 'UnitAveragePower', true);
case 'PAM4'
data = randi([0 3], N_subcarriers*N_symbols, 1);
mod_data = pammod(data, 4);
case 'GFSK'
data = randi([0 1], N_subcarriers*N_symbols, 1);
mod_data = fskModulator_gfsk(data);
case 'CPFSK'
data = randi([0 1], N_subcarriers*N_symbols, 1);
mod_data = fskModulator_cpfsk(data);
case '8PSK'
data = randi([0 7], N_subcarriers*N_symbols, 1);
mod_data = pskmod(data, 8, pi/8);
end
% OFDM调制
ofdm_symbols = reshape(mod_data, N_subcarriers, N_symbols);
ofdm_time = ifft(ofdm_symbols, N_subcarriers, 1);
ofdm_cp = [ofdm_time(end-CP_length+1:end, :); ofdm_time];
tx_signal = ofdm_cp(:);
% 加噪声
persistent ricianChannel; % 避免重复创建信道对象
if isempty(ricianChannel)
ricianChannel = comm.RicianChannel(...
'SampleRate', 1e6, ...
'PathDelays', [0, 1e-6, 2e-6], ...
'AveragePathGains', [0, -2, -8], ...
'KFactor', 15, ...
'MaximumDopplerShift', 5, ...
'RandomStream', 'mt19937ar with seed', ...
'Seed', randi(1000));
end
tx_signal = ricianChannel(tx_signal); % 通过信道
% 加高斯噪声
rx_signal = awgn(tx_signal, SNR_dB, 'measured');
% 去CP & FFT
rx_symbols = reshape(rx_signal, N_subcarriers+CP_length, N_symbols);
rx_symbols = rx_symbols(CP_length+1:end, :);
rx_freq_eq = fft(rx_symbols, N_subcarriers, 1); % 假设理想均衡
% 提取特征(新增六阶累积量)
symbols = rx_freq_eq(:);
abs_symbols = abs(symbols);
phase_symbols = angle(symbols);
% 新增特征计算
C20 = mean(abs_symbols.^2);
C21 = mean(abs_symbols.^2 .* exp(1i*2*phase_symbols));
C40 = mean(abs_symbols.^4) - 3*C20^2;
C41 = mean(abs_symbols.^4 .* exp(1i*1*phase_symbols)) - 3*C20*C21;
C42 = mean(abs_symbols.^4 .* exp(1i*2*phase_symbols)) - abs(C20)^2;
% 新增六阶累积量
C60 = mean(abs_symbols.^6) - 15*C20*mean(abs_symbols.^4) + 30*C20^3;
% 幅度方差
var_abs = var(abs_symbols);
% 新增统计量
skew = skewness(abs_symbols);
kurt = kurtosis(abs_symbols);
% 组合特征向量(扩展至12维)
features = [real(C20), imag(C20), real(C21), imag(C21), ...
real(C40), imag(C40), real(C41), imag(C41), ...
real(C42), imag(C42), C60, ...
var_abs, skew, kurt];
% 添加归一化
features = (features - mean(features)) ./ std(features);
end

Answers (1)

cdarling
cdarling on 4 Jun 2025
机器学习/深度学习需要有能够区分开的特征,才能学到正确的分类
低维度的数据比较容易理解,比如人的高低,水的温度等,这些只需要一列数据(一个特征)即可区分
要区分一列特征明显数据,可能不需要太多的机器学习方法,按照统计给出一个分界的值(阈值)即可
而要区分更复杂的数据,可能需要做特征,比如时域数据没有频域特征,多维振动没有汇总的位移量,这些需要为机器学习提供,或者在深度学习中有适合生成它的网络结构
比如频域特征可以使用fft计算,或者使用audioFeatureExtractor等方法生成,或者在网络中使用小波散射变换waveletScattering等
有了频域特征,机器学习,或者深度学习中后面的网络结构,就能学习到与频域相关的内容了
除了频率以外,也可能你的数据中需要对其他特征进行学习,那么也同样需要处理这些特征

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!