Skip to content

Commit 3828f4f

Browse files
GhassenJedward-bot
authored andcommitted
Refactoring the routing layer to enable noisy gating at training time.
PiperOrigin-RevId: 352144176
1 parent f36188e commit 3828f4f

File tree

3 files changed

+234
-0
lines changed

3 files changed

+234
-0
lines changed

edward2/tensorflow/layers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from edward2.tensorflow.layers.convolutional import Conv2DVariationalDropout
3232
from edward2.tensorflow.layers.convolutional import DepthwiseCondConv2D
3333
from edward2.tensorflow.layers.convolutional import DepthwiseConv2DBatchEnsemble
34+
from edward2.tensorflow.layers.dense import CondDense
3435
from edward2.tensorflow.layers.dense import DenseBatchEnsemble
3536
from edward2.tensorflow.layers.dense import DenseDVI
3637
from edward2.tensorflow.layers.dense import DenseFlipout
@@ -70,13 +71,15 @@
7071
from edward2.tensorflow.layers.recurrent import LSTMCellFlipout
7172
from edward2.tensorflow.layers.recurrent import LSTMCellRank1
7273
from edward2.tensorflow.layers.recurrent import LSTMCellReparameterization
74+
from edward2.tensorflow.layers.routing import RoutingLayer
7375
from edward2.tensorflow.layers.stochastic_output import MixtureLogistic
7476

7577
__all__ = [
7678
"ActNorm",
7779
"Attention",
7880
"BayesianLinearModel",
7981
"CondConv2D",
82+
"CondDense",
8083
"Conv1DBatchEnsemble",
8184
"Conv1DFlipout",
8285
"Conv1DRank1",
@@ -122,6 +125,7 @@
122125
"NeuralProcess",
123126
"RandomFeatureGaussianProcess",
124127
"Reverse",
128+
"RoutingLayer",
125129
"SinkhornAutoregressiveFlow",
126130
"SparseGaussianProcess",
127131
"SpectralNormalization",

edward2/tensorflow/layers/routing.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# coding=utf-8
2+
# Copyright 2021 The Edward2 Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Routing layer for mixture of experts."""
17+
18+
import tensorflow as tf
19+
from edward2.tensorflow.layers import routing_utils
20+
21+
class RoutingLayer(tf.keras.layers.Layer):
22+
23+
def __init__(self,
24+
num_experts, routing_pooling, routing_fn, k, normalize_routing,
25+
noise_epsilon, **kwargs):
26+
super().__init__(**kwargs)
27+
self.num_experts = num_experts
28+
self.routing_pooling = routing_pooling
29+
self.routing_fn = routing_fn
30+
self.k = k
31+
self.normalize_routing = normalize_routing
32+
self.noise_epsilon = noise_epsilon
33+
self.use_noisy_routing = 'noisy' in routing_fn
34+
self.use_softmax_top_k = routing_fn in [
35+
'softmax_top_k', 'noisy_softmax_top_k'
36+
]
37+
self.use_onehot_top_k = routing_fn in ['onehot_top_k', 'noisy_onehot_top_k']
38+
self.use_sigmoid_activation = routing_fn == 'sigmoid'
39+
self.use_softmax_routing = routing_fn in ['softmax', 'noisy_softmax']
40+
41+
def build(self, input_shape):
42+
input_shape = tf.TensorShape(input_shape)
43+
self.input_size = input_shape[1]
44+
self.kernel_shape = [self.input_size, self.num_experts]
45+
46+
self.w_gate = self.add_weight(
47+
name='w_gate',
48+
shape=self.kernel_shape,
49+
initializer=tf.keras.initializers.Zeros(),
50+
regularizer=None,
51+
constraint=None,
52+
trainable=True,
53+
dtype=self.dtype)
54+
55+
if self.use_noisy_routing:
56+
self.w_noise = self.add_weight(
57+
name='w_gate',
58+
shape=self.kernel_shape,
59+
initializer=tf.keras.initializers.Zeros(),
60+
regularizer=None,
61+
constraint=None,
62+
trainable=True,
63+
dtype=self.dtype)
64+
65+
if self.routing_pooling == 'global_average':
66+
self.pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
67+
elif self.routing_pooling == 'global_max':
68+
self.pooling_layer = tf.keras.layers.GlobalMaxPool2D()
69+
elif self.routing_pooling == 'average_8':
70+
self.pooling_layer = tf.keras.Sequential([
71+
tf.keras.layers.AveragePooling2D(pool_size=8),
72+
tf.keras.layers.Flatten(),
73+
])
74+
elif self.routing_pooling == 'max_8':
75+
self.pooling_layer = tf.keras.Sequential([
76+
tf.keras.layers.MaxPool2D(pool_size=8),
77+
tf.keras.layers.Flatten(),
78+
])
79+
else:
80+
self.pooling_layer = tf.keras.layers.Flatten()
81+
82+
self.built = True
83+
84+
def call(self, inputs, training=None):
85+
pooled_inputs = self.pooling_layer(inputs)
86+
routing_weights = tf.linalg.matmul(pooled_inputs, self.w_gate)
87+
88+
if self.use_noisy_routing and training:
89+
raw_noise_stddev = tf.linalg.matmul(pooled_inputs, self.w_noise)
90+
noise_stddev = tf.nn.softplus(raw_noise_stddev) + self.noise_epsilon
91+
routing_weights += tf.random.normal(tf.shape(routing_weights)) * noise_stddev
92+
93+
if self.use_sigmoid_activation:
94+
routing_weights = tf.nn.sigmoid(routing_weights)
95+
elif self.use_softmax_routing:
96+
routing_weights = tf.nn.softmax(routing_weights)
97+
elif self.use_softmax_top_k:
98+
top_values, top_indices = tf.math.top_k(routing_weights,
99+
min(self.k + 1, self.num_experts))
100+
# top k logits has shape [batch, k]
101+
top_k_values = tf.slice(top_values, [0, 0], [-1, self.k])
102+
top_k_indices = tf.slice(top_indices, [0, 0], [-1, self.k])
103+
top_k_gates = tf.nn.softmax(top_k_values)
104+
# This returns a [batch, n] Tensor with 0's in the positions of non-top-k
105+
# expert values.
106+
routing_weights = routing_utils.rowwise_unsorted_segment_sum(top_k_gates,
107+
top_k_indices,
108+
self.num_experts)
109+
elif self.use_onehot_top_k:
110+
top_values, top_indices = tf.math.top_k(routing_weights, k=self.k)
111+
one_hot_tensor = tf.one_hot(top_indices, depth=self.num_experts)
112+
mask = tf.reduce_sum(one_hot_tensor, axis=1)
113+
routing_weights *= mask
114+
115+
if self.normalize_routing:
116+
normalization = tf.math.reduce_sum(
117+
routing_weights, axis=-1, keepdims=True)
118+
routing_weights /= normalization
119+
120+
return routing_weights
121+
122+
def get_config(self):
123+
config = {
124+
'num_experts': self.num_experts,
125+
'routing_pooling': self.routing_pooling,
126+
'routing_fn': self.routing_fn,
127+
'k': self.k,
128+
'normalize_routing': self.normalize_routing,
129+
'noise_epsilon': self.noise_epsilon,
130+
}
131+
new_config = super().get_config()
132+
new_config.update(config)
133+
return new_config
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# coding=utf-8
2+
# Copyright 2021 The Edward2 Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Routing utils."""
17+
import tensorflow as tf
18+
19+
20+
def rowwise_unsorted_segment_sum(values, indices, n):
21+
"""UnsortedSegmentSum on each row.
22+
23+
Args:
24+
values: a `Tensor` with shape `[batch_size, k]`.
25+
indices: an integer `Tensor` with shape `[batch_size, k]`.
26+
n: an integer.
27+
28+
Returns:
29+
A `Tensor` with the same type as `values` and shape `[batch_size, n]`.
30+
"""
31+
batch, k = tf.unstack(tf.shape(indices), num=2)
32+
indices_flat = tf.reshape(indices, [-1]) + tf.cast(
33+
tf.math.divide(tf.range(batch * k), k) * n, tf.int32)
34+
ret_flat = tf.math.unsorted_segment_sum(
35+
tf.reshape(values, [-1]), indices_flat, batch * n)
36+
return tf.reshape(ret_flat, [batch, n])
37+
38+
39+
def normal_distribution_cdf(x, stddev):
40+
"""Evaluates the CDF of the normal distribution.
41+
42+
Normal distribution with mean 0 and standard deviation stddev,
43+
evaluated at x=x.
44+
input and output `Tensor`s have matching shapes.
45+
Args:
46+
x: a `Tensor`
47+
stddev: a `Tensor` with the same shape as `x`.
48+
49+
Returns:
50+
a `Tensor` with the same shape as `x`.
51+
"""
52+
return 0.5 * (1.0 + tf.erf(x / (tf.math.sqrt(2) * stddev + 1e-20)))
53+
54+
55+
def prob_in_top_k(clean_values, noisy_values, noise_stddev, noisy_top_values,
56+
k):
57+
"""Helper function to NoisyTopKGating.
58+
59+
Computes the probability that value is in top k, given different random noise.
60+
This gives us a way of backpropagating from a loss that balances the number
61+
of times each expert is in the top k experts per example.
62+
In the case of no noise, pass in None for noise_stddev, and the result will
63+
not be differentiable.
64+
Args:
65+
clean_values: a `Tensor` of shape [batch, n].
66+
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
67+
normally distributed noise with standard deviation noise_stddev.
68+
noise_stddev: a `Tensor` of shape [batch, n], or None
69+
noisy_top_values: a `Tensor` of shape [batch, m]. "values" Output of
70+
tf.top_k(noisy_top_values, m). m >= k+1
71+
k: an integer.
72+
73+
Returns:
74+
a `Tensor` of shape [batch, n].
75+
"""
76+
batch = tf.shape(clean_values)[0]
77+
m = tf.shape(noisy_top_values)[1]
78+
top_values_flat = tf.reshape(noisy_top_values, [-1])
79+
# we want to compute the threshold that a particular value would have to
80+
# exceed in order to make the top k. This computation differs depending
81+
# on whether the value is already in the top k.
82+
threshold_positions_if_in = tf.range(batch) * m + k
83+
threshold_if_in = tf.expand_dims(
84+
tf.gather(top_values_flat, threshold_positions_if_in), 1)
85+
is_in = tf.greater(noisy_values, threshold_if_in)
86+
if noise_stddev is None:
87+
return tf.to_float(is_in)
88+
threshold_positions_if_out = threshold_positions_if_in - 1
89+
threshold_if_out = tf.expand_dims(
90+
tf.gather(top_values_flat, threshold_positions_if_out), 1)
91+
# is each value currently in the top k.
92+
prob_if_in = normal_distribution_cdf(clean_values - threshold_if_in,
93+
noise_stddev)
94+
prob_if_out = normal_distribution_cdf(clean_values - threshold_if_out,
95+
noise_stddev)
96+
prob = tf.where(is_in, prob_if_in, prob_if_out)
97+
return prob

0 commit comments

Comments
 (0)