@@ -60,19 +60,19 @@ def __call__(self, inputs, deterministic: Optional[bool] = None):
60
60
"""
61
61
deterministic = merge_param (
62
62
'deterministic' , self .deterministic , deterministic )
63
- if self .rate == 0. :
63
+
64
+ if (self .rate == 0. ) or deterministic :
64
65
return inputs
66
+
65
67
# Prevent gradient NaNs in 1.0 edge-case.
66
68
if self .rate == 1.0 :
67
69
return jnp .zeros_like (inputs )
70
+
68
71
keep_prob = 1. - self .rate
69
- if deterministic :
70
- return inputs
71
- else :
72
- rng = self .make_rng (self .rng_collection )
73
- broadcast_shape = list (inputs .shape )
74
- for dim in self .broadcast_dims :
75
- broadcast_shape [dim ] = 1
76
- mask = random .bernoulli (rng , p = keep_prob , shape = broadcast_shape )
77
- mask = jnp .broadcast_to (mask , inputs .shape )
78
- return lax .select (mask , inputs / keep_prob , jnp .zeros_like (inputs ))
72
+ rng = self .make_rng (self .rng_collection )
73
+ broadcast_shape = list (inputs .shape )
74
+ for dim in self .broadcast_dims :
75
+ broadcast_shape [dim ] = 1
76
+ mask = random .bernoulli (rng , p = keep_prob , shape = broadcast_shape )
77
+ mask = jnp .broadcast_to (mask , inputs .shape )
78
+ return lax .select (mask , inputs / keep_prob , jnp .zeros_like (inputs ))
0 commit comments