Skip to content

fix scaling bug in SoftDeterministicPolicy #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions all/policies/soft_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,28 @@ def _normal(self, outputs):

def _sample(self, normal):
raw = normal.rsample()
action = self._squash(raw)
log_prob = self._log_prob(normal, raw)
return self._squash(raw), log_prob

def _log_prob(self, normal, raw):
'''
Compute the log probability of a raw action after the action is squashed.
Both inputs act on the raw underlying distribution.
Because tanh_mean does not affect the density, we can ignore it.
However, tanh_scale will affect the relative contribution of each component.'
See Appendix C in the Soft Actor-Critic paper

Args:
normal (torch.distributions.normal.Normal): The "raw" normal distribution.
raw (torch.Tensor): The "raw" action.

Returns:
torch.Tensor: The probability of the raw action, accounting for the affects of tanh.
'''
log_prob = normal.log_prob(raw)
log_prob -= torch.log(1 - action.pow(2) + 1e-6)
log_prob = log_prob.sum(1)
return action, log_prob
log_prob -= torch.log(1 - torch.tanh(raw).pow(2) + 1e-6)
log_prob /= self._tanh_scale
return log_prob.sum(1)

def _squash(self, x):
return torch.tanh(x) * self._tanh_scale + self._tanh_mean
Expand Down
68 changes: 68 additions & 0 deletions all/policies/soft_deterministic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import unittest
import torch
import numpy as np
import torch_testing as tt
from gym.spaces import Box
from all import nn
from all.environments import State
from all.policies import SoftDeterministicPolicy

STATE_DIM = 2
ACTION_DIM = 3

class TestSoftDeterministic(unittest.TestCase):
def setUp(self):
torch.manual_seed(2)
self.model = nn.Sequential(
nn.Linear0(STATE_DIM, ACTION_DIM * 2)
)
self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01)
self.space = Box(np.array([-1, -1, -1]), np.array([1, 1, 1]))
self.policy = SoftDeterministicPolicy(
self.model,
self.optimizer,
self.space
)

def test_output_shape(self):
state = State(torch.randn(1, STATE_DIM))
action, log_prob = self.policy(state)
self.assertEqual(action.shape, (1, ACTION_DIM))
self.assertEqual(log_prob.shape, torch.Size([1]))

state = State(torch.randn(5, STATE_DIM))
action, log_prob = self.policy(state)
self.assertEqual(action.shape, (5, ACTION_DIM))
self.assertEqual(log_prob.shape, torch.Size([5]))

def test_step_one(self):
state = State(torch.randn(1, STATE_DIM))
self.policy(state)
self.policy.step()

def test_converge(self):
state = State(torch.randn(1, STATE_DIM))
target = torch.tensor([0.25, 0.5, -0.5])

for _ in range(0, 200):
action, _ = self.policy(state)
loss = ((target - action) ** 2).mean()
loss.backward()
self.policy.step()

self.assertLess(loss, 0.2)

def test_scaling(self):
self.space = Box(np.array([-10, -5, 100]), np.array([10, -2, 200]))
self.policy = SoftDeterministicPolicy(
self.model,
self.optimizer,
self.space
)
state = State(torch.randn(1, STATE_DIM))
action, log_prob = self.policy(state)
tt.assert_allclose(action, torch.tensor([[-3.09055, -4.752777, 188.98222]]))
tt.assert_allclose(log_prob, torch.tensor([-0.397002]), rtol=1e-4)

if __name__ == '__main__':
unittest.main()