Skip to content

Commit ed0e7aa

Browse files
* Implement capability to send dynamic hyperparameters (other than learning
rate) from the TensorCore to the TPUEmbedding. * Implement frequency aware Adagrad optimizer for TPUEmbedding that uses the above capability. PiperOrigin-RevId: 687122317
1 parent c00def8 commit ed0e7aa

8 files changed

+229
-119
lines changed

tensorflow/core/protobuf/tpu/optimization_parameters.proto

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,45 +34,40 @@ message SimulatedQuantization {
3434
int32 num_buckets = 3;
3535
}
3636

37-
// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The
38-
// actual learning rates are provided as a scalar input list to the
37+
// Dynamic input specification for optimizers in the TPUEmbeddingConfiguration.
38+
// The actual dynamic inputs are provided as a scalar input list to the
3939
// SendTPUEmbeddingGradients Op indexed by their tag specified through the
4040
// following proto.
41-
message DynamicLearningRate {
42-
// For tables where learning rates are dynamically computed and communicated
43-
// to the TPU embedding program, a tag must be specified for the learning
44-
// rate.
41+
message OptimizerDynamicInput {
42+
// For tables where dynamic inputs are needed (e.g., learning rates or other
43+
// dynamic hyperparameters used in optimizers), a tag must be specified for
44+
// the input.
4545
//
46-
// The tag must be a non-negative integer. The total number of unique tags
47-
// must be less than or equal to the number of tables in the TPU embedding
48-
// configuration (a table does not specify any tag if it uses a constant
49-
// learning rate, and specifies exactly one tag if it uses dynamic learning
50-
// rates).
51-
//
52-
// All tags in the range [0, number_of_unique_tags) must be present in the TPU
53-
// embedding configuration, i.e. a tag cannot be skipped if a different tag
54-
// numerically greater than it is used in the configuration.
46+
// The tag must be a non-negative integer. All tags in the range
47+
// [0, number_of_unique_tags) must be present in the TPU embedding
48+
// configuration, i.e. a tag cannot be skipped if a different tag numerically
49+
// greater than it is used in the configuration.
5550
//
5651
// If multiple tables specify the same tag, they *MUST* have
57-
// the same dynamic learning rate, for example, their dynamic learning rate
58-
// could be computed by the same TensorFlow sub-graph. The partitioning of the
52+
// the same dynamic input, for example, their dynamic learning rate could be
53+
// computed by the same TensorFlow sub-graph. The partitioning of the
5954
// embedding layer would be more optimal if the number_of_unique_tags is as
6055
// *LOW* as possible, i.e., if many tables share the same tag.
6156
//
62-
// The learning_rate input of the SendTPUEmbeddingGradients op is used to
63-
// communicate dynamic learning rates to the TPU embedding program.
64-
// The learning_rate input is a list of scalars where the size of the list is
65-
// equal to the number of unique tags. The learning rate associated with a
66-
// particular tag is specified by populating its corresponding index in the
67-
// list of learning_rate scalars.
57+
// The hyper_parameters input of the SendTPUEmbeddingGradients op is used to
58+
// communicate dynamic hyper-parameters to the TPU embedding program.
59+
// The hyper_parameters input is a list of scalars where the size of the list
60+
// is equal to the number of unique tags. The hyper-parameter associated with
61+
// a particular tag is specified by populating its corresponding index in the
62+
// list of scalars.
6863
int32 tag = 1;
6964
}
7065

7166
// Source of learning rate to use.
7267
message LearningRate {
7368
oneof learning_rate {
7469
float constant = 1;
75-
DynamicLearningRate dynamic = 2;
70+
OptimizerDynamicInput dynamic = 2;
7671
}
7772
}
7873

@@ -131,6 +126,53 @@ message BoundedAdagradParameters {
131126
float max_accumulator = 3;
132127
}
133128

129+
// Frequency Aware Adagrad optimizer. This optimizer implements the AdaGrad
130+
// algorithm and further allows to:
131+
// * Scale the learning rate based on frequency of the update. Sparsely updated
132+
// rows are updated with a higher effective learning rate, and frequently
133+
// updated rows are updated with a lower effective learning rate.
134+
// * Decay the growth of the accumulator values.
135+
// * Use L1 / L2 regularization for the weight updates.
136+
//
137+
// The optimization algorithm is shown below.
138+
// counter(new) = counter(old) + 1
139+
// accum(new) = max(accumulator_decay * accum(old) + grad^2,
140+
// initial_accumulator_value)
141+
// lr_scale = min((step_counter / accum(new)) ^ probability_exponent,
142+
// max_lr_multiplier) update = grad * lr_scale / sqrt(accum(new)) if
143+
// (l1_regularization_strength > 0.0):
144+
// update = update + l1_regularization_strength * sign(var(old))
145+
// if (l2_regularization_strength > 0.0):
146+
// update = update + l2_regularization_strength * var(old)
147+
// var(new) = var(old) - lr_scale * grad * update
148+
149+
message FrequencyAwareAdagradParameters {
150+
// The L1 regularization parameter for adjusting the update based on the sign
151+
// of the variable.
152+
float l1_regularization_strength = 1;
153+
154+
// The L2 regularization parameter for adjusting the update based on the
155+
// variable.
156+
float l2_regularization_strength = 2;
157+
158+
// The exponent used for scaling the learning rate based on the sparsity of
159+
// updates.
160+
float probability_exponent = 4;
161+
162+
// The maximum value of the learning rate scale.
163+
float max_lr_multiplier = 3;
164+
165+
// The decay for the Adagrad accumulator.
166+
float accumulator_decay = 5;
167+
168+
// The initial and minimum value for the Adagrad accumulator.
169+
float initial_accumulator_value = 6;
170+
171+
// The tag for identifying the step counter used for the frequency aware
172+
// Adagrad optimizer.
173+
OptimizerDynamicInput step_counter = 7;
174+
}
175+
134176
// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD
135177
// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629
136178
message StochasticGradientDescentParameters {}
@@ -502,7 +544,6 @@ message HotIdReplicationConfiguration {
502544
message OptimizationParameters {
503545
// Learning rate used for updating the embedding layer parameters.
504546
LearningRate learning_rate = 13;
505-
reserved 1; // Old learning rate tag.
506547

507548
// Limits to which to clip the weight values after the backward pass; not
508549
// present means no limits are applied.
@@ -550,6 +591,7 @@ message OptimizationParameters {
550591
AdagradParameters adagrad = 3;
551592
AdagradMomentumParameters adagrad_momentum = 26;
552593
BoundedAdagradParameters bounded_adagrad = 19;
594+
FrequencyAwareAdagradParameters frequency_aware_adagrad = 30;
553595
StochasticGradientDescentParameters stochastic_gradient_descent = 4;
554596
FtrlParameters ftrl = 5;
555597
AdamParameters adam = 6;
@@ -567,9 +609,9 @@ message OptimizationParameters {
567609
AssignParameters assign = 25;
568610
}
569611

570-
reserved 15; // Old use_gradient_accumulation.
612+
reserved 1, 15;
571613

572-
// NEXT_ID: 30
614+
// NEXT_ID: 31
573615
}
574616

575617
// Specification of an optimization algorithm's state variables (both the main

tensorflow/core/tpu/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ cc_library(
2525
hdrs = ["tpu_embedding_configuration_utils.h"],
2626
visibility = ["//visibility:public"],
2727
deps = [
28+
":tpu_embedding_optimization_parameters_utils",
2829
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc",
2930
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
31+
"@com_google_absl//absl/status",
3032
"@com_google_absl//absl/status:statusor",
3133
"@com_google_absl//absl/strings:str_format",
3234
],
@@ -72,6 +74,7 @@ cc_library(
7274
"//tensorflow/core:lib_proto_parsing",
7375
"//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc",
7476
"@com_google_absl//absl/base",
77+
"@com_google_absl//absl/container:flat_hash_set",
7578
"@com_google_absl//absl/status",
7679
"@com_google_absl//absl/strings",
7780
"@local_xla//xla:xla_data_proto_cc",

tensorflow/core/tpu/graph_rewrite/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ cc_library(
398398
srcs = ["tpu_embedding_software_deduplication_rewrite_pass.cc"],
399399
hdrs = ["tpu_embedding_software_deduplication_rewrite_pass.h"],
400400
deps = [
401+
":tpu_embedding_rewrite_pass_utils",
401402
"//tensorflow/core:core_cpu",
402403
"//tensorflow/core:core_cpu_internal",
403404
"//tensorflow/core:framework",
@@ -406,7 +407,6 @@ cc_library(
406407
"//tensorflow/core:protos_all_cc",
407408
"//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
408409
"//tensorflow/core/tpu:tpu_embedding_configuration_utils",
409-
"//tensorflow/core/tpu/graph_rewrite:tpu_embedding_rewrite_pass_utils",
410410
"@com_google_absl//absl/algorithm:container",
411411
"@com_google_absl//absl/container:flat_hash_map",
412412
"@com_google_absl//absl/log",

0 commit comments

Comments
 (0)