Skip to content

Commit 4931c2a

Browse files
botevKfacJaxDev
authored and
KfacJaxDev
committed
Making ScaleAndShift blocks begin capable of having parameters that are broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d].
PiperOrigin-RevId: 456070961
1 parent 1ace327 commit 4931c2a

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

kfac_jax/_src/curvature_blocks.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,25 @@ def update_curvature_matrix_estimate(
15091509
#
15101510

15111511

1512+
def compatible_shapes(ref_shape, target_shape):
1513+
if len(target_shape) > len(ref_shape):
1514+
raise ValueError("Target shape should be smaller.")
1515+
for ref_d, target_d in zip(reversed(ref_shape), reversed(target_shape)):
1516+
if ref_d != target_d and target_d != 1:
1517+
raise ValueError(f"{target_shape} is incompatible with {ref_shape}.")
1518+
1519+
1520+
def compatible_sum(tensor, target_shape, skip_axes):
1521+
compatible_shapes(tensor.shape, target_shape)
1522+
n = tensor.ndim - len(target_shape)
1523+
axis = [i + n for i, t in enumerate(target_shape)
1524+
if t == 1 and i + n not in skip_axes]
1525+
tensor = jnp.sum(tensor, axis=axis, keepdims=True)
1526+
axis = [i for i in range(tensor.ndim - len(target_shape))
1527+
if i not in skip_axes]
1528+
return jnp.sum(tensor, axis=axis)
1529+
1530+
15121531
class ScaleAndShiftDiagonal(Diagonal):
15131532
"""A diagonal approximation specifically for a scale and shift layers."""
15141533

@@ -1539,18 +1558,20 @@ def _update_curvature_matrix_estimate(
15391558
assert (state.diagonal_factors[0].raw_value.shape ==
15401559
self.parameters_shapes[0])
15411560
scale_shape = estimation_data["params"][0].shape
1542-
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
1543-
d_scale = jnp.sum(x * dy, axis=tuple(axis))
1544-
scale_diag_update = jnp.sum(d_scale * d_scale, axis=0) / batch_size
1561+
d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0])
1562+
scale_diag_update = jnp.sum(
1563+
d_scale * d_scale,
1564+
axis=0, keepdims=d_scale.ndim == len(scale_shape)
1565+
) / batch_size
15451566
state.diagonal_factors[0].update(scale_diag_update, ema_old, ema_new)
15461567

15471568
if self.has_shift:
1548-
assert (state.diagonal_factors[-1].raw_value.shape ==
1549-
self.parameters_shapes[-1])
15501569
shift_shape = estimation_data["params"][-1].shape
1551-
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
1552-
d_shift = jnp.sum(dy, axis=tuple(axis))
1553-
shift_diag_update = jnp.sum(d_shift * d_shift, axis=0) / batch_size
1570+
d_shift = compatible_sum(dy, shift_shape, skip_axes=[0])
1571+
shift_diag_update = jnp.sum(
1572+
d_shift * d_shift,
1573+
axis=0, keepdims=d_shift.ndim == len(shift_shape)
1574+
) / batch_size
15541575
state.diagonal_factors[-1].update(shift_diag_update, ema_old, ema_new)
15551576

15561577
return state
@@ -1587,16 +1608,14 @@ def update_curvature_matrix_estimate(
15871608
if self._has_scale:
15881609
# Scale tangent
15891610
scale_shape = estimation_data["params"][0].shape
1590-
axis = range(x.ndim)[1:(x.ndim - len(scale_shape))]
1591-
d_scale = jnp.sum(x * dy, axis=tuple(axis))
1611+
d_scale = compatible_sum(x * dy, scale_shape, skip_axes=[0])
15921612
d_scale = d_scale.reshape([batch_size, -1])
15931613
tangents.append(d_scale)
15941614

15951615
if self._has_shift:
15961616
# Shift tangent
15971617
shift_shape = estimation_data["params"][-1].shape
1598-
axis = range(x.ndim)[1:(x.ndim - len(shift_shape))]
1599-
d_shift = jnp.sum(dy, axis=tuple(axis))
1618+
d_shift = compatible_sum(dy, shift_shape, skip_axes=[0])
16001619
d_shift = d_shift.reshape([batch_size, -1])
16011620
tangents.append(d_shift)
16021621

0 commit comments

Comments
 (0)