Skip to content

Commit 22556d3

Browse files
SandSnip3rtensorflower-gardener
authored andcommitted
Remove dead code from HloRematerialization and fix comment.
PiperOrigin-RevId: 653503548
1 parent f391d84 commit 22556d3

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

third_party/xla/xla/service/hlo_rematerialization.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,18 +179,10 @@ class HloRematerialization : public HloModulePass {
179179
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
180180

181181
protected:
182-
// Rematerializes instructions within the given computation. 'order' is the
183-
// order in which the computation's instructions will be emitted in the
184-
// backend. Rematerialized instructions will be added to the HLO computation
185-
// and inserted into 'order'.
186-
absl::StatusOr<bool> RematerializeComputation(HloComputation* computation,
187-
HloSchedule* schedule,
188-
int64_t memory_limit_bytes,
189-
int64_t min_remat_size) {
190-
return RematerializeComputation(computation, schedule, memory_limit_bytes,
191-
min_remat_size, /*execution_threads=*/{});
192-
}
193-
182+
// Rematerializes instructions within the given computation. 'schedule'
183+
// constains the order in which the computation's instructions will be emitted
184+
// in the backend. Rematerialized instructions will be added to the HLO
185+
// computation and inserted into 'schedule'.
194186
virtual absl::StatusOr<bool> RematerializeComputation(
195187
HloComputation* computation, HloSchedule* schedule,
196188
int64_t memory_limit_bytes, int64_t min_remat_size,

0 commit comments

Comments
 (0)