Skip to content

Commit 96479c6

Browse files
committed
Replaced space.shape[0] with len(space), supporting Tuple spaces.
* Tuple spaces do not have a shape attribute. To better support them, references to space.shape[0] have been replaced with len(space), which is more generic.
1 parent c5aa83b commit 96479c6

File tree

7 files changed

+20
-20
lines changed

7 files changed

+20
-20
lines changed

all/agents/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SAC(Agent):
2222
v (VNetwork): An Approximation of the state-value function.
2323
replay_buffer (ReplayBuffer): The experience replay buffer.
2424
discount_factor (float): Discount factor for future rewards.
25-
entropy_target (float): The desired entropy of the policy. Usually -env.action_space.shape[0]
25+
entropy_target (float): The desired entropy of the policy. Usually -len(env.action_space)
2626
minibatch_size (int): The number of experiences to sample in each training update.
2727
replay_start_size (int): Number of experiences in replay buffer when training begins.
2828
temperature_initial (float): The initial temperature used in the maximum entropy objective.

all/policies/deterministic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
class DeterministicPolicyNetwork(RLNetwork):
3737
def __init__(self, model, space):
3838
super().__init__(model)
39-
self._action_dim = space.shape[0]
39+
self._action_dim = len(space)
4040
self._tanh_scale = torch.tensor((space.high - space.low) / 2).to(self.device)
4141
self._tanh_mean = torch.tensor((space.high + space.low) / 2).to(self.device)
4242

all/policies/soft_deterministic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
class SoftDeterministicPolicyNetwork(RLNetwork):
3434
def __init__(self, model, space):
3535
super().__init__(model)
36-
self._action_dim = space.shape[0]
36+
self._action_dim = len(space)
3737
self._tanh_scale = torch.tensor((space.high - space.low) / 2).to(self.device)
3838
self._tanh_mean = torch.tensor((space.high + space.low) / 2).to(self.device)
3939

all/presets/classic_control/models/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
def fc_relu_q(env, hidden=64):
55
return nn.Sequential(
66
nn.Flatten(),
7-
nn.Linear(env.state_space.shape[0], hidden),
7+
nn.Linear(len(env.state_space), hidden),
88
nn.ReLU(),
99
nn.Linear(hidden, env.action_space.n),
1010
)
@@ -15,10 +15,10 @@ def dueling_fc_relu_q(env):
1515
nn.Flatten(),
1616
nn.Dueling(
1717
nn.Sequential(
18-
nn.Linear(env.state_space.shape[0], 256), nn.ReLU(), nn.Linear(256, 1)
18+
nn.Linear(len(env.state_space), 256), nn.ReLU(), nn.Linear(256, 1)
1919
),
2020
nn.Sequential(
21-
nn.Linear(env.state_space.shape[0], 256),
21+
nn.Linear(len(env.state_space), 256),
2222
nn.ReLU(),
2323
nn.Linear(256, env.action_space.n),
2424
),
@@ -28,7 +28,7 @@ def dueling_fc_relu_q(env):
2828

2929
def fc_relu_features(env, hidden=64):
3030
return nn.Sequential(
31-
nn.Flatten(), nn.Linear(env.state_space.shape[0], hidden), nn.ReLU()
31+
nn.Flatten(), nn.Linear(len(env.state_space), hidden), nn.ReLU()
3232
)
3333

3434

@@ -43,7 +43,7 @@ def fc_policy_head(env, hidden=64):
4343
def fc_relu_dist_q(env, hidden=64, atoms=51):
4444
return nn.Sequential(
4545
nn.Flatten(),
46-
nn.Linear(env.state_space.shape[0], hidden),
46+
nn.Linear(len(env.state_space), hidden),
4747
nn.ReLU(),
4848
nn.Linear0(hidden, env.action_space.n * atoms),
4949
)
@@ -52,7 +52,7 @@ def fc_relu_dist_q(env, hidden=64, atoms=51):
5252
def fc_relu_rainbow(env, hidden=64, atoms=51, sigma=0.5):
5353
return nn.Sequential(
5454
nn.Flatten(),
55-
nn.Linear(env.state_space.shape[0], hidden),
55+
nn.Linear(len(env.state_space), hidden),
5656
nn.ReLU(),
5757
nn.CategoricalDueling(
5858
nn.NoisyFactorizedLinear(hidden, atoms, sigma_init=sigma),

all/presets/continuous/models/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def fc_q(env, hidden1=400, hidden2=300):
1212
return nn.Sequential(
13-
nn.Linear(env.state_space.shape[0] + env.action_space.shape[0] + 1, hidden1),
13+
nn.Linear(len(env.state_space) + len(env.action_space) + 1, hidden1),
1414
nn.ReLU(),
1515
nn.Linear(hidden1, hidden2),
1616
nn.ReLU(),
@@ -19,7 +19,7 @@ def fc_q(env, hidden1=400, hidden2=300):
1919

2020
def fc_v(env, hidden1=400, hidden2=300):
2121
return nn.Sequential(
22-
nn.Linear(env.state_space.shape[0] + 1, hidden1),
22+
nn.Linear(len(env.state_space) + 1, hidden1),
2323
nn.ReLU(),
2424
nn.Linear(hidden1, hidden2),
2525
nn.ReLU(),
@@ -28,25 +28,25 @@ def fc_v(env, hidden1=400, hidden2=300):
2828

2929
def fc_deterministic_policy(env, hidden1=400, hidden2=300):
3030
return nn.Sequential(
31-
nn.Linear(env.state_space.shape[0] + 1, hidden1),
31+
nn.Linear(len(env.state_space) + 1, hidden1),
3232
nn.ReLU(),
3333
nn.Linear(hidden1, hidden2),
3434
nn.ReLU(),
35-
nn.Linear0(hidden2, env.action_space.shape[0]),
35+
nn.Linear0(hidden2, len(env.action_space)),
3636
)
3737

3838
def fc_soft_policy(env, hidden1=400, hidden2=300):
3939
return nn.Sequential(
40-
nn.Linear(env.state_space.shape[0] + 1, hidden1),
40+
nn.Linear(len(env.state_space) + 1, hidden1),
4141
nn.ReLU(),
4242
nn.Linear(hidden1, hidden2),
4343
nn.ReLU(),
44-
nn.Linear0(hidden2, env.action_space.shape[0] * 2),
44+
nn.Linear0(hidden2, len(env.action_space) * 2),
4545
)
4646

4747
def fc_actor_critic(env, hidden1=400, hidden2=300):
4848
features = nn.Sequential(
49-
nn.Linear(env.state_space.shape[0] + 1, hidden1),
49+
nn.Linear(len(env.state_space) + 1, hidden1),
5050
nn.ReLU(),
5151
)
5252

@@ -59,7 +59,7 @@ def fc_actor_critic(env, hidden1=400, hidden2=300):
5959
policy = nn.Sequential(
6060
nn.Linear(hidden1, hidden2),
6161
nn.ReLU(),
62-
nn.Linear(hidden2, env.action_space.shape[0] * 2)
62+
nn.Linear(hidden2, len(env.action_space) * 2)
6363
)
6464

6565
return features, v, policy

all/presets/continuous/sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def sac(
5252
replay_buffer_size (int): Maximum number of experiences to store in the replay buffer.
5353
temperature_initial (float): Initial value of the temperature parameter.
5454
lr_temperature (float): Learning rate for the temperature. Should be low compared to other learning rates.
55-
entropy_target_scaling (float): The target entropy will be -(entropy_target_scaling * env.action_space.shape[0])
55+
entropy_target_scaling (float): The target entropy will be -(entropy_target_scaling * len(env.action_space))
5656
q1_model_constructor(function): The function used to construct the neural q1 model.
5757
q2_model_constructor(function): The function used to construct the neural q2 model.
5858
v_model_constructor(function): The function used to construct the neural v model.
@@ -126,7 +126,7 @@ def _sac(env, writer=DummyWriter()):
126126
v,
127127
replay_buffer,
128128
temperature_initial=temperature_initial,
129-
entropy_target=(-env.action_space.shape[0] * entropy_target_scaling),
129+
entropy_target=(-len(env.action_space) * entropy_target_scaling),
130130
lr_temperature=lr_temperature,
131131
replay_start_size=replay_start_size,
132132
discount_factor=discount_factor,

docs/source/guide/basic_concepts.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ In order to actually apply this agent to a problem, for example, a classic contr
218218
def _vqn(env, writer=DummyWriter()):
219219
# create a pytorch model
220220
model = nn.Sequential(
221-
nn.Linear(env.state_space.shape[0], 64),
221+
nn.Linear(len(env.state_space), 64),
222222
nn.ReLU(),
223223
nn.Linear(64, env.action_space.n),
224224
).to(device)

0 commit comments

Comments
 (0)