Skip to content

Commit 5c13e76

Browse files
fehiepsiedward-bot
authored andcommitted
Fix incorrect matmul op in jax spectral normalization. Resolves #534
PiperOrigin-RevId: 470364802
1 parent 1abb22e commit 5c13e76

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

edward2/jax/nn/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __call__(self, inputs: Array, training: bool = True) -> Array:
140140

141141
if self.kernel_apply_kwargs is None:
142142
# By default, we use the implementation in SN-GAN.
143-
kernel_apply = lambda x: w.reshape(-1, w.shape[-1]) @ x
143+
kernel_apply = lambda x: x @ w.reshape(-1, w.shape[-1])
144144
in_shape = (np.prod(w.shape[:-1]),)
145145
else:
146146
# Otherwise, we extract the actual kernel transformation in the input

edward2/jax/nn/normalization_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def setUp(self):
5050
self.norm_multiplier = 0.95
5151

5252
@parameterized.named_parameters(
53-
("Dense", (None, 10), DenseLayer, ed.nn.SpectralNormalization),
53+
("Dense", (None, 3), DenseLayer, ed.nn.SpectralNormalization),
5454
("Conv2D",
5555
(None, 32, 32, 3), Conv2DLayer, ed.nn.SpectralNormalizationConv2D))
5656
def test_spec_norm_magnitude(self, input_shape, layer, norm_wrapper):

0 commit comments

Comments
 (0)