Skip to content

Commit ffb6dbf

Browse files
saberkuntensorflower-gardener
authored andcommitted
Use sparse_categorical_crossentropy for test as the loss object default does not work on tpustrategy +
the single task trainer already handles the reduction. PiperOrigin-RevId: 367757677
1 parent e353e4e commit ffb6dbf

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

orbit/examples/single_task/single_task_trainer.py

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def train_fn(inputs):
107107
# replicas. This ensures that we don't end up multiplying our loss by
108108
# the number of workers - gradients are summed, not averaged, across
109109
# replicas during the apply_gradients call.
110+
# Note, the reduction of loss is explicitly handled and scaled by
111+
# num_replicas_in_sync. Recommend to use a plain loss function.
112+
# If you're using tf.keras.losses.Loss object, you may need to set
113+
# reduction argument explicitly.
110114
loss = tf.reduce_mean(self.loss_fn(target, output))
111115
scaled_loss = loss / self.strategy.num_replicas_in_sync
112116

orbit/examples/single_task/single_task_trainer_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ def test_single_task_training(self):
3030
tf.keras.Input(shape=(4,), name='features'),
3131
tf.keras.layers.Dense(10, activation=tf.nn.relu),
3232
tf.keras.layers.Dense(10, activation=tf.nn.relu),
33-
tf.keras.layers.Dense(3)
33+
tf.keras.layers.Dense(3),
34+
tf.keras.layers.Softmax(),
3435
])
3536

3637
trainer = single_task_trainer.SingleTaskTrainer(
3738
train_ds,
3839
label_key='label',
3940
model=model,
40-
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
41+
loss_fn=tf.keras.losses.sparse_categorical_crossentropy,
4142
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))
4243

4344
controller = orbit.Controller(

0 commit comments

Comments
 (0)