@@ -18123,36 +18123,75 @@ void ggml_build_backward_gradient_checkpointing(
18123
18123
ggml_hash_map_free(replacements);
18124
18124
}
18125
18125
18126
- // functions to change gradients considering the case that input a might be initial gradient with zero value
18126
+ // utility functions to change gradients
18127
+ // by default, just add/subtract/etc. the gradients
18128
+ // if a is in zero_table and not a gradient accumulator, replace a
18129
+ // if a is in zero_table and a gradient accumulator, modify gradients in-place and mark result as gradient accumulator
18127
18130
18128
18131
static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
18129
18132
if (ggml_hash_contains(zero_table, a)) {
18130
- return b;
18133
+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18134
+ struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
18135
+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18136
+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18137
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18138
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18139
+ return ret;
18140
+ } else {
18141
+ return b;
18142
+ }
18131
18143
} else {
18132
18144
return ggml_add_impl(ctx, a, b, false);
18133
18145
}
18134
18146
}
18135
18147
18136
18148
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
18137
18149
if (ggml_hash_contains(zero_table, a)) {
18138
- struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
18139
- return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
18150
+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18151
+ struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
18152
+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18153
+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18154
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18155
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18156
+ return ret;
18157
+ } else {
18158
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
18159
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
18160
+ }
18140
18161
} else {
18141
18162
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
18142
18163
}
18143
18164
}
18144
18165
18145
18166
static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
18146
18167
if (ggml_hash_contains(zero_table, a)) {
18147
- return ggml_repeat(ctx, b, a);
18168
+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18169
+ struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
18170
+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18171
+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18172
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18173
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18174
+ return ret;
18175
+ } else {
18176
+ return ggml_repeat(ctx, b, a);
18177
+ }
18148
18178
} else {
18149
18179
return ggml_add1_impl(ctx, a, b, false);
18150
18180
}
18151
18181
}
18152
18182
18153
18183
static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
18154
18184
if (ggml_hash_contains(zero_table, a)) {
18155
- return ggml_neg(ctx, b);
18185
+ if (a->flags & GGML_TENSOR_FLAG_GRAD_ACC) {
18186
+ struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
18187
+ ret->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
18188
+ const size_t insert_result = ggml_hash_insert(zero_table, ret);
18189
+ GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
18190
+ GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
18191
+ return ret;
18192
+ } else {
18193
+ return ggml_neg(ctx, b);
18194
+ }
18156
18195
} else {
18157
18196
return ggml_sub_impl(ctx, a, b, false);
18158
18197
}
@@ -19136,22 +19175,25 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
19136
19175
}
19137
19176
}
19138
19177
19139
- // hash table of original gradients that should be overwritten instead of incremented
19178
+ // keep table of original gradients for replacement/accumulation logic
19140
19179
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
19180
+ for (int i = 0; i < gf->n_nodes; i++) {
19181
+ struct ggml_tensor * node = gf->nodes[i];
19141
19182
19142
- // when accumulating gradients the table is empty -> gradients always incremented
19143
- if (!accumulate) {
19144
- for (int i = 0; i < gf->n_nodes; i++) {
19145
- if (gf->grads[i]) {
19146
- ggml_hash_insert(&zero_table, gf->grads[i]);
19183
+ if (node->grad) {
19184
+ // only gradients of trainable parameters should be accumulated
19185
+ if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
19186
+ node->grad->flags |= GGML_TENSOR_FLAG_GRAD_ACC;
19147
19187
}
19188
+
19189
+ ggml_hash_insert(&zero_table, node->grad);
19148
19190
}
19149
19191
}
19150
19192
19151
19193
for (int i = gf->n_nodes - 1; i >= 0; i--) {
19152
19194
struct ggml_tensor * node = gf->nodes[i];
19153
19195
19154
- // inplace operations to add gradients are not created by ggml_compute_backward
19196
+ // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
19155
19197
// use allocator to automatically make inplace operations
19156
19198
if (node->grad) {
19157
19199
ggml_compute_backward(ctx, node, &zero_table);
@@ -19319,19 +19361,18 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
19319
19361
19320
19362
for (int i = 0; i < cgraph->n_nodes; i++) {
19321
19363
struct ggml_tensor * node = cgraph->nodes[i];
19322
- struct ggml_tensor * grad = cgraph->grads[i];
19323
19364
19324
19365
// initial gradients of loss should be 1, 0 otherwise
19325
- if (grad) {
19366
+ if (node-> grad) {
19326
19367
if (node->flags & GGML_TENSOR_FLAG_LOSS) {
19327
- GGML_ASSERT(grad->buffer);
19368
+ GGML_ASSERT(node-> grad->buffer);
19328
19369
GGML_ASSERT(node->type == GGML_TYPE_F32);
19329
19370
GGML_ASSERT(ggml_is_scalar(node));
19330
19371
19331
19372
const float onef = 1.0f;
19332
- ggml_backend_tensor_set(grad, &onef, 0, ggml_nbytes(grad));
19373
+ ggml_backend_tensor_set(node-> grad, &onef, 0, ggml_nbytes(node-> grad));
19333
19374
} else {
19334
- ggml_set_zero(grad);
19375
+ ggml_set_zero(node-> grad);
19335
19376
}
19336
19377
}
19337
19378
0 commit comments