Skip to content

Commit bb53c63

Browse files
kenfrankosonnet-copybara
authored andcommitted
Apply name change(experimental_run_v2 -> run) for all callers.
PiperOrigin-RevId: 301919882 Change-Id: I14c6ed85bdf50d619d1bc572e0fbcc5f1821c70b
1 parent de30fda commit bb53c63

File tree

7 files changed

+19
-14
lines changed

7 files changed

+19
-14
lines changed

examples/distributed_cifar10.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@
405405
"\n",
406406
"@tf.function\n",
407407
"def train_step(images, labels):\n",
408-
" per_replica_loss = strategy.experimental_run_v2(step, args=(images, labels))\n",
408+
" per_replica_loss = strategy.run(step, args=(images, labels))\n",
409409
" return strategy.reduce(\"sum\", per_replica_loss, axis=None)\n",
410410
"\n",
411411
"def train_epoch(dataset):\n",
@@ -465,7 +465,7 @@
465465
" total_correct = 0\n",
466466
"\n",
467467
" for images, labels in cifar10_test_dist:\n",
468-
" per_replica_correct = strategy.experimental_run_v2(is_predicted, args=(images, labels))\n",
468+
" per_replica_correct = strategy.run(is_predicted, args=(images, labels))\n",
469469
" total_correct += strategy.reduce(\"sum\", per_replica_correct, axis=0)\n",
470470
"\n",
471471
" return tf.cast(total_correct, tf.float32) / num_cifar10_test_examples\n",

sonnet/src/conformance/checkpoint_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def test_save_restore(self, golden, replicator_fn, use_function):
195195
variables = golden.create_all_variables(module)
196196

197197
def forward():
198-
per_replica = replicator.experimental_run_v2(
198+
per_replica = replicator.run(
199199
lambda: golden.forward(module))
200200
return tree.map_structure(
201201
lambda args: tf.stack(replicator.unwrap(args), axis=0), per_replica)
@@ -350,7 +350,7 @@ def test_restore_on_create_in_replica_context(self, golden, replicator_fn,
350350
module = golden.create_module()
351351

352352
def forward():
353-
return replicator.experimental_run_v2(lambda: golden.forward(module))
353+
return replicator.run(lambda: golden.forward(module))
354354

355355
if use_function:
356356
forward = tf.function(forward)

sonnet/src/conformance/distribute_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_variable_creation_in_replica_context(self, golden, replicator_fn):
4444
@tf.function
4545
def forward():
4646
step = lambda: golden.create_all_variables(mod)
47-
return replicator.experimental_run_v2(step)
47+
return replicator.run(step)
4848

4949
# TODO(b/132329316) Remove when `xla.compile` allows tf.device(TPU).
5050
with tf.device(None):
@@ -83,7 +83,7 @@ def forward():
8383
state = core.initial_state(input_shape[0])
8484
return unroll(core, sequence, state)
8585

86-
return replicator.experimental_run_v2(forward)
86+
return replicator.run(forward)
8787

8888
# TpuReplicator doesn't support pure eager mode.
8989
if isinstance(replicator, snt_replicator.TpuReplicator):

sonnet/src/distribute/batch_norm_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def foo():
6565
outputs = layer(inputs, True, False, scale, offset)
6666
return inputs, outputs
6767

68-
inputs, outputs = strategy.experimental_run_v2(foo)
68+
inputs, outputs = strategy.run(foo)
6969
local_mean_metric = strategy.experimental_local_results(mean_metric.value)
7070
local_var_metric = strategy.experimental_local_results(var_metric.value)
7171
self.assertAllEqual(local_mean_metric[0].numpy(),
@@ -101,7 +101,7 @@ def compute():
101101
outputs = layer(inputs, True, False, scale, offset)
102102
return inputs, outputs
103103

104-
return strategy.experimental_run_v2(compute)
104+
return strategy.run(compute)
105105
inputs, outputs = run()
106106

107107
local_mean_metric = strategy.experimental_local_results(mean_metric.value)

sonnet/src/distribute/replicator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Replicator(tf.distribute.MirroredStrategy):
7676
Finally we use the run API to apply ``forward`` in parallel on all accelerator
7777
devices:
7878
79-
>>> per_replica_y = replicator.experimental_run_v2(forward)
79+
>>> per_replica_y = replicator.run(forward)
8080
"""
8181

8282
@contextlib.contextmanager
@@ -128,7 +128,7 @@ class TpuReplicator(tf.distribute.experimental.TPUStrategy):
128128
129129
>>> @tf.function(autograph=False)
130130
... def all_forward():
131-
... return replicator.experimental_run_v2(forward)
131+
... return replicator.run(forward)
132132
>>> per_replica_y = all_forward()
133133
"""
134134

@@ -139,6 +139,11 @@ def scope(self):
139139
stack.enter_context(tf.variable_creator_scope(replica_local_creator))
140140
yield
141141

142+
# TODO(tomhennigan) Remove this once TF 2.2 is released.
143+
for cls in (Replicator, TpuReplicator):
144+
if not hasattr(cls, "run"):
145+
cls.run = cls.experimental_run_v2
146+
142147

143148
def create_variables_eagerly(f: Callable[..., T]) -> Callable[..., T]:
144149
"""Wraps a function and attempts to create variables using eager mode.

sonnet/src/distribute/replicator_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _create_variable_in_replica_context(replicator):
4545
o = TrainableVariable()
4646

4747
def create_var():
48-
replicator.experimental_run_v2(o)
48+
replicator.run(o)
4949

5050
# TpuReplicator doesn't support pure eager mode.
5151
if isinstance(replicator, snt_replicator.TpuReplicator):
@@ -123,7 +123,7 @@ def test_assign(self, replicator_fn, method_name, value, cross_replica):
123123
# TpuReplicator doesn't support pure eager mode.
124124
if isinstance(replicator, snt_replicator.TpuReplicator):
125125
update_fn = tf.function(update_fn)
126-
replicator.experimental_run_v2(update_fn)
126+
replicator.run(update_fn)
127127
for component in v._values:
128128
self.assertAllEqual(component.read_value(), tf.ones_like(component))
129129

@@ -144,7 +144,7 @@ def test_read_value(self, replicator_fn, cross_replica):
144144
read_value_fn = tf.function(v.read_value)
145145
else:
146146
read_value_fn = v.read_value
147-
values = replicator.experimental_run_v2(read_value_fn)
147+
values = replicator.run(read_value_fn)
148148
values = replicator.experimental_local_results(values)
149149
for component in v._values:
150150
for value in values:

sonnet/src/optimizers/optimizer_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def testUnsuppportedStrategyError(self):
114114
with self.assertRaisesRegexp(
115115
ValueError,
116116
"Sonnet optimizers are not compatible with `MirroredStrategy`"):
117-
strategy.experimental_run_v2(lambda: optimizer.apply(updates, parameters))
117+
strategy.run(lambda: optimizer.apply(updates, parameters))
118118

119119

120120
# NOTE: Avoiding ABCMeta because of metaclass conflict.

0 commit comments

Comments
 (0)