Skip to content

Refactoring the routing layer. #449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions edward2/tensorflow/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from edward2.tensorflow.layers.convolutional import Conv2DVariationalDropout
from edward2.tensorflow.layers.convolutional import DepthwiseCondConv2D
from edward2.tensorflow.layers.convolutional import DepthwiseConv2DBatchEnsemble
from edward2.tensorflow.layers.dense import CondDense
from edward2.tensorflow.layers.dense import DenseBatchEnsemble
from edward2.tensorflow.layers.dense import DenseDVI
from edward2.tensorflow.layers.dense import DenseFlipout
Expand Down Expand Up @@ -70,13 +71,15 @@
from edward2.tensorflow.layers.recurrent import LSTMCellFlipout
from edward2.tensorflow.layers.recurrent import LSTMCellRank1
from edward2.tensorflow.layers.recurrent import LSTMCellReparameterization
from edward2.tensorflow.layers.routing import RoutingLayer
from edward2.tensorflow.layers.stochastic_output import MixtureLogistic

__all__ = [
"ActNorm",
"Attention",
"BayesianLinearModel",
"CondConv2D",
"CondDense",
"Conv1DBatchEnsemble",
"Conv1DFlipout",
"Conv1DRank1",
Expand Down Expand Up @@ -122,6 +125,7 @@
"NeuralProcess",
"RandomFeatureGaussianProcess",
"Reverse",
"RoutingLayer",
"SinkhornAutoregressiveFlow",
"SparseGaussianProcess",
"SpectralNormalization",
Expand Down
133 changes: 133 additions & 0 deletions edward2/tensorflow/layers/routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2021 The Edward2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Routing layer for mixture of experts."""

import tensorflow as tf
from edward2.tensorflow.layers import routing_utils


class RoutingLayer(tf.keras.layers.Layer):

def __init__(self, num_experts, routing_pooling, routing_fn, k,
normalize_routing, noise_epsilon, **kwargs):
super().__init__(**kwargs)
self.num_experts = num_experts
self.routing_pooling = routing_pooling
self.routing_fn = routing_fn
self.k = k
self.normalize_routing = normalize_routing
self.noise_epsilon = noise_epsilon
self.use_noisy_routing = 'noisy' in routing_fn
self.use_softmax_top_k = routing_fn in [
'softmax_top_k', 'noisy_softmax_top_k'
]
self.use_onehot_top_k = routing_fn in ['onehot_top_k', 'noisy_onehot_top_k']
self.use_sigmoid_activation = routing_fn == 'sigmoid'
self.use_softmax_routing = routing_fn in ['softmax', 'noisy_softmax']

def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
self.input_size = input_shape[1]
self.kernel_shape = [self.input_size, self.num_experts]

self.w_gate = self.add_weight(
name='w_gate',
shape=self.kernel_shape,
initializer=tf.keras.initializers.Zeros(),
regularizer=None,
constraint=None,
trainable=True,
dtype=self.dtype)

if self.use_noisy_routing:
self.w_noise = self.add_weight(
name='w_gate',
shape=self.kernel_shape,
initializer=tf.keras.initializers.Zeros(),
regularizer=None,
constraint=None,
trainable=True,
dtype=self.dtype)

if self.routing_pooling == 'global_average':
self.pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
elif self.routing_pooling == 'global_max':
self.pooling_layer = tf.keras.layers.GlobalMaxPool2D()
elif self.routing_pooling == 'average_8':
self.pooling_layer = tf.keras.Sequential([
tf.keras.layers.AveragePooling2D(pool_size=8),
tf.keras.layers.Flatten(),
])
elif self.routing_pooling == 'max_8':
self.pooling_layer = tf.keras.Sequential([
tf.keras.layers.MaxPool2D(pool_size=8),
tf.keras.layers.Flatten(),
])
else:
self.pooling_layer = tf.keras.layers.Flatten()

self.built = True

def call(self, inputs, training=None):
pooled_inputs = self.pooling_layer(inputs)
routing_weights = tf.linalg.matmul(pooled_inputs, self.w_gate)

if self.use_noisy_routing and training:
raw_noise_stddev = tf.linalg.matmul(pooled_inputs, self.w_noise)
noise_stddev = tf.nn.softplus(raw_noise_stddev) + self.noise_epsilon
routing_weights += tf.random.normal(
tf.shape(routing_weights)) * noise_stddev

if self.use_sigmoid_activation:
routing_weights = tf.nn.sigmoid(routing_weights)
elif self.use_softmax_routing:
routing_weights = tf.nn.softmax(routing_weights)
elif self.use_softmax_top_k:
top_values, top_indices = tf.math.top_k(routing_weights,
min(self.k + 1, self.num_experts))
# top k logits has shape [batch, k]
top_k_values = tf.slice(top_values, [0, 0], [-1, self.k])
top_k_indices = tf.slice(top_indices, [0, 0], [-1, self.k])
top_k_gates = tf.nn.softmax(top_k_values)
# This returns a [batch, n] Tensor with 0's in the positions of non-top-k
# expert values.
routing_weights = routing_utils.rowwise_unsorted_segment_sum(
top_k_gates, top_k_indices, self.num_experts)
elif self.use_onehot_top_k:
top_values, top_indices = tf.math.top_k(routing_weights, k=self.k)
one_hot_tensor = tf.one_hot(top_indices, depth=self.num_experts)
mask = tf.reduce_sum(one_hot_tensor, axis=1)
routing_weights *= mask

if self.normalize_routing:
normalization = tf.math.reduce_sum(
routing_weights, axis=-1, keepdims=True)
routing_weights /= normalization

return routing_weights

def get_config(self):
config = {
'num_experts': self.num_experts,
'routing_pooling': self.routing_pooling,
'routing_fn': self.routing_fn,
'k': self.k,
'normalize_routing': self.normalize_routing,
'noise_epsilon': self.noise_epsilon,
}
new_config = super().get_config()
new_config.update(config)
return new_config
97 changes: 97 additions & 0 deletions edward2/tensorflow/layers/routing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# coding=utf-8
# Copyright 2021 The Edward2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Routing utils."""
import tensorflow as tf


def rowwise_unsorted_segment_sum(values, indices, n):
"""UnsortedSegmentSum on each row.
Args:
values: a `Tensor` with shape `[batch_size, k]`.
indices: an integer `Tensor` with shape `[batch_size, k]`.
n: an integer.
Returns:
A `Tensor` with the same type as `values` and shape `[batch_size, n]`.
"""
batch, k = tf.unstack(tf.shape(indices), num=2)
indices_flat = tf.reshape(indices, [-1]) + tf.cast(
tf.math.divide(tf.range(batch * k), k) * n, tf.int32)
ret_flat = tf.math.unsorted_segment_sum(
tf.reshape(values, [-1]), indices_flat, batch * n)
return tf.reshape(ret_flat, [batch, n])


def normal_distribution_cdf(x, stddev):
"""Evaluates the CDF of the normal distribution.
Normal distribution with mean 0 and standard deviation stddev,
evaluated at x=x.
input and output `Tensor`s have matching shapes.
Args:
x: a `Tensor`
stddev: a `Tensor` with the same shape as `x`.
Returns:
a `Tensor` with the same shape as `x`.
"""
return 0.5 * (1.0 + tf.erf(x / (tf.math.sqrt(2) * stddev + 1e-20)))


def prob_in_top_k(clean_values, noisy_values, noise_stddev, noisy_top_values,
k):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m]. "values" Output of
tf.top_k(noisy_top_values, m). m >= k+1
k: an integer.
Returns:
a `Tensor` of shape [batch, n].
"""
batch = tf.shape(clean_values)[0]
m = tf.shape(noisy_top_values)[1]
top_values_flat = tf.reshape(noisy_top_values, [-1])
# we want to compute the threshold that a particular value would have to
# exceed in order to make the top k. This computation differs depending
# on whether the value is already in the top k.
threshold_positions_if_in = tf.range(batch) * m + k
threshold_if_in = tf.expand_dims(
tf.gather(top_values_flat, threshold_positions_if_in), 1)
is_in = tf.greater(noisy_values, threshold_if_in)
if noise_stddev is None:
return tf.to_float(is_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = tf.expand_dims(
tf.gather(top_values_flat, threshold_positions_if_out), 1)
# is each value currently in the top k.
prob_if_in = normal_distribution_cdf(clean_values - threshold_if_in,
noise_stddev)
prob_if_out = normal_distribution_cdf(clean_values - threshold_if_out,
noise_stddev)
prob = tf.where(is_in, prob_if_in, prob_if_out)
return prob