Skip to content

Commit 21d4ba1

Browse files
authored
Merge pull request #49 from mfe7/master
mask zeroed episodes in baseline fit, add extra round of regularizati…
2 parents 2608ec9 + 705f18a commit 21d4ba1

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

maml_rl/baseline.py

+12
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,24 @@ def fit(self, episodes):
4848
# sequence_length * batch_size x 1
4949
returns = episodes.returns.view(-1, 1)
5050

51+
# Remove blank (all-zero) episodes that only exist because episode lengths vary
52+
flat_mask = episodes.mask.flatten()
53+
flat_mask_nnz = torch.nonzero(flat_mask)
54+
featmat = featmat[flat_mask_nnz].view(-1, self.feature_size)
55+
returns = returns[flat_mask_nnz].view(-1, 1)
56+
5157
reg_coeff = self._reg_coeff
5258
XT_y = torch.matmul(featmat.t(), returns)
5359
XT_X = torch.matmul(featmat.t(), featmat)
5460
for _ in range(5):
5561
try:
5662
coeffs, _ = torch.lstsq(XT_y, XT_X + reg_coeff * self._eye)
63+
64+
# An extra round of increasing regularization eliminated
65+
# inf or nan in the least-squares solution most of the time
66+
if torch.isnan(coeffs).any() or torch.isinf(coeffs).any():
67+
raise RuntimeError
68+
5769
break
5870
except RuntimeError:
5971
reg_coeff *= 10

0 commit comments

Comments
 (0)