Skip to content

Commit 478472b

Browse files
fix gradient accumulation
1 parent c1d13df commit 478472b

File tree

3 files changed

+70
-26
lines changed

3 files changed

+70
-26
lines changed

examples/mnist/mnist-common.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -530,13 +530,16 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
530530
void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split) {
531531
const int64_t t_start_us = ggml_time_us();
532532

533+
// gf == graph forward, forward pass only.
533534
struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
534535
ggml_build_forward_expand(gf, model.loss);
535536

536-
struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients.
537+
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
538+
struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
537539
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true, false);
538540

539-
struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gf); // Backward pass, gradients + optimizer.
541+
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
542+
struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
540543
ggml_build_opt_adamw(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);
541544

542545
model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
@@ -557,8 +560,6 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
557560
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
558561
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
559562

560-
ggml_backend_graph_compute(model.backend, gf); // Always compute forward pass.
561-
562563
// With a period of nbatch_logical/nbatch_physical iterations:
563564
if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) {
564565
// For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them:

include/ggml.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,13 @@ extern "C" {
570570
GGML_LOG_LEVEL_DEBUG = 5
571571
};
572572

573+
// this tensor...
573574
enum ggml_tensor_flag {
574-
GGML_TENSOR_FLAG_INPUT = 1,
575-
GGML_TENSOR_FLAG_OUTPUT = 2,
576-
GGML_TENSOR_FLAG_PARAM = 4,
577-
GGML_TENSOR_FLAG_LOSS = 8,
575+
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML comptue graph
576+
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML comptue graph
577+
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
578+
GGML_TENSOR_FLAG_GRAD_ACC = 8, // ...is an accumulator for gradients
579+
GGML_TENSOR_FLAG_LOSS = 16, // ...defines loss for numerical optimization (multiple loss tensors add up)
578580
};
579581

580582
// ggml object

src/ggml.c

+59-18
Original file line numberDiff line numberDiff line change
@@ -18123,36 +18123,75 @@ void ggml_build_backward_gradient_checkpointing(
1812318123
ggml_hash_map_free(replacements);
1812418124
}
1812518125

18126-
// functions to change gradients considering the case that input a might be initial gradient with zero value
18126+
// utility functions to change gradients
18127+
// by default, just add/subtract/etc. the gradients
18128+
// if a is in zero_table and not a gradient accumulator, replace a
18129+
// if a is in zero_table and a gradient accumulator, modify gradients in-place and mark result as gradient accumulator
1812718130

1812818131
static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
1812918132
if (ggml_hash_contains(zero_table, a)) {
18130-
return b;
18133+
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18134+
struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
18135+
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18136+
const size_t insert_result = ggml_hash_insert(zero_table, ret);
18137+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18138+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18139+
return ret;
18140+
} else {
18141+
return b;
18142+
}
1813118143
} else {
1813218144
return ggml_add_impl(ctx, a, b, false);
1813318145
}
1813418146
}
1813518147

1813618148
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
1813718149
if (ggml_hash_contains(zero_table, a)) {
18138-
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
18139-
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
18150+
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18151+
struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
18152+
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18153+
const size_t insert_result = ggml_hash_insert(zero_table, ret);
18154+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18155+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18156+
return ret;
18157+
} else {
18158+
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
18159+
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
18160+
}
1814018161
} else {
1814118162
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
1814218163
}
1814318164
}
1814418165

1814518166
static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
1814618167
if (ggml_hash_contains(zero_table, a)) {
18147-
return ggml_repeat(ctx, b, a);
18168+
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18169+
struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
18170+
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18171+
const size_t insert_result = ggml_hash_insert(zero_table, ret);
18172+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18173+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18174+
return ret;
18175+
} else {
18176+
return ggml_repeat(ctx, b, a);
18177+
}
1814818178
} else {
1814918179
return ggml_add1_impl(ctx, a, b, false);
1815018180
}
1815118181
}
1815218182

1815318183
static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
1815418184
if (ggml_hash_contains(zero_table, a)) {
18155-
return ggml_neg(ctx, b);
18185+
if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18186+
struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
18187+
ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18188+
const size_t insert_result = ggml_hash_insert(zero_table, ret);
18189+
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18190+
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18191+
return ret;
18192+
} else {
18193+
return ggml_neg(ctx, b);
18194+
}
1815618195
} else {
1815718196
return ggml_sub_impl(ctx, a, b, false);
1815818197
}
@@ -19136,22 +19175,25 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
1913619175
}
1913719176
}
1913819177

19139-
// hash table of original gradients that should be overwritten instead of incremented
19178+
// keep table of original gradients for replacement/accumulation logic
1914019179
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
19180+
for (int i = 0; i < gf->n_nodes; i++) {
19181+
struct ggml_tensor * node = gf->nodes[i];
1914119182

19142-
// when accumulating gradients the table is empty -> gradients always incremented
19143-
if (!accumulate) {
19144-
for (int i = 0; i < gf->n_nodes; i++) {
19145-
if (gf->grads[i]) {
19146-
ggml_hash_insert(&zero_table, gf->grads[i]);
19183+
if (node->grad) {
19184+
// only gradients of trainable parameters should be accumulated
19185+
if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
19186+
node->grad->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
1914719187
}
19188+
19189+
ggml_hash_insert(&zero_table, node->grad);
1914819190
}
1914919191
}
1915019192

1915119193
for (int i = gf->n_nodes - 1; i >= 0; i--) {
1915219194
struct ggml_tensor * node = gf->nodes[i];
1915319195

19154-
// inplace operations to add gradients are not created by ggml_compute_backward
19196+
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
1915519197
// use allocator to automatically make inplace operations
1915619198
if (node->grad) {
1915719199
ggml_compute_backward(ctx, node, &zero_table);
@@ -19319,19 +19361,18 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
1931919361

1932019362
for (int i = 0; i < cgraph->n_nodes; i++) {
1932119363
struct ggml_tensor * node = cgraph->nodes[i];
19322-
struct ggml_tensor * grad = cgraph->grads[i];
1932319364

1932419365
// initial gradients of loss should be 1, 0 otherwise
19325-
if (grad) {
19366+
if (node->grad) {
1932619367
if (node->flags & GGML_TENSOR_FLAG_LOSS) {
19327-
GGML_ASSERT(grad->buffer);
19368+
GGML_ASSERT(node->grad->buffer);
1932819369
GGML_ASSERT(node->type == GGML_TYPE_F32);
1932919370
GGML_ASSERT(ggml_is_scalar(node));
1933019371

1933119372
const float onef = 1.0f;
19332-
ggml_backend_tensor_set(grad, &onef, 0, ggml_nbytes(grad));
19373+
ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
1933319374
} else {
19334-
ggml_set_zero(grad);
19375+
ggml_set_zero(node->grad);
1933519376
}
1933619377
}
1933719378

0 commit comments

Comments
 (0)