Open
Description
In acme.agents.tf.mcts.types
, Observation is hinted to be a numpy ndarray
# Assumption: observations are array-like.
Observation = np.ndarray
based on this, in acme.agents.tf.mcts.acting
, the MCTSActor._forward()
method hard-codes a tf.expand_dims
call on the observation. This makes it impossible to pass nested structures as observations to MCTS.
To fix it, we can re-define this _forward
method as:
import tree
(...)
def _forward(self, observation):
# this is all they should have done in the first place
logits, value = self._network(tree.map_structure(lambda o: tf.expand_dims(o,axis=0), observation))
(...)
and solve the problem.
Is there anything I am missing?
Metadata
Metadata
Assignees
Labels
No labels