Skip to content

Commit dd0b928

Browse files
authored
Implement autoregressive PPO for sb3 applied to plado PDDL/PPDDL domains (airbus#478)
* Move base sb3 classes for buffers and algos to be used by several customized ones To be used by GNN-based algos and later by algos for composite actions with autoregressive prediction of components. * Implement a masked version of MultiDiscreteSpace This will be useful to model variable-length multidiscrete actions, e.g. parameteric actions whose arity depends on its first component, the action type (for PDDL domains). An element of this space can have -1 components corresponding to non-existing components. Otherwise it is still between 0 and nvec[i]-1. * Implement state_sample in plado domain As transition proba and cost are computed together by plado engine we override _state_sample(), and update get_transition_value and get_next_state/get_next_state_transition that may be directly called by algos (e.g.lazy-A*). get_next_state cache the last value computed and usually get_transition_value is called afterwards. But if the value is not cached anymore, we call again the plado engine. * update plado domain for autoregressive action prediction We add a multidiscrete encoding of the actions to be used by our autoregressive sb3 algo. Beware that this is not a regular MultiDiscrete Space as we allow -1 values, indicating that the component is actually not needed (for actions with fewer parameters than others) * Implement autoregressive sb3 PPO for PDDL domains
1 parent 0f379c8 commit dd0b928

File tree

22 files changed

+2126
-439
lines changed

22 files changed

+2126
-439
lines changed

skdecide/hub/domain/plado/plado.py

+206-68
Large diffs are not rendered by default.

skdecide/hub/solver/stable_baselines/autoregressive/__init__.py

Whitespace-only changes.

skdecide/hub/solver/stable_baselines/autoregressive/common/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
import torch as th
5+
from sb3_contrib.common.maskable.buffers import MaskableRolloutBuffer
6+
7+
from skdecide.hub.solver.stable_baselines.common.buffers import (
8+
MaskableScikitDecideRolloutBufferMixin,
9+
ScikitDecideRolloutBuffer,
10+
)
11+
12+
13+
class ApplicableActionsRolloutBuffer(
14+
MaskableScikitDecideRolloutBufferMixin,
15+
ScikitDecideRolloutBuffer,
16+
MaskableRolloutBuffer,
17+
):
18+
"""Rollout buffer storing also applicable actions.
19+
20+
For each step, applicable actions are stored as a numpy array N,M with
21+
- N: nb of applicable actions
22+
- M: flattened dim of action space
23+
24+
As the number of applicable actions vary, we have to use a list of numpy arrays
25+
instead of a single numpy array to store them in the buffer.
26+
27+
(And at first it comes as a list of list because of sb3 vectorized environment)
28+
29+
"""
30+
31+
action_masks: list[np.ndarray]
32+
33+
def reset(self) -> None:
34+
super().reset()
35+
self.action_masks = list() # actually storing applicable actions
36+
37+
def _add_action_masks(self, action_masks: Optional[np.ndarray]) -> None:
38+
if action_masks is None or action_masks.shape[0] > 1:
39+
raise NotImplementedError()
40+
41+
self.action_masks.append(action_masks[0])
42+
43+
def _swap_and_flatten_action_masks(self) -> None:
44+
# already done when squeezing first dimension in _add_action_masks()
45+
...
46+
47+
def _get_action_masks_samples(self, batch_inds: np.ndarray) -> list[th.Tensor]:
48+
return [self.to_torch(self.action_masks[idx]) for idx in batch_inds]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import torch as th
4+
from sb3_contrib.common.maskable.distributions import (
5+
MaskableCategorical,
6+
MaskableCategoricalDistribution,
7+
MaskableDistribution,
8+
MaybeMasks,
9+
)
10+
from stable_baselines3.common.distributions import Distribution, SelfDistribution
11+
from torch import nn
12+
13+
14+
class MultiMaskableCategoricalDistribution(Distribution):
15+
"""Distribution for variable-length multidiscrete actions with partial masking on each component.
16+
17+
This is meant for autoregressive prediction.
18+
19+
The distribution is considered as the joint distribution of discrete distributions (MaskableCategoricalDistribution)
20+
with the possibility to mask each marginal.
21+
This distribution is meant to be used for autoregressive action:
22+
- Each component is sampled sequentially
23+
- The partial mask for the next component is conditioned by the previous components
24+
- It is possible to have missing components when this has no meaning for the action.
25+
this corresponds in the simulation to
26+
- either not initialized marginal (if all samples discard the component)
27+
- 0 masks for the given sample (the partial mask row corresponding to the sample has only 0's)
28+
29+
When computing entropy of the distribution or log-probability of an action, we add only contribution
30+
of marginal distributions for which we have an actual component (dropping the one with a 0-mask).
31+
32+
As this distribution is used to sample component by component, the sample(), and mode() methods are left
33+
unimplemented.
34+
35+
"""
36+
37+
def __init__(self, distributions: list[MaskableCategoricalDistribution]):
38+
super().__init__()
39+
self.distributions = distributions
40+
self._ind_valid_samples_by_distributions: list[
41+
Optional[tuple[th.Tensor, th.Tensor]]
42+
] = [None] * len(distributions)
43+
self._all_valid_samples_by_distributions: list[bool] = [False] * len(
44+
distributions
45+
)
46+
self._any_valid_samples_by_distributions: list[bool] = [False] * len(
47+
distributions
48+
)
49+
50+
def get_actions_component(
51+
self, i_component: int, deterministic: bool = False
52+
) -> th.Tensor:
53+
return self.distributions[i_component].get_actions(deterministic=deterministic)
54+
55+
def apply_masking_component(
56+
self, i_component: int, component_masks: MaybeMasks
57+
) -> None:
58+
self.distributions[i_component].apply_masking(masks=component_masks)
59+
# valid samples: at least one 1 in the corresponding mask
60+
valid_samples = component_masks.sum(-1) > 0
61+
self._any_valid_samples_by_distributions[i_component] = valid_samples.all()
62+
self._all_valid_samples_by_distributions[i_component] = valid_samples.any()
63+
# store valid sample indices if not all valid
64+
if (
65+
self._any_valid_samples_by_distributions[i_component]
66+
and not self._all_valid_samples_by_distributions[i_component]
67+
):
68+
self._ind_valid_samples_by_distributions[
69+
i_component
70+
] = valid_samples.nonzero(as_tuple=True)
71+
72+
def set_proba_distribution_component(
73+
self, i_component: int, action_component_logits: th.Tensor
74+
) -> None:
75+
self.distributions[i_component].proba_distribution(
76+
action_logits=action_component_logits
77+
)
78+
self._any_valid_samples_by_distributions[i_component] = True
79+
self._all_valid_samples_by_distributions[i_component] = True
80+
self._ind_valid_samples_by_distributions[i_component] = None
81+
82+
def get_proba_distribution_component_for_valid_samples(
83+
self, i_component: int
84+
) -> Optional[MaskableCategorical]:
85+
if not (self._any_valid_samples_by_distributions[i_component]):
86+
return None
87+
elif self._all_valid_samples_by_distributions[i_component]:
88+
return self.distributions[i_component]
89+
else:
90+
distribution = self.distributions[i_component]
91+
ind_valid_samples = self._ind_valid_samples_by_distributions[i_component]
92+
return MaskableCategorical(
93+
logits=distribution.distribution.logits[ind_valid_samples],
94+
masks=distribution.distribution.masks[ind_valid_samples],
95+
)
96+
97+
def get_proba_distribution_component_batch_shape(
98+
self, i_component: int
99+
) -> Optional[tuple[int, ...]]:
100+
distribution = self.distributions[i_component]
101+
if self.distribution.distribution is None:
102+
return None
103+
else:
104+
return distribution.distribution.logits.shape[:-1]
105+
106+
def log_prob(self, x: th.Tensor) -> th.Tensor:
107+
marginal_logps = []
108+
# loop over marginals but no contribution if not initialized or 0-masked
109+
for i_component, distribution in enumerate(self.distributions):
110+
marginal_dist = self.get_proba_distribution_component_for_valid_samples(
111+
i_component
112+
)
113+
if marginal_dist is not None:
114+
if self._all_valid_samples_by_distributions[i_component]:
115+
marginal_logp = marginal_dist.log_prob(x[:, i_component])
116+
else:
117+
# add only contribution for valid samples
118+
marginal_logp = th.zeros(
119+
self.get_proba_distribution_component_batch_shape(i_component),
120+
dtype=x.dtype,
121+
)
122+
ind_valid_samples = self._ind_valid_samples_by_distributions[
123+
i_component
124+
]
125+
marginal_logp[ind_valid_samples] = marginal_dist.log_prob(
126+
x[ind_valid_samples, i_component]
127+
)
128+
marginal_logps.append(marginal_logp)
129+
130+
return sum(marginal_logps)
131+
132+
def entropy(self) -> Optional[th.Tensor]:
133+
marginal_entropies = []
134+
# loop over marginals but no contribution if not initialized or 0-masked
135+
for i_component, distribution in enumerate(self.distributions):
136+
marginal_dist = self.get_proba_distribution_component_for_valid_samples(
137+
i_component
138+
)
139+
if marginal_dist is not None:
140+
if self._all_valid_samples_by_distributions[i_component]:
141+
marginal_entropy = marginal_dist.entropy()
142+
else:
143+
# add only contribution for valid samples
144+
marginal_entropy = th.zeros(
145+
self.get_proba_distribution_component_batch_shape(i_component),
146+
dtype=marginal_dist.logits.dtype,
147+
)
148+
ind_valid_samples = self._ind_valid_samples_by_distributions[
149+
i_component
150+
]
151+
marginal_entropy[ind_valid_samples] = marginal_dist.entropy()
152+
marginal_entropies.append(marginal_entropy)
153+
154+
return sum(marginal_entropies)
155+
156+
def sample(self) -> th.Tensor:
157+
raise NotImplementedError()
158+
159+
def mode(self) -> th.Tensor:
160+
raise NotImplementedError()
161+
162+
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
163+
raise NotImplementedError()
164+
165+
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
166+
raise NotImplementedError()
167+
168+
def proba_distribution_net(
169+
self, *args, **kwargs
170+
) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
171+
raise NotImplementedError()
172+
173+
def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution:
174+
raise NotImplementedError()

0 commit comments

Comments
 (0)