Skip to content

Commit 8c49273

Browse files
GhassenJedward-bot
authored andcommitted
Refactoring the routing layer.
PiperOrigin-RevId: 352144176
1 parent f36188e commit 8c49273

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

edward2/tensorflow/layers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from edward2.tensorflow.layers.convolutional import Conv2DVariationalDropout
3232
from edward2.tensorflow.layers.convolutional import DepthwiseCondConv2D
3333
from edward2.tensorflow.layers.convolutional import DepthwiseConv2DBatchEnsemble
34+
from edward2.tensorflow.layers.dense import CondDense
3435
from edward2.tensorflow.layers.dense import DenseBatchEnsemble
3536
from edward2.tensorflow.layers.dense import DenseDVI
3637
from edward2.tensorflow.layers.dense import DenseFlipout
@@ -70,13 +71,15 @@
7071
from edward2.tensorflow.layers.recurrent import LSTMCellFlipout
7172
from edward2.tensorflow.layers.recurrent import LSTMCellRank1
7273
from edward2.tensorflow.layers.recurrent import LSTMCellReparameterization
74+
from edward2.tensorflow.layers.routing import RoutingLayer
7375
from edward2.tensorflow.layers.stochastic_output import MixtureLogistic
7476

7577
__all__ = [
7678
"ActNorm",
7779
"Attention",
7880
"BayesianLinearModel",
7981
"CondConv2D",
82+
"CondDense",
8083
"Conv1DBatchEnsemble",
8184
"Conv1DFlipout",
8285
"Conv1DRank1",
@@ -122,6 +125,7 @@
122125
"NeuralProcess",
123126
"RandomFeatureGaussianProcess",
124127
"Reverse",
128+
"RoutingLayer",
125129
"SinkhornAutoregressiveFlow",
126130
"SparseGaussianProcess",
127131
"SpectralNormalization",

edward2/tensorflow/layers/routing.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# coding=utf-8
2+
# Copyright 2021 The Edward2 Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Routing layer for mixture of experts."""
17+
18+
import tensorflow as tf
19+
20+
21+
class RoutingLayer(tf.keras.layers.Layer):
22+
23+
def __init__(num_experts, routing_pooling, routing_fn, k, normalize_routing,
24+
noise_epsilon, **kwargs):
25+
super().__init__(**kwargs)
26+
self.num_experts = num_experts
27+
self.routing_pooling = routing_pooling
28+
self.routing_fn = routing_fn
29+
self.k = k
30+
self.normalize_routing = normalize_routing
31+
self.noise_epsilon = noise_epsilon
32+
self.use_noisy_routing = 'noisy' in routing_fn
33+
self.use_softmax_top_k = routing_fn in [
34+
'softmax_top_k', 'noisy_softmax_top_k'
35+
]
36+
self.use_onehot_top_k = routing_fn in ['onehot_top_k', 'noisy_onehot_top_k']
37+
self.use_sigmoid_activation = routing_fn == 'sigmoid'
38+
self.use_softmax_routing = routing_fn in ['softmax', 'noisy_softmax']
39+
40+
def build(self, input_shape):
41+
input_shape = tf.TensorShape(input_shape)
42+
self.input_size = input_shape[1]
43+
self.kernel_shape = [self.input_size, self.num_experts]
44+
45+
self.w_gate = self.add_weight(
46+
name='w_gate',
47+
shape=self.kernel_shape,
48+
initializer=tf.keras.initializers.Zeros(),
49+
regularizer=None,
50+
constraint=None,
51+
trainable=True,
52+
dtype=self.dtype)
53+
54+
if self.use_noisy_routing:
55+
self.w_noise = self.add_weight(
56+
name='w_gate',
57+
shape=self.kernel_shape,
58+
initializer=tf.keras.initializers.Zeros(),
59+
regularizer=None,
60+
constraint=None,
61+
trainable=True,
62+
dtype=self.dtype)
63+
64+
if self.routing_pooling == 'global_average':
65+
self.pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
66+
elif self.routing_pooling == 'global_max':
67+
self.pooling_layer = tf.keras.layers.GlobalMaxPool2D()
68+
elif self.routing_pooling == 'average_8':
69+
self.pooling_layer = tf.keras.Sequential([
70+
tf.keras.layers.AveragePooling2D(pool_size=8),
71+
tf.keras.layers.Flatten(),
72+
])
73+
elif self.routing_pooling == 'max_8':
74+
self.pooling_layer = tf.keras.Sequential([
75+
tf.keras.layers.MaxPool2D(pool_size=8),
76+
tf.keras.layers.Flatten(),
77+
])
78+
else:
79+
self.pooling_layer = tf.keras.layers.Flatten()
80+
81+
self.built = True
82+
83+
def _rowwise_unsorted_segment_sum(values, indices, n):
84+
"""UnsortedSegmentSum on each row.
85+
86+
Args:
87+
values: a `Tensor` with shape `[batch_size, k]`.
88+
indices: an integer `Tensor` with shape `[batch_size, k]`.
89+
n: an integer.
90+
91+
Returns:
92+
A `Tensor` with the same type as `values` and shape `[batch_size, n]`.
93+
"""
94+
batch, k = tf.unstack(tf.shape(indices), num=2)
95+
indices_flat = tf.reshape(indices, [-1]) + tf.cast(
96+
tf.math.divide(tf.range(batch * k), k) * n, tf.int32)
97+
ret_flat = tf.math.unsorted_segment_sum(
98+
tf.reshape(values, [-1]), indices_flat, batch * n)
99+
return tf.reshape(ret_flat, [batch, n])
100+
101+
def call(self, inputs, training=None):
102+
pooled_inputs = self.pooling_layer(inputs)
103+
routing_weights = tf.linalg.matmul(pooled_inputs, self.w_gate)
104+
105+
if self.use_noisy_routing and training:
106+
raw_noise_stddev = tf.linalg.matmul(pooled_inputs, self.w_noise)
107+
noise_stddev = tf.nn.softplus(raw_noise_stddev) + self.noise_epsilon
108+
routing_weights += tf.random.normal(tf.shape(routing_weights)) * noise_stddev
109+
110+
if self.use_sigmoid_activation:
111+
routing_weights = tf.nn.sigmoid(routing_weights)
112+
elif self.use_softmax_routing:
113+
routing_weights = tf.nn.softmax(routing_weights)
114+
elif self.use_softmax_top_k:
115+
top_values, top_indices = tf.math.top_k(logits,
116+
min(k + 1, self.num_experts))
117+
# top k logits has shape [batch, k]
118+
top_k_values = tf.slice(top_values, [0, 0], [-1, k])
119+
top_k_indices = tf.slice(top_indices, [0, 0], [-1, k])
120+
top_k_gates = tf.nn.softmax(top_k_values)
121+
# This returns a [batch, n] Tensor with 0's in the positions of non-top-k
122+
# expert values.
123+
routing_weights = _rowwise_unsorted_segment_sum(top_k_gates,
124+
top_k_indices,
125+
self.num_experts)
126+
elif self.use_onehot_top_k:
127+
top_values, top_indices = tf.math.top_k(routing_weights, k=self.k)
128+
one_hot_tensor = tf.one_hot(top_indices, depth=self.num_experts)
129+
mask = tf.reduce_sum(one_hot_tensor, axis=1)
130+
routing_weights *= mask
131+
132+
if self.normalize_routing:
133+
normalization = tf.math.reduce_sum(
134+
routing_weights, axis=-1, keepdims=True)
135+
routing_weights /= normalization
136+
137+
return routing_weights
138+
139+
def get_config(self):
140+
config = {
141+
'num_experts': self.num_experts,
142+
'routing_pooling': self.routing_pooling,
143+
'routing_fn': self.routing_fn,
144+
'k': self.k,
145+
'normalize_routing': self.normalize_routing,
146+
'noise_epsilon': self.noise_epsilon,
147+
}
148+
new_config = super().get_config()
149+
new_config.update(config)
150+
return new_config

0 commit comments

Comments
 (0)