Skip to content

Commit ed5cde0

Browse files
stochastic gradient descent op
1 parent a6bc691 commit ed5cde0

File tree

4 files changed

+112
-6
lines changed

4 files changed

+112
-6
lines changed

examples/mnist/mnist-common.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
514514
opt_pars.print_backward_graph = false;
515515
opt_pars.n_threads = std::thread::hardware_concurrency();
516516
opt_pars.adam.n_iter = 1; // per call of ggml_opt_resume_g
517-
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 397510);
517+
ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 0);
518518

519519
model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
520520

@@ -530,8 +530,10 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
530530
ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images));
531531
ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
532532

533-
enum ggml_opt_result opt_result = ggml_opt_resume_g(model.ctx_compute, &opt_ctx, model.loss, gf, gb, NULL, NULL);
534-
GGML_ASSERT(opt_result == GGML_OPT_RESULT_OK || opt_result == GGML_OPT_RESULT_DID_NOT_CONVERGE);
533+
const float onef = 1.0f;
534+
ggml_backend_graph_compute(model.backend, gf);
535+
ggml_backend_tensor_set(model.loss->grad, &onef, 0, sizeof(float));
536+
ggml_backend_graph_compute(model.backend, gb);
535537

536538
ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss));
537539
ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));

examples/mnist/mnist-common.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct mnist_model {
5959
mnist_model() {
6060
// backend = ggml_backend_cuda_init(0);
6161
backend = ggml_backend_cpu_init();
62-
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency());
62+
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency()/2);
6363

6464
buf_weight = malloc(size_weight);
6565
{

include/ggml.h

+6
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ extern "C" {
528528

529529
GGML_OP_CROSS_ENTROPY_LOSS,
530530
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
531+
GGML_OP_OPT_STEP_ADAM,
531532

532533
GGML_OP_COUNT,
533534
};
@@ -2033,6 +2034,11 @@ extern "C" {
20332034
struct ggml_tensor * b,
20342035
struct ggml_tensor * c);
20352036

2037+
GGML_API struct ggml_tensor * ggml_opt_step_adam(
2038+
struct ggml_context * ctx,
2039+
struct ggml_tensor * a,
2040+
float alpha);
2041+
20362042
//
20372043
// automatic differentiation
20382044
//

src/ggml.c

+100-2
Original file line numberDiff line numberDiff line change
@@ -2850,9 +2850,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
28502850

28512851
"CROSS_ENTROPY_LOSS",
28522852
"CROSS_ENTROPY_LOSS_BACK",
2853+
"OPT_STEP_ADAM",
28532854
};
28542855

2855-
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
2856+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
28562857

28572858
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
28582859
"none",
@@ -2942,9 +2943,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29422943

29432944
"cross_entropy_loss(x,y)",
29442945
"cross_entropy_loss_back(x,y)",
2946+
"adam(x)",
29452947
};
29462948

2947-
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
2949+
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
29482950

29492951
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
29502952

@@ -8103,6 +8105,26 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
81038105
return result;
81048106
}
81058107

8108+
// opt_step_adam
8109+
8110+
struct ggml_tensor * ggml_opt_step_adam(
8111+
struct ggml_context * ctx,
8112+
struct ggml_tensor * a,
8113+
float alpha) {
8114+
GGML_ASSERT(a->grad);
8115+
8116+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
8117+
8118+
result->op = GGML_OP_OPT_STEP_ADAM;
8119+
result->grad = NULL;
8120+
result->src[0] = a;
8121+
result->src[1] = a->grad;
8122+
8123+
ggml_set_op_params(result, &alpha, sizeof(alpha));
8124+
8125+
return result;
8126+
}
8127+
81068128
////////////////////////////////////////////////////////////////////////////////
81078129

81088130
void ggml_set_param(
@@ -17092,6 +17114,62 @@ static void ggml_compute_forward_cross_entropy_loss_back(
1709217114
}
1709317115
}
1709417116

17117+
static void ggml_compute_forward_opt_step_adam_f32(
17118+
const struct ggml_compute_params * params,
17119+
struct ggml_tensor * dst) {
17120+
17121+
const struct ggml_tensor * src0 = dst->src[0];
17122+
const struct ggml_tensor * src0_grad = dst->src[1];
17123+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
17124+
17125+
const int ith = params->ith;
17126+
const int nth = params->nth;
17127+
17128+
const int nr = ggml_nrows(src0);
17129+
17130+
GGML_TENSOR_UNARY_OP_LOCALS
17131+
GGML_ASSERT(nb00 == sizeof(float));
17132+
17133+
// rows per thread
17134+
const int dr = (nr + nth - 1)/nth;
17135+
17136+
// row range for this thread
17137+
const int ir0 = dr*ith;
17138+
const int ir1 = MIN(ir0 + dr, nr);
17139+
17140+
const float alpha = ggml_get_op_params_f32(dst, 0);
17141+
17142+
for (int ir = ir0; ir < ir1; ++ir) {
17143+
const int64_t i03 = ir/(ne02*ne01);
17144+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
17145+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
17146+
17147+
const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
17148+
17149+
float * weight_ptr = (float *) ((char *) src0->data + offset);
17150+
const float * grad_ptr = (const float *) ((const char *) src0_grad->data + offset);
17151+
17152+
ggml_vec_mad_f32(ne00, weight_ptr, grad_ptr, -alpha);
17153+
}
17154+
}
17155+
17156+
static void ggml_compute_forward_opt_step_adam(
17157+
const struct ggml_compute_params * params,
17158+
struct ggml_tensor * dst) {
17159+
17160+
const struct ggml_tensor * src0 = dst->src[0];
17161+
17162+
switch (src0->type) {
17163+
case GGML_TYPE_F32:
17164+
{
17165+
ggml_compute_forward_opt_step_adam_f32(params, dst);
17166+
} break;
17167+
default:
17168+
{
17169+
GGML_ABORT("fatal error");
17170+
}
17171+
}
17172+
}
1709517173
/////////////////////////////////
1709617174

1709717175
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -17433,6 +17511,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1743317511
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
1743417512
}
1743517513
break;
17514+
case GGML_OP_OPT_STEP_ADAM:
17515+
{
17516+
ggml_compute_forward_opt_step_adam(params, tensor);
17517+
}
17518+
break;
1743617519
case GGML_OP_NONE:
1743717520
{
1743817521
// nop
@@ -18519,6 +18602,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1851918602
{
1852018603
GGML_ABORT("fatal error"); // not supported
1852118604
}
18605+
case GGML_OP_OPT_STEP_ADAM:
18606+
{
18607+
GGML_ABORT("fatal error"); // not supported
18608+
}
1852218609
case GGML_OP_NONE:
1852318610
{
1852418611
// nop
@@ -18651,6 +18738,16 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
1865118738
}
1865218739
}
1865318740

18741+
for (int i = 0; i < gf->n_nodes; i++) {
18742+
struct ggml_tensor * node = gf->nodes[i];
18743+
18744+
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
18745+
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
18746+
struct ggml_tensor * opt_step = ggml_opt_step_adam(ctx, node, 0.001f);
18747+
ggml_build_forward_expand(gb, opt_step);
18748+
}
18749+
}
18750+
1865418751
ggml_hash_set_free(&zero_table);
1865518752
}
1865618753

@@ -19106,6 +19203,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1910619203
} break;
1910719204
case GGML_OP_CROSS_ENTROPY_LOSS:
1910819205
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
19206+
case GGML_OP_OPT_STEP_ADAM:
1910919207
{
1911019208
n_tasks = n_threads;
1911119209
} break;

0 commit comments

Comments
 (0)