how can I customize the training loop for a DQN agent?
Show older comments
Following this question https://it.mathworks.com/matlabcentral/answers/2142801-how-to-normalize-the-rewards-in-rl, I wanted to encompass a reward normalization of my code in qhich I train a DQN agent. However, the only documentation about custom training is on an actor based policy https://www.mathworks.com/help/releases/R2024a/reinforcement-learning/ug/train-reinforcement-learning-policy-using-custom-training.html. so I would like to know how can I adapt this code to a DQN algorithm, in particular the steps 5, 6 and 7 of the example code
% Learn the set of aggregated trajectories.
if mod(episodeCt,trajectoriesForLearning) == 0
% Get the indices of each action taken in the action buffer.
actionIndicationMatrix = dlarray(single(actionBuffer(:,:) == actionSet));
% 5. Compute the gradient of the loss with respect to the actor
% learnable parameters.
actorGradient = dlfeval(actorGradFcn,...
actor,{observationBuffer},actionIndicationMatrix,returnBuffer,maskBuffer);
% 6. Update the actor using the computed gradients.
[actor,actorOptimizer] = update( ...
actorOptimizer, ...
actor, ...
actorGradient);
% Update the policy from the actor
policy = rlStochasticActorPolicy(actor);
% flush the mask and reward buffer
maskBuffer(:) = 0;
rewardBuffer(:) = 0;
end
Answers (1)
Darshak
on 3 Feb 2025
Hi Camilla,
I came across another documentation for custom training loops for reinforcement learning when I was looking into custom training myself, it has a DQN agent set up too, you can refer to it from here: https://www.mathworks.com/help/reinforcement-learning/ug/model-based-reinforcement-learning-using-custom-training-loop.html,
The code below are modifications, just for an example case, I would make for implementing custom training loop for a DQN agent-based system based on the sample code you shared.
rewardMean = 0;
rewardStd = 1;
alpha = 0.01; % Learning rate for running mean and std
for episodeCt = 1:numEpisodes
episodeOffset = mod(episodeCt-1,trajectoriesForLearning)*maxStepsPerEpisode;
% Reset the environment at the start of the episode
obs = reset(env);
episodeReward = zeros(maxStepsPerEpisode,1);
for stepCt = 1:maxStepsPerEpisode
% Compute an action using the policy based on the current observation.
action = getAction(policy, {obs});
% Apply the action to the environment and obtain the resulting observation and reward.
[nextObs, reward, isdone] = step(env, action{1});
% Normalize the reward
rewardMean = (1-alpha)*rewardMean + alpha*reward;
rewardStd = sqrt((1-alpha)*(rewardStd^2) + alpha*(reward-rewardMean)^2);
normalizedReward = (reward - rewardMean) / (rewardStd + eps); % Avoid division by zero
% Store the action, observation, and normalized reward experiences in their buffers.
j = episodeOffset + stepCt;
observationBuffer(:,:,j) = obs;
actionBuffer(:,:,j) = action{1};
rewardBuffer(:,j) = normalizedReward; % Use normalized reward
maskBuffer(:,j) = 1;
obs = nextObs;
% Stop if a terminal condition is reached.
if isdone
break;
end
end
% Update the return buffer and the cumulative reward for this episode.
episodeElements = episodeOffset + (1:maxStepsPerEpisode);
episodeCumulativeReward = extractdata(sum(rewardBuffer(episodeElements)));
% Compute the discounted future reward.
returnBuffer(episodeElements) = rewardBuffer(episodeElements)*discountWeights;
% Learn the set of aggregated trajectories.
if mod(episodeCt,trajectoriesForLearning) == 0
% Get the indices of each action taken in the action buffer.
actionIndicationMatrix = dlarray(single(actionBuffer(:,:) == actionSet));
% Compute the gradient of the loss with respect to the actor learnable parameters.
actorGradient = dlfeval(actorGradFcn, actor, {observationBuffer}, actionIndicationMatrix, returnBuffer, maskBuffer);
% Update the actor using the computed gradients.
[actor,actorOptimizer] = update(actorOptimizer, actor, actorGradient);
% Update the policy from the actor
policy = rlStochasticActorPolicy(actor);
% Flush the mask and reward buffer
maskBuffer(:) = 0;
rewardBuffer(:) = 0;
end
The given documentation also has other things which I used in my case, and it might be helpful to you too, so I suggest you go through that page.
I hope this helps you.
Categories
Find more on Reinforcement Learning in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!