@@ -2850,9 +2850,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2850
2850
2851
2851
"CROSS_ENTROPY_LOSS",
2852
2852
"CROSS_ENTROPY_LOSS_BACK",
2853
+ "OPT_STEP_ADAM",
2853
2854
};
2854
2855
2855
- static_assert(GGML_OP_COUNT == 78 , "GGML_OP_COUNT != 78 ");
2856
+ static_assert(GGML_OP_COUNT == 79 , "GGML_OP_COUNT != 79 ");
2856
2857
2857
2858
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2858
2859
"none",
@@ -2942,9 +2943,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2942
2943
2943
2944
"cross_entropy_loss(x,y)",
2944
2945
"cross_entropy_loss_back(x,y)",
2946
+ "adam(x)",
2945
2947
};
2946
2948
2947
- static_assert(GGML_OP_COUNT == 78 , "GGML_OP_COUNT != 78 ");
2949
+ static_assert(GGML_OP_COUNT == 79 , "GGML_OP_COUNT != 79 ");
2948
2950
2949
2951
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2950
2952
@@ -8103,6 +8105,26 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
8103
8105
return result;
8104
8106
}
8105
8107
8108
+ // opt_step_adam
8109
+
8110
+ struct ggml_tensor * ggml_opt_step_adam(
8111
+ struct ggml_context * ctx,
8112
+ struct ggml_tensor * a,
8113
+ float alpha) {
8114
+ GGML_ASSERT(a->grad);
8115
+
8116
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
8117
+
8118
+ result->op = GGML_OP_OPT_STEP_ADAM;
8119
+ result->grad = NULL;
8120
+ result->src[0] = a;
8121
+ result->src[1] = a->grad;
8122
+
8123
+ ggml_set_op_params(result, &alpha, sizeof(alpha));
8124
+
8125
+ return result;
8126
+ }
8127
+
8106
8128
////////////////////////////////////////////////////////////////////////////////
8107
8129
8108
8130
void ggml_set_param(
@@ -17092,6 +17114,62 @@ static void ggml_compute_forward_cross_entropy_loss_back(
17092
17114
}
17093
17115
}
17094
17116
17117
+ static void ggml_compute_forward_opt_step_adam_f32(
17118
+ const struct ggml_compute_params * params,
17119
+ struct ggml_tensor * dst) {
17120
+
17121
+ const struct ggml_tensor * src0 = dst->src[0];
17122
+ const struct ggml_tensor * src0_grad = dst->src[1];
17123
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
17124
+
17125
+ const int ith = params->ith;
17126
+ const int nth = params->nth;
17127
+
17128
+ const int nr = ggml_nrows(src0);
17129
+
17130
+ GGML_TENSOR_UNARY_OP_LOCALS
17131
+ GGML_ASSERT(nb00 == sizeof(float));
17132
+
17133
+ // rows per thread
17134
+ const int dr = (nr + nth - 1)/nth;
17135
+
17136
+ // row range for this thread
17137
+ const int ir0 = dr*ith;
17138
+ const int ir1 = MIN(ir0 + dr, nr);
17139
+
17140
+ const float alpha = ggml_get_op_params_f32(dst, 0);
17141
+
17142
+ for (int ir = ir0; ir < ir1; ++ir) {
17143
+ const int64_t i03 = ir/(ne02*ne01);
17144
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
17145
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
17146
+
17147
+ const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
17148
+
17149
+ float * weight_ptr = (float *) ((char *) src0->data + offset);
17150
+ const float * grad_ptr = (const float *) ((const char *) src0_grad->data + offset);
17151
+
17152
+ ggml_vec_mad_f32(ne00, weight_ptr, grad_ptr, -alpha);
17153
+ }
17154
+ }
17155
+
17156
+ static void ggml_compute_forward_opt_step_adam(
17157
+ const struct ggml_compute_params * params,
17158
+ struct ggml_tensor * dst) {
17159
+
17160
+ const struct ggml_tensor * src0 = dst->src[0];
17161
+
17162
+ switch (src0->type) {
17163
+ case GGML_TYPE_F32:
17164
+ {
17165
+ ggml_compute_forward_opt_step_adam_f32(params, dst);
17166
+ } break;
17167
+ default:
17168
+ {
17169
+ GGML_ABORT("fatal error");
17170
+ }
17171
+ }
17172
+ }
17095
17173
/////////////////////////////////
17096
17174
17097
17175
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -17433,6 +17511,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17433
17511
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
17434
17512
}
17435
17513
break;
17514
+ case GGML_OP_OPT_STEP_ADAM:
17515
+ {
17516
+ ggml_compute_forward_opt_step_adam(params, tensor);
17517
+ }
17518
+ break;
17436
17519
case GGML_OP_NONE:
17437
17520
{
17438
17521
// nop
@@ -18519,6 +18602,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18519
18602
{
18520
18603
GGML_ABORT("fatal error"); // not supported
18521
18604
}
18605
+ case GGML_OP_OPT_STEP_ADAM:
18606
+ {
18607
+ GGML_ABORT("fatal error"); // not supported
18608
+ }
18522
18609
case GGML_OP_NONE:
18523
18610
{
18524
18611
// nop
@@ -18651,6 +18738,16 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
18651
18738
}
18652
18739
}
18653
18740
18741
+ for (int i = 0; i < gf->n_nodes; i++) {
18742
+ struct ggml_tensor * node = gf->nodes[i];
18743
+
18744
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
18745
+ GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18746
+ struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, 0.001f);
18747
+ ggml_build_forward_expand(gb, opt_step);
18748
+ }
18749
+ }
18750
+
18654
18751
ggml_hash_set_free(&zero_table);
18655
18752
}
18656
18753
@@ -19106,6 +19203,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
19106
19203
} break;
19107
19204
case GGML_OP_CROSS_ENTROPY_LOSS:
19108
19205
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
19206
+ case GGML_OP_OPT_STEP_ADAM:
19109
19207
{
19110
19208
n_tasks = n_threads;
19111
19209
} break;
0 commit comments