From 3e0b6a219b03c4d3d708498c2bc399abaceddf40 Mon Sep 17 00:00:00 2001 From: Michal Gregor Date: Tue, 11 Aug 2020 14:29:53 +0200 Subject: [PATCH] Added support for specifying custom models under all existing presets. --- all/presets/atari/a2c.py | 13 ++++++++++--- all/presets/atari/c51.py | 5 ++++- all/presets/atari/ddqn.py | 5 ++++- all/presets/atari/dqn.py | 5 ++++- all/presets/atari/ppo.py | 13 ++++++++++--- all/presets/atari/rainbow.py | 5 ++++- all/presets/atari/vac.py | 13 ++++++++++--- all/presets/atari/vpg.py | 13 ++++++++++--- all/presets/atari/vqn.py | 5 ++++- all/presets/atari/vsarsa.py | 5 ++++- all/presets/classic_control/a2c.py | 13 ++++++++++--- all/presets/classic_control/c51.py | 7 +++++-- all/presets/classic_control/ddqn.py | 5 ++++- all/presets/classic_control/dqn.py | 5 ++++- all/presets/classic_control/ppo.py | 13 ++++++++++--- all/presets/classic_control/rainbow.py | 5 ++++- all/presets/classic_control/vac.py | 13 ++++++++++--- all/presets/classic_control/vpg.py | 13 ++++++++++--- all/presets/classic_control/vqn.py | 5 ++++- all/presets/classic_control/vsarsa.py | 5 ++++- all/presets/continuous/ddpg.py | 9 +++++++-- all/presets/continuous/ppo.py | 5 ++++- all/presets/continuous/sac.py | 17 +++++++++++++---- 23 files changed, 153 insertions(+), 44 deletions(-) diff --git a/all/presets/atari/a2c.py b/all/presets/atari/a2c.py index 4cc27893..cdd71451 100644 --- a/all/presets/atari/a2c.py +++ b/all/presets/atari/a2c.py @@ -23,6 +23,10 @@ def a2c( # Batch settings n_envs=16, n_steps=5, + # Model construction + feature_model_constructor=nature_features, + value_model_constructor=nature_value_head, + policy_model_constructor=nature_policy_head ): """ A2C Atari preset. @@ -39,14 +43,17 @@ def a2c( value_loss_scaling (float): Coefficient for the value function loss. n_envs (int): Number of parallel environments. n_steps (int): Length of each rollout. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _a2c(envs, writer=DummyWriter()): env = envs[0] final_anneal_step = last_frame / (n_steps * n_envs * 4) - value_model = nature_value_head().to(device) - policy_model = nature_policy_head(env).to(device) - feature_model = nature_features().to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(env).to(device) + feature_model = feature_model_constructor().to(device) feature_optimizer = Adam(feature_model.parameters(), lr=lr, eps=eps) value_optimizer = Adam(value_model.parameters(), lr=lr, eps=eps) diff --git a/all/presets/atari/c51.py b/all/presets/atari/c51.py index ee22cbff..d09884b8 100644 --- a/all/presets/atari/c51.py +++ b/all/presets/atari/c51.py @@ -31,6 +31,8 @@ def c51( atoms=51, v_min=-10, v_max=10, + # Model construction + model_constructor=nature_c51 ): """ C51 Atari preset. @@ -53,13 +55,14 @@ def c51( the distributional value function. v_min (int): The expected return corresponding to the smallest atom. v_max (int): The expected return correspodning to the larget atom. + model_constructor (function): The function used to construct the neural model. """ def _c51(env, writer=DummyWriter()): action_repeat = 4 last_timestep = last_frame / action_repeat last_update = (last_timestep - replay_start_size) / update_frequency - model = nature_c51(env, atoms=atoms).to(device) + model = model_constructor(env, atoms=atoms).to(device) optimizer = Adam( model.parameters(), lr=lr, diff --git a/all/presets/atari/ddqn.py b/all/presets/atari/ddqn.py index 4baccb59..fb23027a 100644 --- a/all/presets/atari/ddqn.py +++ b/all/presets/atari/ddqn.py @@ -32,6 +32,8 @@ def ddqn( # Prioritized replay settings alpha=0.5, beta=0.5, + # Model construction + model_constructor=nature_ddqn ): """ Dueling Double DQN with Prioritized Experience Replay (PER). @@ -55,6 +57,7 @@ def ddqn( (0 = no prioritization, 1 = full prioritization) beta (float): The strength of the importance sampling correction for prioritized experience replay. (0 = no correction, 1 = full correction) + model_constructor (function): The function used to construct the neural model. """ def _ddqn(env, writer=DummyWriter()): action_repeat = 4 @@ -62,7 +65,7 @@ def _ddqn(env, writer=DummyWriter()): last_update = (last_timestep - replay_start_size) / update_frequency final_exploration_step = final_exploration_frame / action_repeat - model = nature_ddqn(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam( model.parameters(), lr=lr, diff --git a/all/presets/atari/dqn.py b/all/presets/atari/dqn.py index 1ddf7a5e..53b1209a 100644 --- a/all/presets/atari/dqn.py +++ b/all/presets/atari/dqn.py @@ -30,6 +30,8 @@ def dqn( initial_exploration=1., final_exploration=0.01, final_exploration_frame=4000000, + # Model construction + model_constructor=nature_dqn ): """ DQN Atari preset. @@ -49,6 +51,7 @@ def dqn( decayed until final_exploration_frame. final_exploration (int): Final probability of choosing a random action. final_exploration_frame (int): The frame where the exploration decay stops. + model_constructor (function): The function used to construct the neural model. """ def _dqn(env, writer=DummyWriter()): action_repeat = 4 @@ -56,7 +59,7 @@ def _dqn(env, writer=DummyWriter()): last_update = (last_timestep - replay_start_size) / update_frequency final_exploration_step = final_exploration_frame / action_repeat - model = nature_dqn(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam( model.parameters(), diff --git a/all/presets/atari/ppo.py b/all/presets/atari/ppo.py index 60a8d2bd..aba6720c 100644 --- a/all/presets/atari/ppo.py +++ b/all/presets/atari/ppo.py @@ -30,6 +30,10 @@ def ppo( n_steps=128, # GAE settings lam=0.95, + # Model construction + feature_model_constructor=nature_features, + value_model_constructor=nature_value_head, + policy_model_constructor=nature_policy_head ): """ PPO Atari preset. @@ -51,6 +55,9 @@ def ppo( n_envs (int): Number of parallel actors. n_steps (int): Length of each rollout. lam (float): The Generalized Advantage Estimate (GAE) decay parameter. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _ppo(envs, writer=DummyWriter()): env = envs[0] @@ -60,9 +67,9 @@ def _ppo(envs, writer=DummyWriter()): # with n_envs and 4 frames per step final_anneal_step = last_frame * epochs * minibatches / (n_steps * n_envs * 4) - value_model = nature_value_head().to(device) - policy_model = nature_policy_head(env).to(device) - feature_model = nature_features().to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(env).to(device) + feature_model = feature_model_constructor().to(device) feature_optimizer = Adam( feature_model.parameters(), lr=lr, eps=eps diff --git a/all/presets/atari/rainbow.py b/all/presets/atari/rainbow.py index 051485ba..8004035d 100644 --- a/all/presets/atari/rainbow.py +++ b/all/presets/atari/rainbow.py @@ -38,6 +38,8 @@ def rainbow( v_max=10, # Noisy Nets sigma=0.5, + # Model construction + model_constructor=nature_rainbow ): """ Rainbow Atari Preset. @@ -66,13 +68,14 @@ def rainbow( v_min (int): The expected return corresponding to the smallest atom. v_max (int): The expected return correspodning to the larget atom. sigma (float): Initial noisy network noise. + model_constructor (function): The function used to construct the neural model. """ def _rainbow(env, writer=DummyWriter()): action_repeat = 4 last_timestep = last_frame / action_repeat last_update = (last_timestep - replay_start_size) / update_frequency - model = nature_rainbow(env, atoms=atoms, sigma=sigma).to(device) + model = model_constructor(env, atoms=atoms, sigma=sigma).to(device) optimizer = Adam(model.parameters(), lr=lr, eps=eps) q = QDist( model, diff --git a/all/presets/atari/vac.py b/all/presets/atari/vac.py index 9ef50e8b..1418559f 100644 --- a/all/presets/atari/vac.py +++ b/all/presets/atari/vac.py @@ -20,6 +20,10 @@ def vac( value_loss_scaling=0.25, # Parallel actors n_envs=16, + # Model construction + feature_model_constructor=nature_features, + value_model_constructor=nature_value_head, + policy_model_constructor=nature_policy_head ): """ Vanilla Actor-Critic Atari preset. @@ -35,11 +39,14 @@ def vac( Set to 0 to disable. value_loss_scaling (float): Coefficient for the value function loss. n_envs (int): Number of parallel environments. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _vac(envs, writer=DummyWriter()): - value_model = nature_value_head().to(device) - policy_model = nature_policy_head(envs[0]).to(device) - feature_model = nature_features().to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(envs[0]).to(device) + feature_model = feature_model_constructor().to(device) value_optimizer = Adam(value_model.parameters(), lr=lr_v, eps=eps) policy_optimizer = Adam(policy_model.parameters(), lr=lr_pi, eps=eps) diff --git a/all/presets/atari/vpg.py b/all/presets/atari/vpg.py index 961bc7c9..62a18d04 100644 --- a/all/presets/atari/vpg.py +++ b/all/presets/atari/vpg.py @@ -20,6 +20,10 @@ def vpg( clip_grad=0.5, value_loss_scaling=0.25, min_batch_size=1000, + # Model construction + feature_model_constructor=nature_features, + value_model_constructor=nature_value_head, + policy_model_constructor=nature_policy_head ): """ Vanilla Policy Gradient Atari preset. @@ -35,13 +39,16 @@ def vpg( value_loss_scaling (float): Coefficient for the value function loss. min_batch_size (int): Continue running complete episodes until at least this many states have been seen since the last update. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ final_anneal_step = last_frame / (min_batch_size * 4) def _vpg_atari(env, writer=DummyWriter()): - value_model = nature_value_head().to(device) - policy_model = nature_policy_head(env).to(device) - feature_model = nature_features().to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(env).to(device) + feature_model = feature_model_constructor().to(device) feature_optimizer = Adam(feature_model.parameters(), lr=lr, eps=eps) value_optimizer = Adam(value_model.parameters(), lr=lr, eps=eps) diff --git a/all/presets/atari/vqn.py b/all/presets/atari/vqn.py index 1d0a0903..1e261fff 100644 --- a/all/presets/atari/vqn.py +++ b/all/presets/atari/vqn.py @@ -20,6 +20,8 @@ def vqn( final_exploration_frame=1000000, # Parallel actors n_envs=64, + # Model construction + model_constructor=nature_ddqn ): """ Vanilla Q-Network Atari preset. @@ -34,13 +36,14 @@ def vqn( final_exploration (int): Final probability of choosing a random action. final_exploration_frame (int): The frame where the exploration decay stops. n_envs (int): Number of parallel environments. + model_constructor (function): The function used to construct the neural model. """ def _vqn(envs, writer=DummyWriter()): action_repeat = 4 final_exploration_timestep = final_exploration_frame / action_repeat env = envs[0] - model = nature_ddqn(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam(model.parameters(), lr=lr, eps=eps) q = QNetwork( model, diff --git a/all/presets/atari/vsarsa.py b/all/presets/atari/vsarsa.py index 80ee5991..6fefcc6c 100644 --- a/all/presets/atari/vsarsa.py +++ b/all/presets/atari/vsarsa.py @@ -20,6 +20,8 @@ def vsarsa( initial_exploration=1., # Parallel actors n_envs=64, + # Model construction + model_constructor=nature_ddqn ): """ Vanilla SARSA Atari preset. @@ -34,13 +36,14 @@ def vsarsa( final_exploration (int): Final probability of choosing a random action. final_exploration_frame (int): The frame where the exploration decay stops. n_envs (int): Number of parallel environments. + model_constructor (function): The function used to construct the neural model. """ def _vsarsa(envs, writer=DummyWriter()): action_repeat = 4 final_exploration_timestep = final_exploration_frame / action_repeat env = envs[0] - model = nature_ddqn(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam(model.parameters(), lr=lr, eps=eps) q = QNetwork( model, diff --git a/all/presets/classic_control/a2c.py b/all/presets/classic_control/a2c.py index 0637dad3..5480459e 100644 --- a/all/presets/classic_control/a2c.py +++ b/all/presets/classic_control/a2c.py @@ -18,6 +18,10 @@ def a2c( # Batch settings n_envs=4, n_steps=32, + # Model construction + feature_model_constructor=fc_relu_features, + value_model_constructor=fc_value_head, + policy_model_constructor=fc_policy_head ): """ A2C classic control preset. @@ -30,12 +34,15 @@ def a2c( entropy_loss_scaling (float): Coefficient for the entropy term in the total loss. n_envs (int): Number of parallel environments. n_steps (int): Length of each rollout. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _a2c(envs, writer=DummyWriter()): env = envs[0] - feature_model = fc_relu_features(env).to(device) - value_model = fc_value_head().to(device) - policy_model = fc_policy_head(env).to(device) + feature_model = feature_model_constructor(env).to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(env).to(device) feature_optimizer = Adam(feature_model.parameters(), lr=lr) value_optimizer = Adam(value_model.parameters(), lr=lr) diff --git a/all/presets/classic_control/c51.py b/all/presets/classic_control/c51.py index 6eedf661..793bd140 100644 --- a/all/presets/classic_control/c51.py +++ b/all/presets/classic_control/c51.py @@ -26,7 +26,9 @@ def c51( # Distributional RL atoms=101, v_min=-100, - v_max=100 + v_max=100, + # Model construction + model_constructor=fc_relu_dist_q ): """ C51 classic control preset. @@ -47,9 +49,10 @@ def c51( the distributional value function. v_min (int): The expected return corresponding to the smallest atom. v_max (int): The expected return correspodning to the larget atom. + model_constructor (function): The function used to construct the neural model. """ def _c51(env, writer=DummyWriter()): - model = fc_relu_dist_q(env, atoms=atoms).to(device) + model = model_constructor(env, atoms=atoms).to(device) optimizer = Adam(model.parameters(), lr=lr) q = QDist( model, diff --git a/all/presets/classic_control/ddqn.py b/all/presets/classic_control/ddqn.py index 4c1fde34..75cc76d6 100644 --- a/all/presets/classic_control/ddqn.py +++ b/all/presets/classic_control/ddqn.py @@ -28,6 +28,8 @@ def ddqn( # Prioritized replay settings alpha=0.2, beta=0.6, + # Model construction + model_constructor=dueling_fc_relu_q ): """ Dueling Double DQN with Prioritized Experience Replay (PER). @@ -50,9 +52,10 @@ def ddqn( (0 = no prioritization, 1 = full prioritization) beta (float): The strength of the importance sampling correction for prioritized experience replay. (0 = no correction, 1 = full correction) + model_constructor (function): The function used to construct the neural model. """ def _ddqn(env, writer=DummyWriter()): - model = dueling_fc_relu_q(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam(model.parameters(), lr=lr) q = QNetwork( model, diff --git a/all/presets/classic_control/dqn.py b/all/presets/classic_control/dqn.py index 7e511de0..0f6c5b97 100644 --- a/all/presets/classic_control/dqn.py +++ b/all/presets/classic_control/dqn.py @@ -25,6 +25,8 @@ def dqn( initial_exploration=1., final_exploration=0., final_exploration_frame=10000, + # Model construction + model_constructor=fc_relu_q ): """ DQN classic control preset. @@ -42,9 +44,10 @@ def dqn( decayed until final_exploration_frame. final_exploration (int): Final probability of choosing a random action. final_exploration_frame (int): The frame where the exploration decay stops. + model_constructor (function): The function used to construct the neural model. """ def _dqn(env, writer=DummyWriter()): - model = fc_relu_q(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam(model.parameters(), lr=lr) q = QNetwork( model, diff --git a/all/presets/classic_control/ppo.py b/all/presets/classic_control/ppo.py index 36a2d6e7..72fce453 100644 --- a/all/presets/classic_control/ppo.py +++ b/all/presets/classic_control/ppo.py @@ -22,6 +22,10 @@ def ppo( n_steps=8, # GAE settings lam=0.95, + # Model construction + feature_model_constructor=fc_relu_features, + value_model_constructor=fc_value_head, + policy_model_constructor=fc_policy_head ): """ PPO classic control preset. @@ -39,12 +43,15 @@ def ppo( n_envs (int): Number of parallel actors. n_steps (int): Length of each rollout. lam (float): The Generalized Advantage Estimate (GAE) decay parameter. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _ppo(envs, writer=DummyWriter()): env = envs[0] - feature_model = fc_relu_features(env).to(device) - value_model = fc_value_head().to(device) - policy_model = fc_policy_head(env).to(device) + feature_model = feature_model_constructor(env).to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(env).to(device) feature_optimizer = Adam(feature_model.parameters(), lr=lr) value_optimizer = Adam(value_model.parameters(), lr=lr) diff --git a/all/presets/classic_control/rainbow.py b/all/presets/classic_control/rainbow.py index 7159651d..199b1a42 100644 --- a/all/presets/classic_control/rainbow.py +++ b/all/presets/classic_control/rainbow.py @@ -32,6 +32,8 @@ def rainbow( v_max=100, # Noisy Nets sigma=0.5, + # Model construction + model_constructor=fc_relu_rainbow ): """ Rainbow classic control preset. @@ -54,9 +56,10 @@ def rainbow( v_min (int): The expected return corresponding to the smallest atom. v_max (int): The expected return correspodning to the larget atom. sigma (float): Initial noisy network noise. + model_constructor (function): The function used to construct the neural model. """ def _rainbow(env, writer=DummyWriter()): - model = fc_relu_rainbow(env, atoms=atoms, sigma=sigma).to(device) + model = model_constructor(env, atoms=atoms, sigma=sigma).to(device) optimizer = Adam(model.parameters(), lr=lr) q = QDist( model, diff --git a/all/presets/classic_control/vac.py b/all/presets/classic_control/vac.py index b2bffefd..e56def30 100644 --- a/all/presets/classic_control/vac.py +++ b/all/presets/classic_control/vac.py @@ -14,6 +14,10 @@ def vac( lr_v=5e-3, lr_pi=1e-3, eps=1e-5, + # Model construction + feature_model_constructor=fc_relu_features, + value_model_constructor=fc_value_head, + policy_model_constructor=fc_policy_head ): """ Vanilla Actor-Critic classic control preset. @@ -24,11 +28,14 @@ def vac( lr_v (float): Learning rate for value network. lr_pi (float): Learning rate for policy network and feature network. eps (float): Stability parameters for the Adam optimizer. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _vac(env, writer=DummyWriter()): - value_model = fc_value_head().to(device) - policy_model = fc_policy_head(env).to(device) - feature_model = fc_relu_features(env).to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(env).to(device) + feature_model = feature_model_constructor(env).to(device) value_optimizer = Adam(value_model.parameters(), lr=lr_v, eps=eps) policy_optimizer = Adam(policy_model.parameters(), lr=lr_pi, eps=eps) diff --git a/all/presets/classic_control/vpg.py b/all/presets/classic_control/vpg.py index 86cadb83..05a050ce 100644 --- a/all/presets/classic_control/vpg.py +++ b/all/presets/classic_control/vpg.py @@ -14,6 +14,10 @@ def vpg( lr=5e-3, # Batch settings min_batch_size=500, + # Model construction + feature_model_constructor=fc_relu_features, + value_model_constructor=fc_value_head, + policy_model_constructor=fc_policy_head ): """ Vanilla Policy Gradient classic control preset. @@ -25,11 +29,14 @@ def vpg( lr (float): Learning rate for the Adam optimizer. min_batch_size (int): Continue running complete episodes until at least this many states have been seen since the last update. + feature_model_constructor (function): The function used to construct the neural feature model. + value_model_constructor (function): The function used to construct the neural value model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _vpg(env, writer=DummyWriter()): - feature_model = fc_relu_features(env).to(device) - value_model = fc_value_head().to(device) - policy_model = fc_policy_head(env).to(device) + feature_model = feature_model_constructor(env).to(device) + value_model = value_model_constructor().to(device) + policy_model = policy_model_constructor(env).to(device) feature_optimizer = Adam(feature_model.parameters(), lr=lr) value_optimizer = Adam(value_model.parameters(), lr=lr) diff --git a/all/presets/classic_control/vqn.py b/all/presets/classic_control/vqn.py index f5d21de1..3c26150d 100644 --- a/all/presets/classic_control/vqn.py +++ b/all/presets/classic_control/vqn.py @@ -16,6 +16,8 @@ def vqn( epsilon=0.1, # Parallel actors n_envs=1, + # Model construction + model_constructor=fc_relu_q ): """ Vanilla Q-Network classic control preset. @@ -27,10 +29,11 @@ def vqn( eps (float): Stability parameters for the Adam optimizer. epsilon (int): Probability of choosing a random action. n_envs (int): Number of parallel environments. + model_constructor (function): The function used to construct the neural model. """ def _vqn(envs, writer=DummyWriter()): env = envs[0] - model = fc_relu_q(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam(model.parameters(), lr=lr, eps=eps) q = QNetwork(model, optimizer, writer=writer) policy = GreedyPolicy(q, env.action_space.n, epsilon=epsilon) diff --git a/all/presets/classic_control/vsarsa.py b/all/presets/classic_control/vsarsa.py index 42394586..6349409f 100644 --- a/all/presets/classic_control/vsarsa.py +++ b/all/presets/classic_control/vsarsa.py @@ -16,6 +16,8 @@ def vsarsa( epsilon=0.1, # Parallel actors n_envs=1, + # Model construction + model_constructor=fc_relu_q ): """ Vanilla SARSA classic control preset. @@ -27,10 +29,11 @@ def vsarsa( eps (float): Stability parameters for the Adam optimizer. epsilon (int): Probability of choosing a random action. n_envs (int): Number of parallel environments. + model_constructor (function): The function used to construct the neural model. """ def _vsarsa(envs, writer=DummyWriter()): env = envs[0] - model = fc_relu_q(env).to(device) + model = model_constructor(env).to(device) optimizer = Adam(model.parameters(), lr=lr, eps=eps) q = QNetwork(model, optimizer, writer=writer) policy = GreedyPolicy(q, env.action_space.n, epsilon=epsilon) diff --git a/all/presets/continuous/ddpg.py b/all/presets/continuous/ddpg.py index 0b172cc3..11c6c25d 100644 --- a/all/presets/continuous/ddpg.py +++ b/all/presets/continuous/ddpg.py @@ -26,6 +26,9 @@ def ddpg( replay_buffer_size=1e6, # Exploration settings noise=0.1, + # Model construction + q_model_constructor=fc_q, + policy_model_constructor=fc_deterministic_policy ): """ DDPG continuous control preset. @@ -42,11 +45,13 @@ def ddpg( replay_start_size (int): Number of experiences in replay buffer when training begins. replay_buffer_size (int): Maximum number of experiences to store in the replay buffer. noise (float): The amount of exploration noise to add. + q_model_constructor (function): The function used to construct the neural q model. + policy_model_constructor (function): The function used to construct the neural policy model. """ def _ddpg(env, writer=DummyWriter()): final_anneal_step = (last_frame - replay_start_size) // update_frequency - q_model = fc_q(env).to(device) + q_model = q_model_constructor(env).to(device) q_optimizer = Adam(q_model.parameters(), lr=lr_q) q = QContinuous( q_model, @@ -59,7 +64,7 @@ def _ddpg(env, writer=DummyWriter()): writer=writer ) - policy_model = fc_deterministic_policy(env).to(device) + policy_model = policy_model_constructor(env).to(device) policy_optimizer = Adam(policy_model.parameters(), lr=lr_pi) policy = DeterministicPolicy( policy_model, diff --git a/all/presets/continuous/ppo.py b/all/presets/continuous/ppo.py index 03f3dd7b..fd68964e 100644 --- a/all/presets/continuous/ppo.py +++ b/all/presets/continuous/ppo.py @@ -31,6 +31,8 @@ def ppo( n_steps=128, # GAE settings lam=0.95, + # Model construction + ac_model_constructor=fc_actor_critic ): """ PPO continuous control preset. @@ -51,12 +53,13 @@ def ppo( n_envs (int): Number of parallel actors. n_steps (int): Length of each rollout. lam (float): The Generalized Advantage Estimate (GAE) decay parameter. + ac_model_constructor (function): The function used to construct the neural feature, value and policy model. """ def _ppo(envs, writer=DummyWriter()): final_anneal_step = last_frame * epochs * minibatches / (n_steps * n_envs) env = envs[0] - feature_model, value_model, policy_model = fc_actor_critic(env) + feature_model, value_model, policy_model = ac_model_constructor(env) feature_model.to(device) value_model.to(device) policy_model.to(device) diff --git a/all/presets/continuous/sac.py b/all/presets/continuous/sac.py index f85e84cd..5cf23331 100644 --- a/all/presets/continuous/sac.py +++ b/all/presets/continuous/sac.py @@ -29,6 +29,11 @@ def sac( temperature_initial=0.1, lr_temperature=1e-5, entropy_target_scaling=1., + # Model construction + q1_model_constructor=fc_q, + q2_model_constructor=fc_q, + v_model_constructor=fc_v, + policy_model_constructor=fc_soft_policy ): """ SAC continuous control preset. @@ -48,11 +53,15 @@ def sac( temperature_initial (float): Initial value of the temperature parameter. lr_temperature (float): Learning rate for the temperature. Should be low compared to other learning rates. entropy_target_scaling (float): The target entropy will be -(entropy_target_scaling * env.action_space.shape[0]) + q1_model_constructor(function): The function used to construct the neural q1 model. + q2_model_constructor(function): The function used to construct the neural q2 model. + v_model_constructor(function): The function used to construct the neural v model. + policy_model_constructor(function): The function used to construct the neural policy model. """ def _sac(env, writer=DummyWriter()): final_anneal_step = (last_frame - replay_start_size) // update_frequency - q_1_model = fc_q(env).to(device) + q_1_model = q1_model_constructor(env).to(device) q_1_optimizer = Adam(q_1_model.parameters(), lr=lr_q) q_1 = QContinuous( q_1_model, @@ -65,7 +74,7 @@ def _sac(env, writer=DummyWriter()): name='q_1' ) - q_2_model = fc_q(env).to(device) + q_2_model = q2_model_constructor(env).to(device) q_2_optimizer = Adam(q_2_model.parameters(), lr=lr_q) q_2 = QContinuous( q_2_model, @@ -78,7 +87,7 @@ def _sac(env, writer=DummyWriter()): name='q_2' ) - v_model = fc_v(env).to(device) + v_model = v_model_constructor(env).to(device) v_optimizer = Adam(v_model.parameters(), lr=lr_v) v = VNetwork( v_model, @@ -92,7 +101,7 @@ def _sac(env, writer=DummyWriter()): name='v', ) - policy_model = fc_soft_policy(env).to(device) + policy_model = policy_model_constructor(env).to(device) policy_optimizer = Adam(policy_model.parameters(), lr=lr_pi) policy = SoftDeterministicPolicy( policy_model,