File tree 1 file changed +12
-0
lines changed
1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -48,12 +48,24 @@ def fit(self, episodes):
48
48
# sequence_length * batch_size x 1
49
49
returns = episodes .returns .view (- 1 , 1 )
50
50
51
+ # Remove blank (all-zero) episodes that only exist because episode lengths vary
52
+ flat_mask = episodes .mask .flatten ()
53
+ flat_mask_nnz = torch .nonzero (flat_mask )
54
+ featmat = featmat [flat_mask_nnz ].view (- 1 , self .feature_size )
55
+ returns = returns [flat_mask_nnz ].view (- 1 , 1 )
56
+
51
57
reg_coeff = self ._reg_coeff
52
58
XT_y = torch .matmul (featmat .t (), returns )
53
59
XT_X = torch .matmul (featmat .t (), featmat )
54
60
for _ in range (5 ):
55
61
try :
56
62
coeffs , _ = torch .lstsq (XT_y , XT_X + reg_coeff * self ._eye )
63
+
64
+ # An extra round of increasing regularization eliminated
65
+ # inf or nan in the least-squares solution most of the time
66
+ if torch .isnan (coeffs ).any () or torch .isinf (coeffs ).any ():
67
+ raise RuntimeError
68
+
57
69
break
58
70
except RuntimeError :
59
71
reg_coeff *= 10
You can’t perform that action at this time.
0 commit comments