Skip to content

Commit f36188e

Browse files
jtainslieedward-bot
authored andcommitted
Minor fix in SpectralNormalization to ensure sigma is scalar.
PiperOrigin-RevId: 351168514
1 parent 614cfa1 commit f36188e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

edward2/tensorflow/layers/normalization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ def update_weights(self):
377377
u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w_reshaped))
378378

379379
sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
380+
# Convert sigma from a 1x1 matrix to a scalar.
381+
sigma = tf.reshape(sigma, [])
380382
u_update_op = self.u.assign(u_hat)
381383
v_update_op = self.v.assign(v_hat)
382384

0 commit comments

Comments
 (0)