为什么我的识别准确率比较低?
9 views (last 30 days)
Show older comments
clear; clc; rng(0);
%% 参数设置
N_subcarriers = 64; % 子载波数量
N_symbols = 100; % OFDM符号数
CP_length = 16; % 循环前缀长度
SNR_dB = 30; % 信噪比
mod_types = {'BPSK', 'QPSK', '16QAM', '64QAM', 'PAM4', 'GFSK', 'CPFSK', '8PSK'}; % 调制类型列表
train_samples = 2000; % 每种调制方式的训练样本数
test_samples = 200; % 每种调制方式的测试样本数
%% 1. 生成训练数据集
train_data = {};
train_labels = [];
for mod_idx = 1:length(mod_types)
mod_type = mod_types{mod_idx};
for i = 1:train_samples
[feat, ~] = generate_ofdm_features(mod_type, N_subcarriers, N_symbols, CP_length, SNR_dB);
% 将每个样本的特征从列向量转换为行向量,以适应LSTM输入
train_data{end+1} = feat'; % 每个特征是一个时间序列,添加到元胞数组
train_labels = [train_labels; mod_idx-1]; % 标签: BPSK(0), QPSK(1), 16QAM(2), 64QAM(3)
end
end
%% 2. 数据预处理:将标签转换为分类
train_labels = categorical(train_labels, 0:length(mod_types)-1, mod_types);
%% 3. 生成测试数据集并评估准确度
test_data = {};
test_labels = [];
for mod_idx = 1:length(mod_types)
mod_type = mod_types{mod_idx};
for i = 1:test_samples
[feat, ~] = generate_ofdm_features(mod_type, N_subcarriers, N_symbols, CP_length, SNR_dB);
% 将每个样本的特征从列向量转换为行向量,以适应LSTM输入
test_data{end+1} = feat'; % 每个特征是一个时间序列,添加到元胞数组
test_labels = [test_labels; mod_idx-1];
end
end
% 转换为分类数组并指定类别名称
test_labels = categorical(test_labels, 0:length(mod_types)-1, mod_types);
%% 4. 设计LSTM网络
numFeatures = size(train_data{1}, 1); % 特征维度
numClasses = length(mod_types); % 类别数量
layers = [
sequenceInputLayer(numFeatures) % 特征维度作为LSTM输入
% 特征增强层
fullyConnectedLayer(256)
reluLayer
dropoutLayer(0.3)
% 时序特征提取
lstmLayer(128, 'OutputMode','sequence')
lstmLayer(64, 'OutputMode','last')
% 分类层
fullyConnectedLayer(32)
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer
];
%% 5. 训练LSTM模型
options = trainingOptions('adam', ...
'MaxEpochs',15, ...
'MiniBatchSize',64, ...
'Shuffle', 'every-epoch', ...
'VerboseFrequency', 100, ...
'LearnRateSchedule', 'piecewise', ...
'ExecutionEnvironment', 'gpu',...
'InitialLearnRate', 1e-3, ...
'LearnRateDropFactor', 0.5, ...
'LearnRateDropPeriod', 15,...
'ValidationData', {test_data, test_labels}, ...
'Plots', 'training-progress');
% 训练模型
lstm_net = trainNetwork(train_data, train_labels, layers, options);
%% 6. 进行预测
predicted_labels = classify(lstm_net, test_data);
%% 7. 计算准确度
confusion_mat = confusionmat(test_labels, predicted_labels);
accuracy = sum(diag(confusion_mat)) / sum(confusion_mat(:)) * 100;
fprintf('整体识别准确度: %.2f%%\n', accuracy);
% 绘制混淆矩阵
% 重命名分类标签的类别名称
train_labels = renamecats(train_labels, mod_types);
test_labels = renamecats(test_labels, mod_types);
predicted_labels = renamecats(predicted_labels, mod_types);
figure
cm = confusionchart(test_labels, predicted_labels);
cm.Title = '混淆矩阵';
cm.RowSummary = 'row-normalized';
%% 辅助函数:生成特征向量(支持BPSK/64QAM)
function [features, rx_freq_eq] = generate_ofdm_features(mod_type, N_subcarriers, N_symbols, CP_length, SNR_dB)
persistent fskModulator_gfsk fskModulator_cpfsk;
if isempty(fskModulator_gfsk)
fskModulator_gfsk = comm.CPMModulator(...
'ModulationOrder', 2, ...
'ModulationIndex', 1, ...
'SamplesPerSymbol', 1, ...
'FrequencyPulse', 'Gaussian', ...
'BitInput', true,...
'BandwidthTimeProduct', 0.3);
end
if isempty(fskModulator_cpfsk)
fskModulator_cpfsk = comm.CPFSKModulator(...
'ModulationOrder', 2, ...
'ModulationIndex', 0.5, ...
'BitInput', true,...
'SamplesPerSymbol', 1);
end
% 生成调制数据
switch mod_type
case 'BPSK'
data = randi([0 1], N_subcarriers*N_symbols, 1);
mod_data = pskmod(data, 2, pi); % BPSK调制
case 'QPSK'
data = randi([0 3], N_subcarriers*N_symbols, 1);
mod_data = pskmod(data, 4, pi/4);
case '16QAM'
data = randi([0 15], N_subcarriers*N_symbols, 1);
mod_data = qammod(data, 16, 'UnitAveragePower', true);
case '64QAM'
data = randi([0 63], N_subcarriers*N_symbols, 1);
mod_data = qammod(data, 64, 'UnitAveragePower', true);
case 'PAM4'
data = randi([0 3], N_subcarriers*N_symbols, 1);
mod_data = pammod(data, 4);
case 'GFSK'
data = randi([0 1], N_subcarriers*N_symbols, 1);
mod_data = fskModulator_gfsk(data);
case 'CPFSK'
data = randi([0 1], N_subcarriers*N_symbols, 1);
mod_data = fskModulator_cpfsk(data);
case '8PSK'
data = randi([0 7], N_subcarriers*N_symbols, 1);
mod_data = pskmod(data, 8, pi/8);
end
% OFDM调制
ofdm_symbols = reshape(mod_data, N_subcarriers, N_symbols);
ofdm_time = ifft(ofdm_symbols, N_subcarriers, 1);
ofdm_cp = [ofdm_time(end-CP_length+1:end, :); ofdm_time];
tx_signal = ofdm_cp(:);
% 加噪声
persistent ricianChannel; % 避免重复创建信道对象
if isempty(ricianChannel)
ricianChannel = comm.RicianChannel(...
'SampleRate', 1e6, ...
'PathDelays', [0, 1e-6, 2e-6], ...
'AveragePathGains', [0, -2, -8], ...
'KFactor', 15, ...
'MaximumDopplerShift', 5, ...
'RandomStream', 'mt19937ar with seed', ...
'Seed', randi(1000));
end
tx_signal = ricianChannel(tx_signal); % 通过信道
% 加高斯噪声
rx_signal = awgn(tx_signal, SNR_dB, 'measured');
% 去CP & FFT
rx_symbols = reshape(rx_signal, N_subcarriers+CP_length, N_symbols);
rx_symbols = rx_symbols(CP_length+1:end, :);
rx_freq_eq = fft(rx_symbols, N_subcarriers, 1); % 假设理想均衡
% 提取特征(新增六阶累积量)
symbols = rx_freq_eq(:);
abs_symbols = abs(symbols);
phase_symbols = angle(symbols);
% 新增特征计算
C20 = mean(abs_symbols.^2);
C21 = mean(abs_symbols.^2 .* exp(1i*2*phase_symbols));
C40 = mean(abs_symbols.^4) - 3*C20^2;
C41 = mean(abs_symbols.^4 .* exp(1i*1*phase_symbols)) - 3*C20*C21;
C42 = mean(abs_symbols.^4 .* exp(1i*2*phase_symbols)) - abs(C20)^2;
% 新增六阶累积量
C60 = mean(abs_symbols.^6) - 15*C20*mean(abs_symbols.^4) + 30*C20^3;
% 幅度方差
var_abs = var(abs_symbols);
% 新增统计量
skew = skewness(abs_symbols);
kurt = kurtosis(abs_symbols);
% 组合特征向量(扩展至12维)
features = [real(C20), imag(C20), real(C21), imag(C21), ...
real(C40), imag(C40), real(C41), imag(C41), ...
real(C42), imag(C42), C60, ...
var_abs, skew, kurt];
% 添加归一化
features = (features - mean(features)) ./ std(features);
end
0 Comments
Answers (1)
cdarling
on 4 Jun 2025
机器学习/深度学习需要有能够区分开的特征,才能学到正确的分类
低维度的数据比较容易理解,比如人的高低,水的温度等,这些只需要一列数据(一个特征)即可区分
要区分一列特征明显数据,可能不需要太多的机器学习方法,按照统计给出一个分界的值(阈值)即可
而要区分更复杂的数据,可能需要做特征,比如时域数据没有频域特征,多维振动没有汇总的位移量,这些需要为机器学习提供,或者在深度学习中有适合生成它的网络结构
比如频域特征可以使用fft计算,或者使用audioFeatureExtractor等方法生成,或者在网络中使用小波散射变换waveletScattering等
有了频域特征,机器学习,或者深度学习中后面的网络结构,就能学习到与频域相关的内容了
除了频率以外,也可能你的数据中需要对其他特征进行学习,那么也同样需要处理这些特征
0 Comments
See Also
Categories
Find more on Test and Measurement in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!