@@ -1509,6 +1509,25 @@ def update_curvature_matrix_estimate(
1509
1509
#
1510
1510
1511
1511
1512
+ def compatible_shapes (ref_shape , target_shape ):
1513
+ if len (target_shape ) > len (ref_shape ):
1514
+ raise ValueError ("Target shape should be smaller." )
1515
+ for ref_d , target_d in zip (reversed (ref_shape ), reversed (target_shape )):
1516
+ if ref_d != target_d and target_d != 1 :
1517
+ raise ValueError (f"{ target_shape } is incompatible with { ref_shape } ." )
1518
+
1519
+
1520
+ def compatible_sum (tensor , target_shape , skip_axes ):
1521
+ compatible_shapes (tensor .shape , target_shape )
1522
+ n = tensor .ndim - len (target_shape )
1523
+ axis = [i + n for i , t in enumerate (target_shape )
1524
+ if t == 1 and i + n not in skip_axes ]
1525
+ tensor = jnp .sum (tensor , axis = axis , keepdims = True )
1526
+ axis = [i for i in range (tensor .ndim - len (target_shape ))
1527
+ if i not in skip_axes ]
1528
+ return jnp .sum (tensor , axis = axis )
1529
+
1530
+
1512
1531
class ScaleAndShiftDiagonal (Diagonal ):
1513
1532
"""A diagonal approximation specifically for a scale and shift layers."""
1514
1533
@@ -1539,18 +1558,20 @@ def _update_curvature_matrix_estimate(
1539
1558
assert (state .diagonal_factors [0 ].raw_value .shape ==
1540
1559
self .parameters_shapes [0 ])
1541
1560
scale_shape = estimation_data ["params" ][0 ].shape
1542
- axis = range (x .ndim )[1 :(x .ndim - len (scale_shape ))]
1543
- d_scale = jnp .sum (x * dy , axis = tuple (axis ))
1544
- scale_diag_update = jnp .sum (d_scale * d_scale , axis = 0 ) / batch_size
1561
+ d_scale = compatible_sum (x * dy , scale_shape , skip_axes = [0 ])
1562
+ scale_diag_update = jnp .sum (
1563
+ d_scale * d_scale ,
1564
+ axis = 0 , keepdims = d_scale .ndim == len (scale_shape )
1565
+ ) / batch_size
1545
1566
state .diagonal_factors [0 ].update (scale_diag_update , ema_old , ema_new )
1546
1567
1547
1568
if self .has_shift :
1548
- assert (state .diagonal_factors [- 1 ].raw_value .shape ==
1549
- self .parameters_shapes [- 1 ])
1550
1569
shift_shape = estimation_data ["params" ][- 1 ].shape
1551
- axis = range (x .ndim )[1 :(x .ndim - len (shift_shape ))]
1552
- d_shift = jnp .sum (dy , axis = tuple (axis ))
1553
- shift_diag_update = jnp .sum (d_shift * d_shift , axis = 0 ) / batch_size
1570
+ d_shift = compatible_sum (dy , shift_shape , skip_axes = [0 ])
1571
+ shift_diag_update = jnp .sum (
1572
+ d_shift * d_shift ,
1573
+ axis = 0 , keepdims = d_shift .ndim == len (shift_shape )
1574
+ ) / batch_size
1554
1575
state .diagonal_factors [- 1 ].update (shift_diag_update , ema_old , ema_new )
1555
1576
1556
1577
return state
@@ -1587,16 +1608,14 @@ def update_curvature_matrix_estimate(
1587
1608
if self ._has_scale :
1588
1609
# Scale tangent
1589
1610
scale_shape = estimation_data ["params" ][0 ].shape
1590
- axis = range (x .ndim )[1 :(x .ndim - len (scale_shape ))]
1591
- d_scale = jnp .sum (x * dy , axis = tuple (axis ))
1611
+ d_scale = compatible_sum (x * dy , scale_shape , skip_axes = [0 ])
1592
1612
d_scale = d_scale .reshape ([batch_size , - 1 ])
1593
1613
tangents .append (d_scale )
1594
1614
1595
1615
if self ._has_shift :
1596
1616
# Shift tangent
1597
1617
shift_shape = estimation_data ["params" ][- 1 ].shape
1598
- axis = range (x .ndim )[1 :(x .ndim - len (shift_shape ))]
1599
- d_shift = jnp .sum (dy , axis = tuple (axis ))
1618
+ d_shift = compatible_sum (dy , shift_shape , skip_axes = [0 ])
1600
1619
d_shift = d_shift .reshape ([batch_size , - 1 ])
1601
1620
tangents .append (d_shift )
1602
1621
0 commit comments