File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -300,7 +300,7 @@ def __init__(self, zerosum_axes):
300
300
301
301
@staticmethod
302
302
def extend_axis (array , axis ):
303
- n = (array .shape [axis ] + 1 ). astype ( "floatX" )
303
+ n = pt . cast (array .shape [axis ] + 1 , "floatX" )
304
304
sum_vals = array .sum (axis , keepdims = True )
305
305
norm = sum_vals / (pt .sqrt (n ) + n )
306
306
fill_val = norm - sum_vals / pt .sqrt (n )
@@ -312,7 +312,7 @@ def extend_axis(array, axis):
312
312
def extend_axis_rev (array , axis ):
313
313
normalized_axis = normalize_axis_tuple (axis , array .ndim )[0 ]
314
314
315
- n = array .shape [normalized_axis ]. astype ( "floatX" )
315
+ n = pt . cast ( array .shape [normalized_axis ], "floatX" )
316
316
last = pt .take (array , [- 1 ], axis = normalized_axis )
317
317
318
318
sum_vals = - last * pt .sqrt (n )
Original file line number Diff line number Diff line change @@ -170,6 +170,17 @@ def test_sum_to_1():
170
170
)
171
171
172
172
173
+ def test_zerosumtransform ():
174
+ zst = tr .ZeroSumTransform ([0 ])
175
+
176
+ # Check numpy input works, as it is not always converted to pytensor before
177
+ # Case where it failed was when setting initvals in model
178
+ val = np .array ([1 , 2 , 3 , 4 ])
179
+ zval = zst .backward (val )
180
+ assert np .allclose (zval .eval ().sum (), 0.0 )
181
+ assert np .allclose (zst .forward (zval ).eval (), val )
182
+
183
+
173
184
def test_log ():
174
185
check_transform (tr .log , Rplusbig )
175
186
You can’t perform that action at this time.
0 commit comments