Skip to content

SoftDeterministicPolicyNetwork: fixed inconsistent return types in training and eval mode. #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

michalgregor
Copy link
Contributor

I have discovered a slight bug on develop, regarding SAC in eval mode. The SoftDeterministicPolicyNetwork's forward returns a tensor in eval mode and a tuple in training mode. The agent then gets the 0-th dimension of the returned value in both cases – so in eval mode, SAC fails, returning a scalar instead of a vector as action.

Also, this used to work before accidentally, because the network returned a 2D tensor: so in training mode, the [0] would have selected that tensor from the tuple and in eval mode it would have selected the first row from the 2D tensor (which would still work, even though it was not really intended).

Finally, the pull request is showing me 3 commits instead of 1. I don't know why that is precisely, since the first 2 commits are already on develop as far as I can see. If you know why this is happening, please, let me know and I will fix it before you merge.

michalgregor and others added 4 commits August 11, 2020 14:29
… mode.

* SoftDeterministicPolicyNetwork now returns a tuple in eval mode as
  well in training mode: the way the interface in the SAC agent expects.
* Line self._name = env only worked when env was a string; the class name
  is now used in place of env when this is not the case.
@@ -43,7 +43,7 @@ def forward(self, state):
if self.training:
action, log_prob = self._sample(normal)
return action, log_prob
return self._squash(normal.loc)
return self._squash(normal.loc), torch.as_tensor(0.0, device=self.device)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right fix, or would the right fix be to change SAC? The log_prob isn't really 0.0, per se. A third option would be to actually compute the log_prob for the greedy action? Still sort of misleading, perhaps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can totally change that, but my reasoning was this: unless I am misreading the code, the policy is deterministic in eval mode, so the probability of the selected action is really 1, right? So its log would then be 0?

# Model construction
feature_model_constructor=nature_features,
value_model_constructor=nature_value_head,
policy_model_constructor=nature_policy_head
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure with GitHub is showing these as changes, looks okay in the file view, though.

env = gym.make(env)
else:
self._name = env.__class__.__name__

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wanted to do this in a separate PR originally, but apparently I am missing something crucial as to how PRs work here on GitHub. :D :D I will need to look into that...

@cpnota
Copy link
Owner

cpnota commented Sep 29, 2020

Ended up going with a slightly different fix. The fixes were handled by #169 and #170 . Thanks again for identifying these bugs!

@cpnota cpnota closed this Sep 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants