Main Content

Compare Agents on the Continuous Cart Pole

This example shows how to create and train frequently used default agents on a continuous action space cart-pole environment. This environment is modeled in MATLAB®, and represents a pole attached to an unactuated joint on a cart, which moves along a frictionless track. The agent can apply a force to the cart and its training goal is to balance the pole upright using minimal control effort. The example plots performance metrics such as the total training time and the total reward for each trained agent. The results that the agents obtain in this environment, with the selected initial conditions and random number generator seed, do not necessarily imply that specific agents are better than others. Also, note that the training times depend on the computer and operating system you use to run the example, and on other processes running in the background. Your training times might differ substantially from the training times shown in the example.

Fix Random Number Stream for Reproducibility

The example code might involve computation of random numbers at various stages. Fixing the random number stream at the beginning of various sections in the example code preserves the random number sequence in the section every time you run it, and increases the likelihood of reproducing the results. For more information, see Results Reproducibility.

Fix the random number stream with seed 0 and random number algorithm Mersenne Twister. For more information on controlling the seed used for random number generation, see rng.

previousRngState = rng(0,"twister");

The output previousRngState is a structure that contains information about the previous state of the stream. You will restore the state at the end of the example.

Continuous Action Space Cart Pole MATLAB Environment

For this example, the reinforcement learning environment is a pole attached to an unactuated revolutionary joint on a cart. The cart has an actuated prismatic joint connected to a one-dimensional frictionless track. The training goal in this environment is to balance the pole by applying forces (actions) to the prismatic joint.

For this environment:

  • The balanced, upright pole position is 0 radians and the downward hanging pole position is pi radians.

  • The pole starts upright with an initial angle between –0.05 radians and 0.05 radians.

  • The force action signal from the agent to the environment is from –10 N to 10 N.

  • The observations from the environment are the position and velocity of the cart, the pole angle (clockwise-positive), and the pole angle derivative.

  • The episode terminates if the pole is more than 12 degrees from vertical or if the cart moves more than 2.4 m from the original position.

  • A reward of +0.5 is provided for every time-step that the pole remains upright. An additional reward is provided based on the distance between the cart and the origin. A penalty of –50 is applied when the pole falls.

For more information on this model, see Load Predefined Control System Environments.

Create Environment Object

Create a predefined environment object for the continuous cart-pole environment.

env = rlPredefinedEnv("CartPole-Continuous")
env = 
  CartPoleContinuousAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -50
                    State: [4×1 double]

The environment reset function initializes (randomly) and returns the environment state (linear and angular positions and velocities).

reset(env)
ans = 4×1

         0
         0
    0.0315
         0

You can visualize the cart-pole system using the plot function during training or simulation.

plot(env)

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

Obtain the observation and action information for later use when creating agents.

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlNumericSpec with properties:

     LowerLimit: -10
     UpperLimit: 10
           Name: "CartPole Action"
    Description: [0×0 string]
      Dimension: [1 1]
       DataType: "double"

Configure Training Options for All Agents

Set up an evaluator object to evaluate the agent 10 times without exploration every 100 training episodes.

evl = rlEvaluator(NumEpisodes=10,EvaluationFrequency=100);

Create a training options object. For this example, use the following options.

  • Run each training episode for a maximum of 5000 episodes, with each episode lasting at most 500 time steps.

  • To have a better insight on the agent's behavior during training, plot the training progress (default option). If you want to achieve faster training times, set the Plots option to none.

  • Stop training when the average cumulative reward over the evaluation episodes is greater than 480. At this point, the agent can mostly control the position of the pole.

trainOpts = rlTrainingOptions(...
    MaxEpisodes=5000, ...
    MaxStepsPerEpisode=500, ...
    StopTrainingCriteria="EvaluationStatistic",...
    StopTrainingValue=480);

For more information on training options, see rlTrainingOptions.

To simulate the trained agent, create a simulation options object and configure it to simulate for 500 steps.

simOptions = rlSimulationOptions(MaxSteps=500);

For more information on simulation options, see rlSimulationOptions.

Create, Train, and Simulate a PG Agent

The actor and critic networks are initialized randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.

rng(0,"twister")

First, create a default rlPGAgent object using the environment specification objects.

pgAgent = rlPGAgent(obsInfo,actInfo);

Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.

pgAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3;
pgAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3;
pgAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
pgAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;

Set the entropy loss weight to increase exploration.

pgAgent.AgentOptions.EntropyLossWeight = 0.005;

Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % To avoid plotting in training, recreate the environment.
    env = rlPredefinedEnv("CartPole-Continuous");
    % Train the agent. Save the final agent and training results.
    tic
    pgTngRes = train(pgAgent,env,trainOpts,Evaluator=evl);
    pgTngTime = toc;
    % Extract the number of training episodes and the number of total steps.
    pgTngEps = pgTngRes.EpisodeIndex(end);
    pgTngSteps = sum(pgTngRes.TotalAgentSteps);
    % Uncomment to save the trained agent and the training metrics.
    % save("ccpBchPGAgent.mat", ...
    %    "pgAgent","pgTngEps","pgTngSteps","pgTngTime")
else
    % Load the pretrained agent and results for the example.
    load("ccpBchPGAgent.mat", ...
        "pgAgent","pgTngEps","pgTngSteps","pgTngTime")
end

For the PG agent, the training converges to a solution after 700 episodes. You can check the trained agent within the cart-pole environment.

Ensure reproducibility of the simulation by fixing the seed used for random number generation.

rng(0,"twister")

Visualize the environment.

plot(env)

Configure the agent to use a greedy policy (no exploration) in simulation.

pgAgent.UseExplorationPolicy = false;

Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim.

experience = sim(env,pgAgent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

pgTotalRwd = sum(experience.Reward)
pgTotalRwd = 
495.5890

The trained PG agent stabilizes the pole in the upright position.

Create, Train, and Simulate an AC Agent

The actor and critic networks are initialized randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.

rng(0,"twister")

First, create a default rlACAgent object using the environment specification objects.

acAgent = rlACAgent(obsInfo,actInfo);

Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.

acAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3;
acAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3;
acAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
acAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;

Set the entropy loss weight to increase exploration.

acAgent.AgentOptions.EntropyLossWeight = 0.005;

Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % To avoid plotting in training, recreate the environment.
    env = rlPredefinedEnv("CartPole-Continuous");
    % Train the agent. Save the final agent and training results.
    tic
    acTngRes = train(acAgent,env,trainOpts,Evaluator=evl);
    acTngTime = toc;
    % Extract the number of training episodes and the number of total steps.
    acTngEps = acTngRes.EpisodeIndex(end);
    acTngSteps = sum(acTngRes.TotalAgentSteps);
    % Uncomment to save the trained agent and the training metrics.
    % save("ccpBchACAgent.mat", ...
    %     "acAgent","acTngEps","acTngSteps","acTngTime")
else
    % Load the pretrained agent and results for the example.
    load("ccpBchACAgent.mat", ...
        "acAgent","acTngEps","acTngSteps","acTngTime")
end

For the AC agent, the training converges to a solution after 1100 episodes. You can check the trained agent within the cart-pole environment.

Ensure reproducibility of the simulation by fixing the seed used for random number generation.

rng(0,"twister")

Visualize the environment.

plot(env)

Configure the agent to use a greedy policy (no exploration) in simulation.

acAgent.UseExplorationPolicy = false;

Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim.

experience = sim(env,acAgent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

acTotalRwd = sum(experience.Reward)
acTotalRwd = 
484.5387

The trained AC agent stabilizes the pole in the upright position.

Create, Train, and Simulate a PPO Agent

The actor and critic networks are initialized randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.

rng(0,"twister")

First, create a default rlPPOAgent object using the environment specification objects.

ppoAgent = rlPPOAgent(obsInfo,actInfo);

Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.

ppoAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3;
ppoAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3;
ppoAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
ppoAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;

Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % To avoid plotting in training, recreate the environment.
    env = rlPredefinedEnv("CartPole-Continuous");
    % Train the agent. Save the final agent and training results.
    tic
    ppoTngRes = train(ppoAgent,env,trainOpts,Evaluator=evl);
    ppoTngTime = toc;
    % Extract the number of training episodes and the number of total steps.
    ppoTngEps = ppoTngRes.EpisodeIndex(end);
    ppoTngSteps = sum(ppoTngRes.TotalAgentSteps);
    % Uncomment to save the trained agent and the training metrics.
    % save("ccpBchPPOAgent.mat", ...
    %     "ppoAgent","ppoTngEps","ppoTngSteps","ppoTngTime")
else
    % Load the pretrained agent and results for the example.
    load("ccpBchPPOAgent.mat", ...
        "ppoAgent","ppoTngEps","ppoTngSteps","ppoTngTime")
end

For the PPO Agent, the training converges to a solution after 300 episodes. You can check the trained agent within the cart-pole environment.

Ensure reproducibility of the simulation by fixing the seed used for random number generation.

rng(0,"twister")

Visualize the environment.

plot(env)

Configure the agent to use a greedy policy (no exploration) in simulation.

ppoAgent.UseExplorationPolicy = false;

Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim.

experience = sim(env,ppoAgent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

ppoTotalRwd = sum(experience.Reward)
ppoTotalRwd = 
494.6149

The trained PPO agent stabilizes the pole in the upright position.

Create, Train, and Simulate a DDPG Agent

The actor and critic networks are initialized randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.

rng(0,"twister")

First, create a default rlDDPGAgent object using the environment specification objects.

ddpgAgent = rlDDPGAgent(obsInfo,actInfo);

Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.

ddpgAgent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3;
ddpgAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3;
ddpgAgent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
ddpgAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;

Use a larger experience buffer to store more experiences, therefore decreasing the likelihood of catastrophic forgetting.

ddpgAgent.AgentOptions.ExperienceBufferLength = 1e6;

Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % To avoid plotting in training, recreate the environment.
    env = rlPredefinedEnv("CartPole-Continuous");
    % Train the agent. Save the final agent and training results.
    tic
    ddpgTngRes = train(ddpgAgent,env,trainOpts,Evaluator=evl);
    ddpgTngTime = toc;
    % Extract the number of training episodes and the number of total steps.
    ddpgTngEps = ddpgTngRes.EpisodeIndex(end);
    ddpgTngSteps = sum(ddpgTngRes.TotalAgentSteps);
    % Uncomment to save the trained agent and the training metrics.
    % save("ccpBchDDPGAgent.mat", ...
    %     "ddpgAgent","ddpgTngEps","ddpgTngSteps","ddpgTngTime")
else
    % Load the pretrained agent and results for the example.
    load("ccpBchDDPGAgent.mat", ...
        "ddpgAgent","ddpgTngEps","ddpgTngSteps","ddpgTngTime")
end

For the DDPG Agent, the training converges to a solution after 2100 episodes. You can check the trained agent within the cart-pole environment.

Ensure reproducibility of the simulation by fixing the seed used for random number generation.

rng(0,"twister")

Visualize the environment.

plot(env)

Configure the agent to use a greedy policy (no exploration) in simulation.

ddpgAgent.UseExplorationPolicy = false;

Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim.

experience = sim(env,ddpgAgent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

ddpgTotalRwd = sum(experience.Reward)
ddpgTotalRwd = 
489.5395

The trained DDPG agent stabilizes the pole in the upright position.

Create, Train, and Simulate a TD3 Agent

The actor and critic networks are initialized randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.

rng(0,"twister")

First, create a default rlDDPGAgent object using the environment specification objects.

td3Agent = rlTD3Agent(obsInfo,actInfo);

Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.

td3Agent.AgentOptions.CriticOptimizerOptions(1).LearnRate = 1e-3;
td3Agent.AgentOptions.CriticOptimizerOptions(2).LearnRate = 1e-3;
td3Agent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3;
td3Agent.AgentOptions.CriticOptimizerOptions(1).GradientThreshold = 1;
td3Agent.AgentOptions.CriticOptimizerOptions(2).GradientThreshold = 1;
td3Agent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;

Use a larger experience buffer to store more experiences, therefore decreasing the likelihood of catastrophic forgetting.

td3Agent.AgentOptions.ExperienceBufferLength = 1e6;

Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % To avoid plotting in training, recreate the environment.
    env = rlPredefinedEnv("CartPole-Continuous");
    % Train the agent. Save the final agent and training results.
    tic
    td3TngRes = train(td3Agent,env,trainOpts,Evaluator=evl);
    td3TngTime = toc;
    % Extract the number of training episodes and the number of total steps.
    td3TngEps = td3TngRes.EpisodeIndex(end);
    td3TngSteps = sum(td3TngRes.TotalAgentSteps);
    % Uncomment to save the trained agent and the training metrics.
    % save("ccpBchTD3Agent.mat", ...
    %    "td3Agent","td3TngEps","td3TngSteps","td3TngTime")
else
    % Load the pretrained agent and results for the example.
    load("ccpBchTD3Agent.mat", ...
        "td3Agent","td3TngEps","td3TngSteps","td3TngTime")
end

For the TD3 Agent, the training converges to a solution after 800 episodes. You can check the trained agent within the cart-pole environment.

Ensure reproducibility of the simulation by fixing the seed used for random number generation.

rng(0,"twister")

Visualize the environment.

plot(env)

Configure the agent to use a greedy policy (no exploration) in simulation.

td3Agent.UseExplorationPolicy = false;

Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim.

experience = sim(env,td3Agent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

td3TotalRwd = sum(experience.Reward)
td3TotalRwd = 
486.8400

The trained TD3 agent stabilizes the pole in the upright position.

Create, Train, and Simulate a SAC Agent

The actor and critic networks are initialized randomly. Ensure reproducibility of the section by fixing the seed used for random number generation.

rng(0,"twister")

First, create a default rlSACAgent object using the environment specification objects.

sacAgent = rlSACAgent(obsInfo,actInfo);

Set a lower learning rate and a lower gradient threshold to promote a smoother (though possibly slower) training.

sacAgent.AgentOptions.CriticOptimizerOptions(1).LearnRate = 1e-3;
sacAgent.AgentOptions.CriticOptimizerOptions(2).LearnRate = 1e-3;
sacAgent.AgentOptions.ActorOptimizerOptions.LearnRate = 1e-3;
sacAgent.AgentOptions.CriticOptimizerOptions(1).GradientThreshold = 1;
sacAgent.AgentOptions.CriticOptimizerOptions(2).GradientThreshold = 1;
sacAgent.AgentOptions.ActorOptimizerOptions.GradientThreshold = 1;

Use a larger experience buffer to store more experiences, therefore decreasing the likelihood of catastrophic forgetting.

sacAgent.AgentOptions.ExperienceBufferLength = 1e6;

Train the agent, passing the agent, the environment, and the previously defined training options and evaluator objects to train. Training is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % To avoid plotting in training, recreate the environment.
    env = rlPredefinedEnv("CartPole-Continuous");
    % Train the agent. Save the final agent and training results.
    tic
    sacTngRes = train(sacAgent,env,trainOpts,Evaluator=evl);
    sacTngTime = toc;
    % Extract the number of training episodes and the number of total steps.
    sacTngEps = sacTngRes.EpisodeIndex(end);
    sacTngSteps = sum(sacTngRes.TotalAgentSteps);
    % Uncomment to save the trained agent and the training metrics.
    % save("ccpBchSACAgent.mat", ...
    %    "sacAgent","sacTngEps","sacTngSteps","sacTngTime")
else
    % Load the pretrained agent and results for the example.
    load("ccpBchSACAgent.mat", ...
        "sacAgent","sacTngEps","sacTngSteps","sacTngTime")
end

For the SAC Agent, the training converges to a solution after 1000 episodes. You can check the trained agent within the cart-pole environment.

Ensure reproducibility of the simulation by fixing the seed used for random number generation.

rng(0,"twister")

Visualize the environment.

plot(env)

Configure the agent to use a greedy policy (no exploration) in simulation.

sacAgent.UseExplorationPolicy = false;

Simulate the environment with the trained agent for 500 steps and display the total reward. For more information on agent simulation, see sim.

experience = sim(env,sacAgent,simOptions);

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

sacTotalRwd = sum(experience.Reward)
sacTotalRwd = 
482.7768

The trained SAC agent stabilizes the pole in the upright position.

Plot Training and Simulation Metrics

For each agent, collect the total reward from the final simulation episode, the number of training episodes, the total number of agent steps, and the total training time as shown in the Reinforcement Learning Training Monitor.

simReward = [
    pgTotalRwd
    acTotalRwd
    ppoTotalRwd
    ddpgTotalRwd
    td3TotalRwd
    sacTotalRwd
    ];

tngEpisodes = [
    pgTngEps
    acTngEps
    ppoTngEps
    ddpgTngEps
    td3TngEps
    sacTngEps
    ];

tngSteps = [
    pgTngSteps
    acTngSteps
    ppoTngSteps
    ddpgTngSteps
    td3TngSteps
    sacTngSteps
    ];

tngTime = [
    pgTngTime
    acTngTime
    ppoTngTime
    ddpgTngTime
    td3TngTime
    sacTngTime
    ];

Plot the simulation reward, number of training episodes, number of training steps (that is the number of interactions between the agent and the environment) and the training time. Scale the data by the factor [1 5 1e6 10] for better visualization.

bar([simReward,tngEpisodes,tngSteps,tngTime]./[1 5 1e6 10])
xticklabels(["PG" "AC" "PPO" "DDPG" "TD3" "SAC"])
legend(["Simulation Reward","Training Episodes","Training Steps","Training Time"], ...
    "Location","northwest")

Figure contains an axes object. The axes object contains 4 objects of type bar. These objects represent Simulation Reward, Training Episodes, Training Steps, Training Time.

The plot shows that, for this environment, and with the used random number generator seed and initial conditions, all agents perform satisfactorily in terms of total simulation reward, with PPO using much less training time (because it is fast as a training algorithm and because it converges in just 300 episodes). DDPG takes considerably more time than PPO and AC, mostly because it needs many more training steps to converge. TD3 needs more time than AC despite taking less training steps. This happens because TD3 needs to calculate more gradients than AC. Similarly, SAC also takes almost as much time as DDPG to converge despite needing less training steps, and this also happens largely because SAC needs to calculate more gradients. With a different random seed, the initial agent networks would be different, and therefore, convergence results might be different. For more information on the relative strengths and weaknesses of each agent, see Reinforcement Learning Agents.

Save all the variables created in this example, including the training results, for later use.

% Uncomment to save all the workspace variables
% save ccpAllVars.mat

Restore the random number stream using the information stored in previousRngState.

rng(previousRngState);

See Also

Functions

Objects

Topics