Skip to content

Commit e7b2390

Browse files
JohannesGaesslerggerganovslaren
authored
ggml/examples: add backend support for numerical optimization (#949)
* CUDA eval works * stochastic gradient descent op * Adam except decay * CUDA CROSS_ENTROPY_LOSS_BACK * CUDA mnist-fc training works * backend CLI arg * refactor gguf load * remove sched from opt_step_adam * implement l1 regularization (weight decay) * extra call to add optimizer * initialize gradients with ggml_graph_reset * gradient accumulation * increment iter per eval instead of epoch * adjust backend interfaces * fix ggml_graph_reset without backend * fix ggml graph export/import * fixup * rename * revert ggml_opt changes * more general CUDA repeat_back * update documentation, fix CNN * validation split * add clarifying comment * optimize PyTorch training * adjust buffer size, thread count * fix 0.0f validation split * Update examples/mnist/mnist-common.cpp Co-authored-by: Georgi Gerganov <[email protected]> * fix gradient accumulation * tensor flag for accumulators -> tensor hash set * Update include/ggml.h Co-authored-by: slaren <[email protected]> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <[email protected]> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <[email protected]> * fix test prints * Update src/ggml-backend.c Co-authored-by: Georgi Gerganov <[email protected]> * better CUDA support for noncontiguous out_prod * add comment --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent ea40f60 commit e7b2390

33 files changed

+1290
-344
lines changed

examples/mnist/README.md

+69-53
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ $ python3 mnist-train-fc.py mnist-fc-f32.gguf
1818

1919
...
2020

21-
Test loss: 0.069983+-0.009196, Test accuracy: 97.94+-0.14%
21+
Test loss: 0.066051+-0.011630, Test accuracy: 98.07+-0.14%
2222

2323
Model tensors saved to mnist-fc-f32.gguf:
2424
fc1.weight (500, 784)
@@ -28,7 +28,7 @@ fc2.bias (10,)
2828
```
2929

3030
The training script includes an evaluation of the model on the test set.
31-
To evaluate the model using GGML, run:
31+
To evaluate the model on the CPU using GGML, run:
3232

3333
```bash
3434
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte
@@ -37,45 +37,50 @@ ________________________________________________________
3737
________________________________________________________
3838
________________________________________________________
3939
________________________________________________________
40-
________________________________######__________________
41-
____________________________########____________________
42-
________________________########________________________
43-
____________________########________________##__________
44-
__________________######____________________##__________
45-
________________######______________________####________
46-
______________######________________________####________
47-
____________######__________________________####________
48-
____________####____________________________####________
49-
__________####______________________________####________
50-
__________####______________________________####________
51-
__________##________________________________####________
52-
__________##______________________________####__________
53-
__________##____________________________######__________
54-
__________##__________________________######____________
55-
____________##____________________########______________
56-
____________##########################__________________
57-
______________##################________________________
58-
________________________________________________________
59-
________________________________________________________
40+
__________________________________####__________________
41+
______________________________########__________________
42+
__________________________##########____________________
43+
______________________##############____________________
44+
____________________######________####__________________
45+
__________________________________####__________________
46+
__________________________________####__________________
47+
________________________________####____________________
48+
______________________________####______________________
49+
________________________##########______________________
50+
______________________########__####____________________
51+
________________________##__________##__________________
52+
____________________________________##__________________
53+
__________________________________##____________________
54+
__________________________________##____________________
55+
________________________________##______________________
56+
____________________________####________________________
57+
__________##____________######__________________________
58+
__________##############________________________________
59+
________________####____________________________________
6060
________________________________________________________
6161
________________________________________________________
6262
________________________________________________________
6363
________________________________________________________
6464
mnist_graph_eval: trying to load a ggml graph from mnist-fc-f32.gguf
6565
ggml_graph_import: invalid magic number, got 46554747
6666
mnist_graph_eval: could not load a ggml graph from mnist-fc-f32.gguf
67+
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
68+
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
69+
ggml_cuda_init: found 1 CUDA devices:
70+
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
71+
mnist_model: using CPU backend
6772
mnist_model_init_from_file: loading model weights from 'mnist-fc-f32.gguf'
6873
mnist_model_init_from_file: model arch is mnist-fc
6974
mnist_model_init_from_file: successfully loaded weights from mnist-fc-f32.gguf
70-
main: loaded model in 1.52 ms
71-
mnist_model_eval: model evaluation on 10000 images took 26.65 ms, 2.66 us/image
72-
main: predicted digit is 0
73-
main: test_loss=0.069983+-0.009196
74-
main: test_acc=97.94+-0.14%
75+
main: loaded model in 13.03 ms
76+
mnist_model_eval: model evaluation on 10000 images took 95.02 ms, 9.50 us/image
77+
main: predicted digit is 3
78+
main: test_loss=0.066051+-0.009343
79+
main: test_acc=98.07+-0.14%
7580
```
7681

7782
In addition to the evaluation on the test set the GGML evaluation also prints a random image from the test set as well as the model prediction for said image.
78-
To train a fully connected model using GGML run:
83+
To train a fully connected model on the CPU using GGML run:
7984

8085
``` bash
8186
$ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
@@ -96,12 +101,12 @@ $ python3 mnist-train-cnn.py mnist-cnn-f32.gguf
96101

97102
...
98103

99-
Test loss: 0.046456
100-
Test accuracy: 98.40%
104+
Test loss: 0.045483
105+
Test accuracy: 98.56%
101106
GGUF model saved to 'mnist-cnn-f32.gguf'
102107
```
103108

104-
The saved model can be evaluated using the `mnist-eval` binary:
109+
The saved model can be evaluated on the CPU using the `mnist-eval` binary:
105110

106111
```bash
107112
$ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte
@@ -111,50 +116,61 @@ ________________________________________________________
111116
________________________________________________________
112117
________________________________________________________
113118
________________________________________________________
114-
________________________________________________________
115-
________________________________________________________
116-
________________________####____________________________
117-
__________________________##____________________________
118-
__________________________##____________________________
119-
__________________________##____________________________
120-
__________________________##____________________________
121-
__________________________##____________________________
122-
____________________________##__________________________
123-
____________________________##__________________________
124-
____________________________##__________________________
125-
______________________________##________________________
126-
______________________________##________________________
127-
______________________________####______________________
128-
________________________________##______________________
129-
________________________________##______________________
130-
________________________________####____________________
119+
______________________________________##________________
120+
______________________________________##________________
121+
______________________________________##________________
122+
____________________________________##__________________
123+
__________________________________####__________________
131124
__________________________________##____________________
132125
________________________________##______________________
126+
______________________________##________________________
127+
____________________________####________________________
128+
____________________________##__________________________
129+
__________________________##____________________________
130+
________________________##______________________________
131+
______________________##________________________________
132+
____________________####________________________________
133+
____________________##__________________________________
134+
__________________##____________________________________
135+
________________##______________________________________
136+
________________________________________________________
137+
________________________________________________________
133138
________________________________________________________
134139
________________________________________________________
135140
________________________________________________________
136141
________________________________________________________
137142
mnist_graph_eval: trying to load a ggml graph from mnist-cnn-f32.gguf
138143
ggml_graph_import: invalid magic number, got 46554747
139144
mnist_graph_eval: could not load a ggml graph from mnist-cnn-f32.gguf
145+
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
146+
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
147+
ggml_cuda_init: found 1 CUDA devices:
148+
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
149+
mnist_model: using CPU backend
140150
mnist_model_init_from_file: loading model weights from 'mnist-cnn-f32.gguf'
141151
mnist_model_init_from_file: model arch is mnist-cnn
142152
mnist_model_init_from_file: successfully loaded weights from mnist-cnn-f32.gguf
143-
main: loaded model in 5.45 ms
144-
mnist_model_eval: model evaluation on 10000 images took 605.60 ms, 60.56 us/image
153+
main: loaded model in 11.88 ms
154+
mnist_model_eval: model evaluation on 10000 images took 1074.09 ms, 107.41 us/image
145155
main: predicted digit is 1
146-
main: test_loss=0.046456+-0.007354
147-
main: test_acc=98.40+-0.13%
156+
main: test_loss=0.045483+-0.006884
157+
main: test_acc=98.56+-0.12%
148158
```
149159

150-
Like with the fully connected network the convolutional network can also be trained using GGML:
160+
Like with the fully connected network the convolutional network can also be trained on the CPU using GGML:
151161

152162
``` bash
153163
$ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
154164
```
155165

156166
As always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`.
157167

168+
## CUDA
169+
170+
The fully connected model can be trained and evaluated using CUDA.
171+
`mnist-train` and `mnist-eval` accept an additional, optional argument behind those listed so far to specify the backend.
172+
The default is `CPU`, by specifying `CUDA0` the first available CUDA device can be used instead (make sure to compile GGML with CUDA cupport).
173+
158174
## Web demo
159175

160176
The evaluation code can be compiled to WebAssembly using [Emscripten](https://emscripten.org/) (may need to re-login to update `$PATH` after installation).

0 commit comments

Comments
 (0)