Skip to content

Commit 4931c2a

Browse files
botevKfacJaxDev
authored andcommitted
Making ScaleAndShift blocks begin capable of having parameters that are broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d].
PiperOrigin-RevId: 456070961
1 parent 1ace327 commit 4931c2a

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

kfac_jax/_src/curvature_blocks.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,25 @@ def update_curvature_matrix_estimate(
15091509
#
15101510

15111511

1512+
def compatible_shapes(ref_shape, target_shape):
1513+
if len(target_shape) > len(ref_shape):
1514+
raise ValueError("Target shape should be smaller.")
1515+
for ref_d, target_d in zip(reversed(ref_shape), reversed(target_shape)):
1516+
if ref_d != target_d and target_d != 1:
1517+
raise ValueError(f"{target_shape} is incompatible with {ref_shape}.")
1518+
1519+
1520+
def compatible_sum(tensor, target_shape, skip_axes):
1521+
compatible_shapes(tensor.shape, target_shape)
1522+
n = tensor.ndim - len(target_shape)
1523+
axis = [i + n for i, t in enumerate(target_shape)
1524+
if t == 1 and i + n not in skip_axes]
1525+
tensor = jnp.sum(tensor, axis=axis, keepdims=True)
1526+
axis = [i for i in range(tensor.ndim - len(target_shape))
1527+
if i not in skip_axes]
1528+
return jnp.sum(tensor, axis=axis)
1529+
1530+
15121531
class ScaleAndShiftDiagonal(Diagonal):
15131532
"""A diagonal approximation specifically for a scale and shift layers."""
15141533

@@ -1539,18 +1558,20 @@ def _update_curvature_matrix_estimate(
15391558
assert (state.diagonal_factors[0].raw_value.shape ==
15401559
self.parameters_shapes[0])
15411560
scale_shape = estimation_data["params"][0].shape
1542-
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
1543-
d_scale = jnp.sum(x * dy, axis=tuple(axis))
1544-
scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size
1561+
d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0])
1562+
scale_diag_update = jnp.sum(
1563+
d_scale * d_scale,
1564+
axis=0, keepdims=d_scale.ndim == len(scale_shape)
1565+
) / batch_size
15451566
state.diagonal_factors[0].update(scale_diag_update, ema_old, ema_new)
15461567

15471568
if self.has_shift:
1548-
assert (state.diagonal_factors[-1].raw_value.shape ==
1549-
self.parameters_shapes[-1])
15501569
shift_shape = estimation_data["params"][-1].shape
1551-
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
1552-
d_shift = jnp.sum(dy, axis=tuple(axis))
1553-
shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size
1570+
d_shift = compatible_sum(dy, shift_shape, skip_axes=[0])
1571+
shift_diag_update = jnp.sum(
1572+
d_shift * d_shift,
1573+
axis=0, keepdims=d_shift.ndim == len(shift_shape)
1574+
) / batch_size
15541575
state.diagonal_factors[-1].update(shift_diag_update, ema_old, ema_new)
15551576

15561577
return state
@@ -1587,16 +1608,14 @@ def update_curvature_matrix_estimate(
15871608
if self._has_scale:
15881609
# Scale tangent
15891610
scale_shape = estimation_data["params"][0].shape
1590-
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
1591-
d_scale = jnp.sum(x * dy, axis=tuple(axis))
1611+
d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0])
15921612
d_scale = d_scale.reshape([batch_size, -1])
15931613
tangents.append(d_scale)
15941614

15951615
if self._has_shift:
15961616
# Shift tangent
15971617
shift_shape = estimation_data["params"][-1].shape
1598-
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
1599-
d_shift = jnp.sum(dy, axis=tuple(axis))
1618+
d_shift = compatible_sum(dy, shift_shape, skip_axes=[0])
16001619
d_shift = d_shift.reshape([batch_size, -1])
16011620
tangents.append(d_shift)
16021621

0 commit comments

Comments
 (0)