Skip to content

Commit 31e5aa9

Browse files
authored
Merge pull request #171 from cpnota/release/0.6.0
Release/0.6.0
2 parents 3ec67d5 + 67fcf2c commit 31e5aa9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+1309
-710
lines changed

.pylintrc

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -423,23 +423,38 @@ function-naming-style=snake_case
423423
#function-rgx=
424424

425425
# Good variable names which should always be accepted, separated by a comma.
426-
good-names=i,
426+
good-names=a,
427+
b,
428+
c,
429+
d,
430+
e,
431+
f,
432+
g,
433+
h,
434+
i,
427435
j,
428436
k,
429-
ex,
430-
Run,
437+
l,
438+
m,
439+
n,
440+
o,
441+
p,
431442
q,
443+
r,
444+
s,
445+
t,
446+
u,
432447
v,
433-
_,
448+
w,
434449
x,
435450
y,
451+
z
452+
_,
436453
lr,
437-
n,
438-
t,
439-
e,
440-
u,
441454
kl,
442-
ax
455+
ax,
456+
ex,
457+
Run,
443458

444459
# Include a hint for the correct naming format with invalid-name.
445460
include-naming-hint=no

.travis.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
language: python
22
python:
3-
- "3.6"
3+
- "3.7"
44
branches:
55
only:
66
- master
77
- develop
88
install:
9-
- pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
10-
- pip install torchvision
9+
- pip install torch==1.5.1+cpu torchvision==0.6.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
1110
- pip install -q -e .["dev"]
1211
script:
1312
- make lint

all/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
import all.nn
2+
from all.core import State, StateArray
3+
4+
__all__ = ['nn', 'State', 'StateArray']

all/agents/_agent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Agent(ABC, Schedulable):
1313
"""
1414

1515
@abstractmethod
16-
def act(self, state, reward):
16+
def act(self, state):
1717
"""
1818
Select an action for the current timestep and update internal parameters.
1919
@@ -27,14 +27,13 @@ def act(self, state, reward):
2727
2828
Args:
2929
state (all.environment.State): The environment state at the current timestep.
30-
reward (torch.Tensor): The reward from the previous timestep.
3130
3231
Returns:
3332
torch.Tensor: The action to take at the current timestep.
3433
"""
3534

3635
@abstractmethod
37-
def eval(self, state, reward):
36+
def eval(self, state):
3837
"""
3938
Select an action for the current timestep in evaluation mode.
4039
@@ -45,7 +44,6 @@ def eval(self, state, reward):
4544
4645
Args:
4746
state (all.environment.State): The environment state at the current timestep.
48-
reward (torch.Tensor): The reward from the previous timestep.
4947
5048
Returns:
5149
torch.Tensor: The action to take at the current timestep.

all/agents/a2c.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def __init__(
5353
self._batch_size = n_envs * n_steps
5454
self._buffer = self._make_buffer()
5555

56-
def act(self, states, rewards):
57-
self._buffer.store(self._states, self._actions, rewards)
56+
def act(self, states):
57+
self._buffer.store(self._states, self._actions, states.reward)
5858
self._train(states)
5959
self._states = states
6060
self._actions = self.policy.no_grad(self.features.no_grad(states)).sample()
6161
return self._actions
6262

63-
def eval(self, states, _):
63+
def eval(self, states):
6464
return self.policy.eval(self.features.eval(states))
6565

6666
def _train(self, next_states):

all/agents/c51.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,20 @@ def __init__(
5353
self._action = None
5454
self._frames_seen = 0
5555

56-
def act(self, state, reward):
57-
self.replay_buffer.store(self._state, self._action, reward, state)
56+
def act(self, state):
57+
self.replay_buffer.store(self._state, self._action, state)
5858
self._train()
5959
self._state = state
6060
self._action = self._choose_action(state)
6161
return self._action
6262

63-
def eval(self, state, _):
64-
return self._best_actions(self.q_dist.eval(state))
63+
def eval(self, state):
64+
return self._best_actions(self.q_dist.eval(state)).item()
6565

6666
def _choose_action(self, state):
6767
if self._should_explore():
68-
return torch.randint(
69-
self.q_dist.n_actions, (len(state),), device=self.q_dist.device
70-
)
71-
return self._best_actions(self.q_dist.no_grad(state))
68+
return np.random.randint(0, self.q_dist.n_actions)
69+
return self._best_actions(self.q_dist.no_grad(state)).item()
7270

7371
def _should_explore(self):
7472
return (
@@ -77,8 +75,8 @@ def _should_explore(self):
7775
)
7876

7977
def _best_actions(self, probs):
80-
q_values = (probs * self.q_dist.atoms).sum(dim=2)
81-
return torch.argmax(q_values, dim=1)
78+
q_values = (probs * self.q_dist.atoms).sum(dim=-1)
79+
return torch.argmax(q_values, dim=-1)
8280

8381
def _train(self):
8482
if self._should_train():

all/agents/ddpg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def __init__(self,
5454
self._action = None
5555
self._frames_seen = 0
5656

57-
def act(self, state, reward):
58-
self.replay_buffer.store(self._state, self._action, reward, state)
57+
def act(self, state):
58+
self.replay_buffer.store(self._state, self._action, state)
5959
self._train()
6060
self._state = state
6161
self._action = self._choose_action(state)
6262
return self._action
6363

64-
def eval(self, state, _):
64+
def eval(self, state):
6565
return self.policy.eval(state)
6666

6767
def _choose_action(self, state):

all/agents/ddqn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self,
3838
self.q = q
3939
self.policy = policy
4040
self.replay_buffer = replay_buffer
41-
self.loss = staticmethod(loss)
41+
self.loss = loss
4242
# hyperparameters
4343
self.replay_start_size = replay_start_size
4444
self.update_frequency = update_frequency
@@ -49,14 +49,14 @@ def __init__(self,
4949
self._action = None
5050
self._frames_seen = 0
5151

52-
def act(self, state, reward):
53-
self.replay_buffer.store(self._state, self._action, reward, state)
52+
def act(self, state):
53+
self.replay_buffer.store(self._state, self._action, state)
5454
self._train()
5555
self._state = state
5656
self._action = self.policy.no_grad(state)
5757
return self._action
5858

59-
def eval(self, state, _):
59+
def eval(self, state):
6060
return self.policy.eval(state)
6161

6262
def _train(self):

all/agents/dqn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self,
3939
self.q = q
4040
self.policy = policy
4141
self.replay_buffer = replay_buffer
42-
self.loss = staticmethod(loss)
42+
self.loss = loss
4343
# hyperparameters
4444
self.discount_factor = discount_factor
4545
self.minibatch_size = minibatch_size
@@ -50,14 +50,14 @@ def __init__(self,
5050
self._action = None
5151
self._frames_seen = 0
5252

53-
def act(self, state, reward):
54-
self.replay_buffer.store(self._state, self._action, reward, state)
53+
def act(self, state):
54+
self.replay_buffer.store(self._state, self._action, state)
5555
self._train()
5656
self._state = state
5757
self._action = self.policy.no_grad(state)
5858
return self._action
5959

60-
def eval(self, state, _):
60+
def eval(self, state):
6161
return self.policy.eval(state)
6262

6363
def _train(self):

all/agents/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ def __init__(
6363
self._batch_size = n_envs * n_steps
6464
self._buffer = self._make_buffer()
6565

66-
def act(self, states, rewards):
67-
self._buffer.store(self._states, self._actions, rewards)
66+
def act(self, states):
67+
self._buffer.store(self._states, self._actions, states.reward)
6868
self._train(states)
6969
self._states = states
7070
self._actions = self.policy.no_grad(self.features.no_grad(states)).sample()
7171
return self._actions
7272

73-
def eval(self, states, _):
73+
def eval(self, states):
7474
return self.policy.eval(self.features.eval(states))
7575

7676
def _train(self, next_states):

all/agents/sac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ def __init__(self,
6363
self._action = None
6464
self._frames_seen = 0
6565

66-
def act(self, state, reward):
67-
self.replay_buffer.store(self._state, self._action, reward, state)
66+
def act(self, state):
67+
self.replay_buffer.store(self._state, self._action, state)
6868
self._train()
6969
self._state = state
7070
self._action = self.policy.no_grad(state)[0]
7171
return self._action
7272

73-
def eval(self, state, _):
74-
return self.policy.eval(state)[0]
73+
def eval(self, state):
74+
return self.policy.eval(state)
7575

7676
def _train(self):
7777
if self._should_train():

all/agents/vac.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ def __init__(self, features, v, policy, discount_factor=1):
2828
self._distribution = None
2929
self._action = None
3030

31-
def act(self, state, reward):
32-
self._train(state, reward)
31+
def act(self, state):
32+
self._train(state, state.reward)
3333
self._features = self.features(state)
3434
self._distribution = self.policy(self._features)
3535
self._action = self._distribution.sample()
3636
return self._action
3737

38-
def eval(self, state, _):
38+
def eval(self, state):
3939
return self.policy.eval(self.features.eval(state))
4040

4141
def _train(self, state, reward):

all/agents/vpg.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch.nn.functional import mse_loss
3-
from all.environments import State
3+
from all.core import State
44
from ._agent import Agent
55

66
class VPG(Agent):
@@ -43,38 +43,38 @@ def __init__(
4343
self._log_pis = []
4444
self._rewards = []
4545

46-
def act(self, state, reward):
46+
def act(self, state):
4747
if not self._features:
4848
return self._initial(state)
4949
if not state.done:
50-
return self._act(state, reward)
51-
return self._terminal(state, reward)
50+
return self._act(state, state.reward)
51+
return self._terminal(state, state.reward)
5252

53-
def eval(self, state, _):
53+
def eval(self, state):
5454
return self.policy.eval(self.features.eval(state))
5555

5656
def _initial(self, state):
5757
features = self.features(state)
5858
distribution = self.policy(features)
5959
action = distribution.sample()
60-
self._features = [features.features]
60+
self._features = [features]
6161
self._log_pis.append(distribution.log_prob(action))
6262
return action
6363

6464
def _act(self, state, reward):
6565
features = self.features(state)
6666
distribution = self.policy(features)
6767
action = distribution.sample()
68-
self._features.append(features.features)
68+
self._features.append(features)
6969
self._rewards.append(reward)
7070
self._log_pis.append(distribution.log_prob(action))
7171
return action
7272

7373
def _terminal(self, state, reward):
7474
self._rewards.append(reward)
75-
features = torch.cat(self._features)
75+
features = State.array(self._features)
7676
rewards = torch.tensor(self._rewards, device=features.device)
77-
log_pis = torch.cat(self._log_pis)
77+
log_pis = torch.stack(self._log_pis)
7878
self._trajectories.append((features, rewards, log_pis))
7979
self._current_batch_size += len(features)
8080
self._features = []
@@ -90,7 +90,7 @@ def _terminal(self, state, reward):
9090
def _train(self):
9191
# forward pass
9292
values = torch.cat([
93-
self.v(State(features))
93+
self.v(features)
9494
for (features, _, _)
9595
in self._trajectories
9696
])

all/agents/vqn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def __init__(self, q, policy, discount_factor=0.99):
2525
self._state = None
2626
self._action = None
2727

28-
def act(self, state, reward):
29-
self._train(reward, state)
28+
def act(self, state):
29+
self._train(state.reward, state)
3030
action = self.policy.no_grad(state)
3131
self._state = state
3232
self._action = action
3333
return action
3434

35-
def eval(self, state, _):
35+
def eval(self, state):
3636
return self.policy.eval(state)
3737

3838
def _train(self, reward, next_state):

all/agents/vsarsa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def __init__(self, q, policy, discount_factor=0.99):
2222
self._state = None
2323
self._action = None
2424

25-
def act(self, state, reward):
25+
def act(self, state):
2626
action = self.policy.no_grad(state)
27-
self._train(reward, state, action)
27+
self._train(state.reward, state, action)
2828
self._state = state
2929
self._action = action
3030
return action
3131

32-
def eval(self, state, _):
32+
def eval(self, state):
3333
return self.policy.eval(state)
3434

3535
def _train(self, reward, next_state, next_action):

0 commit comments

Comments
 (0)