Skip to content

Commit e11b245

Browse files
Fix gradient checkpointing (#2997)
1 parent 32f474d commit e11b245

File tree

12 files changed

+134
-50
lines changed

12 files changed

+134
-50
lines changed

Cargo.lock

+16-16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ portable-atomic = { version = "1.11.0" }
157157
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
158158

159159
### For the main burn branch. ###
160-
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "eaf4b4b4814c6f5f8ea6d07184e3b4d4dba3b3ae" }
161-
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "eaf4b4b4814c6f5f8ea6d07184e3b4d4dba3b3ae" }
162-
cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "eaf4b4b4814c6f5f8ea6d07184e3b4d4dba3b3ae" }
160+
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c93be864495d1016b4223eb014c64b01d79aa56a" }
161+
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c93be864495d1016b4223eb014c64b01d79aa56a" }
162+
cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c93be864495d1016b4223eb014c64b01d79aa56a" }
163163
### For local development. ###
164164
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
165165
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

crates/burn-autodiff/src/checkpoint/builder.rs

+2-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
collections::HashMap,
3-
graph::{ComputingProperty, NodeID, NodeSteps},
3+
graph::{ComputingProperty, NodeID},
44
tensor::AutodiffTensor,
55
};
66
use alloc::{boxed::Box, sync::Arc, vec::Vec};
@@ -108,8 +108,7 @@ impl CheckpointerBuilder {
108108
}
109109
}
110110

111-
pub(crate) fn build(self, graph: &NodeSteps) -> Checkpointer {
112-
let node_tree = self.make_tree(graph);
111+
pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer {
113112
let mut backward_states_map = HashMap::new();
114113
let mut retro_forwards_map = HashMap::new();
115114

@@ -247,14 +246,6 @@ impl CheckpointerBuilder {
247246
}
248247
}
249248

250-
fn make_tree(&self, graph: &NodeSteps) -> NodeTree {
251-
let mut tree = HashMap::default();
252-
for (id, step) in graph {
253-
tree.insert(*id, step.parents());
254-
}
255-
NodeTree::new(tree)
256-
}
257-
258249
fn update_n_required_of_parents(
259250
id: NodeID,
260251
n_required_map: &mut HashMap<NodeID, usize>,

crates/burn-autodiff/src/checkpoint/strategy.rs

+35-4
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,23 @@ pub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sync + 'st
1616
fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty;
1717

1818
/// Checkpoints parents if necessary in the strategy
19-
fn checkpoint_parents<'a, B2, A>(parents: A, builder: &mut CheckpointerBuilder)
19+
fn checkpoint_parents<'a, B2, A>(
20+
parents: A,
21+
builder: &mut CheckpointerBuilder,
22+
) -> Result<(), CheckpointingError>
2023
where
2124
B2: Backend,
2225
A: IntoIterator<Item = &'a AutodiffTensor<B2>>;
2326
}
2427

28+
#[derive(Debug)]
29+
/// Error that can happen when trying to checkpoint a tensor.
30+
pub enum CheckpointingError {
31+
/// When a parent is untracked, we can't easily checkpoint its state, since we don't know the
32+
/// requirements in advanced.
33+
UntrackedParent,
34+
}
35+
2536
#[derive(Clone, Copy, Debug, Default)]
2637
/// All operations are considered compute bound, notwithstanding how they are marked
2738
pub struct NoCheckpointing {}
@@ -34,12 +45,16 @@ impl CheckpointStrategy for NoCheckpointing {
3445

3546
/// An operation marked as memory bound is actually compute bound.
3647
/// It's therefore useless to checkpoint the parents
37-
fn checkpoint_parents<'a, B2, A>(_parents: A, _builder: &mut CheckpointerBuilder)
48+
fn checkpoint_parents<'a, B2, A>(
49+
_parents: A,
50+
_builder: &mut CheckpointerBuilder,
51+
) -> Result<(), CheckpointingError>
3852
where
3953
B2: Backend,
4054
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
4155
{
4256
// Nothing to do here
57+
Ok(())
4358
}
4459
}
4560

@@ -59,13 +74,29 @@ impl CheckpointStrategy for BalancedCheckpointing {
5974
/// An operation marked as memory bound is really memory bound.
6075
/// Since the operation may not checkpoint its parents but may need them indirectly
6176
/// if asked to recompute itself, the method needs to know the parent tensors to maybe checkpoint them
62-
fn checkpoint_parents<'a, B2, A>(parents: A, builder: &mut CheckpointerBuilder)
77+
fn checkpoint_parents<'a, B2, A>(
78+
parents: A,
79+
builder: &mut CheckpointerBuilder,
80+
) -> Result<(), CheckpointingError>
6381
where
6482
B2: Backend,
6583
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
6684
{
85+
let mut can_checkpoint = true;
86+
6787
for tensor in parents.into_iter() {
68-
builder.checkpoint(tensor, ActionType::Backup);
88+
if let crate::graph::Requirement::None = tensor.node.requirement {
89+
can_checkpoint = false;
90+
} else {
91+
builder.checkpoint(tensor, ActionType::Backup);
92+
}
6993
}
94+
95+
if !can_checkpoint {
96+
*builder = CheckpointerBuilder::default();
97+
return Err(CheckpointingError::UntrackedParent);
98+
}
99+
100+
Ok(())
70101
}
71102
}

crates/burn-autodiff/src/graph/base.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::NodeID;
2-
use crate::{checkpoint::base::Checkpointer, collections::HashMap, grads::Gradients};
2+
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
33
use alloc::{boxed::Box, vec::Vec};
44

55
/// Backward step for reverse mode autodiff.
@@ -15,4 +15,3 @@ pub trait Step: Send + core::fmt::Debug {
1515
}
1616

1717
pub type StepBoxed = Box<dyn Step>;
18-
pub type NodeSteps = HashMap<NodeID, StepBoxed>;

crates/burn-autodiff/src/graph/requirement.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::NodeRef;
22

33
/// Requirement for each tensor in the graph.
4-
#[derive(Debug, Clone, Copy)]
4+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55
pub enum Requirement {
66
/// Operations that require gradients.
77
Grad,

crates/burn-autodiff/src/ops/base.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,17 @@ where
106106
B2: Backend,
107107
A: IntoIterator<Item = &'a AutodiffTensor<B2>>,
108108
{
109-
C::checkpoint_parents(parents, &mut self.checkpointer_builder);
109+
let compute_property = match C::checkpoint_parents(parents, &mut self.checkpointer_builder)
110+
{
111+
Ok(..) => self.compute_property,
112+
Err(..) => ComputingProperty::ComputeBound,
113+
};
110114

111115
OpsPrep::new(
112116
self.nodes,
113117
self.requirement,
114118
self.backward,
115-
self.compute_property,
119+
compute_property,
116120
self.checkpointer_builder,
117121
)
118122
}

crates/burn-autodiff/src/runtime/server.rs

+15-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use super::memory_management::GraphMemoryManagement;
22
use crate::{
33
NodeID,
4-
checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},
4+
checkpoint::{
5+
base::{Checkpointer, NodeTree},
6+
builder::CheckpointerBuilder,
7+
},
58
collections::HashMap,
69
grads::Gradients,
710
graph::{StepBoxed, traversal::BreadthFirstSearch},
@@ -34,8 +37,7 @@ impl AutodiffServer {
3437
);
3538
let builder = self.actions_builder.remove(&node_id).unwrap();
3639

37-
let (tape, builder) = self.build_tape(node_id, step, builder);
38-
let checkpointer = builder.build(&self.steps);
40+
let (tape, checkpointer) = self.build_tape(node_id, step, builder);
3941

4042
let gradients = Self::execute_steps(tape, grads, checkpointer);
4143

@@ -54,20 +56,25 @@ impl AutodiffServer {
5456
node: NodeID,
5557
node_step: StepBoxed,
5658
mut builder: CheckpointerBuilder,
57-
) -> (Vec<Vec<StepBoxed>>, CheckpointerBuilder) {
59+
) -> (Vec<Vec<StepBoxed>>, Checkpointer) {
5860
let mut tape = (0..node_step.depth())
5961
.map(|_| Vec::with_capacity(1))
6062
.collect::<Vec<_>>();
6163

64+
let mut tree = HashMap::default();
65+
6266
BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| {
6367
self.memory_management.consume_node(id);
6468

6569
let depth = step.depth();
70+
6671
if depth == 0 {
6772
return;
6873
}
6974

7075
if let Some(steps) = tape.get_mut(depth - 1) {
76+
let parents = step.parents().into_iter().filter(|s| *s != id);
77+
tree.insert(id, parents.collect());
7178
steps.push(step);
7279
}
7380

@@ -76,7 +83,9 @@ impl AutodiffServer {
7683
}
7784
});
7885

79-
(tape, builder)
86+
let checkpointer = builder.build(NodeTree::new(tree));
87+
88+
(tape, checkpointer)
8089
}
8190

8291
fn execute_steps(
@@ -93,6 +102,7 @@ impl AutodiffServer {
93102
#[cfg(feature = "export_tests")]
94103
// For checkpointing tests
95104
assert!(checkpointer.is_empty());
105+
96106
grads
97107
}
98108
}

0 commit comments

Comments
 (0)