Skip to content

Commit 01836e0

Browse files
authored
Merge pull request #248 from cpnota/release/0.7.1
Release/0.7.1
2 parents 67b27aa + 074d0ca commit 01836e0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+707
-159
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ jobs:
3030
pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
3131
make install
3232
AutoROM -v
33+
python -m atari_py.import_roms $(python -c 'import site; print(site.getsitepackages()[0])')/multi_agent_ale_py/ROM
3334
- name: Lint code
3435
run: |
3536
make lint

all/agents/a2c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _make_buffer(self):
101101
)
102102

103103

104-
class A2CTestAgent(Agent):
104+
class A2CTestAgent(Agent, ParallelAgent):
105105
def __init__(self, features, policy):
106106
self.features = features
107107
self.policy = policy

all/agents/dqn.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,8 @@ def _should_train(self):
8181

8282

8383
class DQNTestAgent(Agent):
84-
def __init__(self, q, n_actions, exploration=0.):
85-
self.q = q
86-
self.n_actions = n_actions
87-
self.exploration = 0.001
84+
def __init__(self, policy):
85+
self.policy = policy
8886

8987
def act(self, state):
90-
if np.random.rand() < self.exploration:
91-
return np.random.randint(0, self.n_actions)
92-
return torch.argmax(self.q.eval(state)).item()
88+
return self.policy.eval(state)

all/agents/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _train(self):
9898

9999
# adjust temperature
100100
temperature_grad = (_log_probs + self.entropy_target).mean()
101-
self.temperature += self.lr_temperature * temperature_grad.detach()
101+
self.temperature = max(0, self.temperature + self.lr_temperature * temperature_grad.detach())
102102

103103
# additional debugging info
104104
self.writer.add_loss('entropy', -_log_probs.mean())

all/agents/vqn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,9 @@ def _train(self, reward, next_state):
5050
self.q.reinforce(loss)
5151

5252

53-
VQNTestAgent = DQNTestAgent
53+
class VQNTestAgent(Agent, ParallelAgent):
54+
def __init__(self, policy):
55+
self.policy = policy
56+
57+
def act(self, state):
58+
return self.policy.eval(state)

all/agents/vsarsa.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from torch.nn.functional import mse_loss
2-
from ._agent import Agent
32
from ._parallel_agent import ParallelAgent
4-
from .dqn import DQNTestAgent
3+
from .vqn import VQNTestAgent
54

65

76
class VSarsa(ParallelAgent):
@@ -47,4 +46,4 @@ def _train(self, reward, next_state, next_action):
4746
self.q.reinforce(loss)
4847

4948

50-
VSarsaTestAgent = DQNTestAgent
49+
VSarsaTestAgent = VQNTestAgent

all/bodies/atari.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
class DeepmindAtariBody(Body):
88
def __init__(self, agent, lazy_frames=False, episodic_lives=True, frame_stack=4, clip_rewards=True):
9-
agent = FrameStack(agent, lazy=lazy_frames, size=frame_stack)
9+
if frame_stack > 1:
10+
agent = FrameStack(agent, lazy=lazy_frames, size=frame_stack)
1011
if clip_rewards:
1112
agent = ClipRewards(agent)
1213
if episodic_lives:
@@ -19,7 +20,7 @@ def process_state(self, state):
1920
if 'life_lost' not in state:
2021
return state
2122

22-
if len(state) == 1:
23+
if len(state.shape) == 0:
2324
if state['life_lost']:
2425
return state.update('mask', 0.)
2526
return state

all/bodies/vision.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,9 @@ def update(self, key, value):
6969
x = {}
7070
for k in self.keys():
7171
if not k == key:
72-
x[k] = super().__getitem__(k)
72+
x[k] = dict.__getitem__(self, k)
7373
x[key] = value
74-
state = LazyState(x, device=self.device)
75-
state.to_cache = self.to_cache
74+
state = LazyState.from_state(x, x['observation'], self.to_cache)
7675
return state
7776

7877
def to(self, device):

all/environments/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from ._environment import Environment
2-
from._multiagent_environment import MultiagentEnvironment
2+
from ._multiagent_environment import MultiagentEnvironment
3+
from ._vector_environment import VectorEnvironment
34
from .gym import GymEnvironment
45
from .atari import AtariEnvironment
56
from .multiagent_atari import MultiagentAtariEnv
67
from .multiagent_pettingzoo import MultiagentPettingZooEnv
8+
from .duplicate_env import DuplicateEnvironment
9+
from .vector_env import GymVectorEnvironment
710
from .pybullet import PybulletEnvironment
811

912
__all__ = [
@@ -13,5 +16,7 @@
1316
"AtariEnvironment",
1417
"MultiagentAtariEnv",
1518
"MultiagentPettingZooEnv",
19+
"GymVectorEnvironment",
20+
"DuplicateEnvironment",
1621
"PybulletEnvironment",
1722
]
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class VectorEnvironment(ABC):
5+
"""
6+
A reinforcement learning vector Environment.
7+
8+
Similar to a regular RL environment except many environments are stacked together
9+
in the observations, rewards, and dones, and the vector environment expects
10+
an action to be given for each environment in step.
11+
12+
Also, since sub-environments are done at different times, you do not need to
13+
manually reset the environments when they are done, rather the vector environment
14+
automatically resets environments when they are complete.
15+
"""
16+
17+
@property
18+
@abstractmethod
19+
def name(self):
20+
"""
21+
The name of the environment.
22+
"""
23+
24+
@abstractmethod
25+
def reset(self):
26+
"""
27+
Reset the environment and return a new initial state.
28+
29+
Returns
30+
-------
31+
State
32+
The initial state for the next episode.
33+
"""
34+
35+
@abstractmethod
36+
def step(self, action):
37+
"""
38+
Apply an action and get the next state.
39+
40+
Parameters
41+
----------
42+
action : Action
43+
The action to apply at the current time step.
44+
45+
Returns
46+
-------
47+
all.environments.State
48+
The State of the environment after the action is applied.
49+
This State object includes both the done flag and any additional "info"
50+
float
51+
The reward achieved by the previous action
52+
"""
53+
54+
@abstractmethod
55+
def close(self):
56+
"""
57+
Clean up any extraneous environment objects.
58+
"""
59+
60+
@property
61+
@abstractmethod
62+
def state_array(self):
63+
"""
64+
A StateArray of the Environments at the current timestep.
65+
"""
66+
67+
@property
68+
@abstractmethod
69+
def state_space(self):
70+
"""
71+
The Space representing the range of observable states for each environment.
72+
73+
Returns
74+
-------
75+
Space
76+
An object of type Space that represents possible states the agent may observe
77+
"""
78+
79+
@property
80+
def observation_space(self):
81+
"""
82+
Alias for Environment.state_space.
83+
84+
Returns
85+
-------
86+
Space
87+
An object of type Space that represents possible states the agent may observe
88+
"""
89+
return self.state_space
90+
91+
@property
92+
@abstractmethod
93+
def action_space(self):
94+
"""
95+
The Space representing the range of possible actions for each environment.
96+
97+
Returns
98+
-------
99+
Space
100+
An object of type Space that represents possible actions the agent may take
101+
"""
102+
103+
@property
104+
@abstractmethod
105+
def device(self):
106+
"""
107+
The torch device the environment lives on.
108+
"""
109+
110+
@property
111+
@abstractmethod
112+
def num_envs(self):
113+
"""
114+
Number of environments in vector. This is the number of actions step() expects as input
115+
and the number of observations, dones, etc returned by the environment.
116+
"""

all/environments/atari.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
LifeLostEnv,
99
)
1010
from all.core import State
11+
from .duplicate_env import DuplicateEnvironment
1112

1213

1314
class AtariEnvironment(GymEnvironment):
@@ -38,6 +39,6 @@ def reset(self):
3839
return self._state
3940

4041
def duplicate(self, n):
41-
return [
42+
return DuplicateEnvironment([
4243
AtariEnvironment(self._name, *self._args, **self._kwargs) for _ in range(n)
43-
]
44+
])

all/environments/duplicate_env.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import gym
2+
import torch
3+
from all.core import State
4+
from ._vector_environment import VectorEnvironment
5+
import numpy as np
6+
7+
8+
class DuplicateEnvironment(VectorEnvironment):
9+
'''
10+
Turns a list of ALL Environment objects into a VectorEnvironment object
11+
12+
This wrapper just takes the list of States the environments generate and outputs
13+
a StateArray object containing all of the environment states. Like all vector
14+
environments, the sub environments are automatically reset when done.
15+
16+
Args:
17+
envs: A list of ALL environments
18+
device (optional): the device on which tensors will be stored
19+
'''
20+
21+
def __init__(self, envs, device=torch.device('cpu')):
22+
self._name = envs[0].name
23+
self._envs = envs
24+
self._state = None
25+
self._action = None
26+
self._reward = None
27+
self._done = True
28+
self._info = None
29+
self._device = device
30+
31+
@property
32+
def name(self):
33+
return self._name
34+
35+
def reset(self):
36+
self._state = State.array([sub_env.reset() for sub_env in self._envs])
37+
return self._state
38+
39+
def step(self, actions):
40+
states = []
41+
actions = actions.cpu().detach().numpy()
42+
for sub_env, action in zip(self._envs, actions):
43+
state = sub_env.reset() if sub_env.state.done else sub_env.step(action)
44+
states.append(state)
45+
self._state = State.array(states)
46+
return self._state
47+
48+
def close(self):
49+
return self._env.close()
50+
51+
def seed(self, seed):
52+
for i, env in enumerate(self._envs):
53+
env.seed(seed + i)
54+
55+
@property
56+
def state_space(self):
57+
return self._envs[0].observation_space
58+
59+
@property
60+
def action_space(self):
61+
return self._envs[0].action_space
62+
63+
@property
64+
def state_array(self):
65+
return self._state
66+
67+
@property
68+
def device(self):
69+
return self._device
70+
71+
@property
72+
def num_envs(self):
73+
return len(self._envs)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
import gym
3+
import torch
4+
from all.environments import DuplicateEnvironment, GymEnvironment
5+
6+
7+
def make_vec_env(num_envs=3):
8+
env = [GymEnvironment('CartPole-v0') for i in range(num_envs)]
9+
return env
10+
11+
12+
class DuplicateEnvironmentTest(unittest.TestCase):
13+
def test_env_name(self):
14+
env = DuplicateEnvironment(make_vec_env())
15+
self.assertEqual(env.name, 'CartPole-v0')
16+
17+
def test_num_envs(self):
18+
num_envs = 5
19+
env = DuplicateEnvironment(make_vec_env(num_envs))
20+
self.assertEqual(env.num_envs, num_envs)
21+
self.assertEqual((num_envs,), env.reset().shape)
22+
23+
def test_reset(self):
24+
num_envs = 5
25+
env = DuplicateEnvironment(make_vec_env(num_envs))
26+
state = env.reset()
27+
self.assertEqual(state.observation.shape, (num_envs, 4))
28+
self.assertTrue((state.reward == torch.zeros(num_envs, )).all())
29+
self.assertTrue((state.done == torch.zeros(num_envs, )).all())
30+
self.assertTrue((state.mask == torch.ones(num_envs, )).all())
31+
32+
def test_step(self):
33+
num_envs = 5
34+
env = DuplicateEnvironment(make_vec_env(num_envs))
35+
env.reset()
36+
state = env.step(torch.ones(num_envs, dtype=torch.int32))
37+
self.assertEqual(state.observation.shape, (num_envs, 4))
38+
self.assertTrue((state.reward == torch.ones(num_envs, )).all())
39+
self.assertTrue((state.done == torch.zeros(num_envs, )).all())
40+
self.assertTrue((state.mask == torch.ones(num_envs, )).all())
41+
42+
def test_step_until_done(self):
43+
num_envs = 3
44+
env = DuplicateEnvironment(make_vec_env(num_envs))
45+
env.seed(5)
46+
env.reset()
47+
for _ in range(100):
48+
state = env.step(torch.ones(num_envs, dtype=torch.int32))
49+
if state.done[0]:
50+
break
51+
self.assertEqual(state[0].observation.shape, (4,))
52+
self.assertEqual(state[0].reward, 1.)
53+
self.assertTrue(state[0].done)
54+
self.assertEqual(state[0].mask, 0)

0 commit comments

Comments
 (0)