Skip to content

Commit aaa5403

Browse files
authored
Merge pull request #257 from cpnota/release/0.7.2
Release/0.7.2
2 parents 9c84581 + bb4fc1e commit aaa5403

File tree

6 files changed

+38
-19
lines changed

6 files changed

+38
-19
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
run: |
2828
sudo apt-get install swig
2929
sudo apt-get install unrar
30-
pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
30+
pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
3131
make install
3232
AutoROM -v
3333
python -m atari_py.import_roms $(python -c 'import site; print(site.getsitepackages()[0])')/multi_agent_ale_py/ROM

all/approximation/approximation.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def eval(self, *inputs):
103103
with torch.no_grad():
104104
# check current mode
105105
mode = self.model.training
106-
# switch to eval mode
106+
# switch model to eval mode
107107
self.model.eval()
108108
# run forward pass
109109
result = self.model(*inputs)
@@ -144,14 +144,11 @@ def step(self):
144144
Returns:
145145
self: The current Approximation object
146146
'''
147-
if self._clip_grad != 0:
148-
utils.clip_grad_norm_(self.model.parameters(), self._clip_grad)
147+
self._clip_grad_norm()
149148
self._optimizer.step()
150149
self._optimizer.zero_grad()
150+
self._step_lr_scheduler()
151151
self._target.update()
152-
if self._scheduler:
153-
self._writer.add_schedule(self._name + '/lr', self._optimizer.param_groups[0]['lr'])
154-
self._scheduler.step()
155152
self._checkpointer()
156153
return self
157154

@@ -164,3 +161,14 @@ def zero_grad(self):
164161
'''
165162
self._optimizer.zero_grad()
166163
return self
164+
165+
def _clip_grad_norm(self):
166+
'''Clip the gradient norm if set. Raises RuntimeError if norm is non-finite.'''
167+
if self._clip_grad != 0:
168+
utils.clip_grad_norm_(self.model.parameters(), self._clip_grad, error_if_nonfinite=True)
169+
170+
def _step_lr_scheduler(self):
171+
'''Step the . Raises RuntimeError if norm is non-finite.'''
172+
if self._scheduler:
173+
self._writer.add_schedule(self._name + '/lr', self._optimizer.param_groups[0]['lr'])
174+
self._scheduler.step()

all/policies/soft_deterministic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _log_prob(self, normal, raw):
7575
'''
7676
log_prob = normal.log_prob(raw)
7777
log_prob -= torch.log(1 - torch.tanh(raw).pow(2) + 1e-6)
78-
log_prob /= self._tanh_scale
78+
log_prob -= torch.log(self._tanh_scale)
7979
return log_prob.sum(-1)
8080

8181
def _squash(self, x):

all/policies/soft_deterministic_test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,28 @@ def test_converge(self):
5656
self.assertLess(loss, 0.2)
5757

5858
def test_scaling(self):
59-
self.space = Box(np.array([-10, -5, 100]), np.array([10, -2, 200]))
60-
self.policy = SoftDeterministicPolicy(
59+
torch.manual_seed(0)
60+
state = State(torch.randn(1, STATE_DIM))
61+
policy1 = SoftDeterministicPolicy(
6162
self.model,
6263
self.optimizer,
63-
self.space
64+
Box(np.array([-1., -1., -1.]), np.array([1., 1., 1.]))
6465
)
66+
action1, log_prob1 = policy1(state)
67+
68+
# reset seed and sample same thing, but with different scaling
69+
torch.manual_seed(0)
6570
state = State(torch.randn(1, STATE_DIM))
66-
action, log_prob = self.policy(state)
67-
tt.assert_allclose(action, torch.tensor([[-3.09055, -4.752777, 188.98222]]))
68-
tt.assert_allclose(log_prob, torch.tensor([-0.397002]), rtol=1e-4)
71+
policy2 = SoftDeterministicPolicy(
72+
self.model,
73+
self.optimizer,
74+
Box(np.array([-2., -1., -1.]), np.array([2., 1., 1.]))
75+
)
76+
action2, log_prob2 = policy2(state)
77+
78+
# check scaling was correct
79+
tt.assert_allclose(action1 * torch.tensor([2, 1, 1]), action2)
80+
tt.assert_allclose(log_prob1 - np.log(2), log_prob2)
6981

7082

7183
if __name__ == '__main__':

docs/source/conf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
author = 'Chris Nota'
2323

2424
# The full version, including alpha/beta/rc tags
25-
release = '0.7.1'
26-
25+
release = '0.7.2'
2726

2827
# -- General configuration ---------------------------------------------------
2928

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
setup(
4040
name="autonomous-learning-library",
41-
version="0.7.1",
41+
version="0.7.2",
4242
description=("A library for building reinforcement learning agents in Pytorch"),
4343
packages=find_packages(),
4444
url="https://github.com/cpnota/autonomous-learning-library.git",
@@ -61,8 +61,8 @@
6161
"gym~=0.18.0", # common environment interface
6262
"numpy>=1.18.0", # math library
6363
"matplotlib>=3.3.0", # plotting library
64-
"opencv-python~=3.4.0", # used by atari wrappers
65-
"torch~=1.8.0", # core deep learning library
64+
"opencv-python~=3.4.0", # used by atari wrappers
65+
"torch~=1.9.0", # core deep learning library
6666
"tensorboard>=2.3.0", # logging and visualization
6767
"tensorboardX>=2.1.0", # tensorboard/pytorch compatibility
6868
"cloudpickle>=1.2.0", # used to copy environments

0 commit comments

Comments
 (0)