@@ -340,11 +340,17 @@ struct vk_device_struct {
340
340
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
341
341
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
342
342
vk_pipeline pipeline_acc_f32;
343
- vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
344
- vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
345
- vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
346
- vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
347
- vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
343
+
344
+ // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
345
+ vk_pipeline pipeline_add[2][2][2];
346
+ vk_pipeline pipeline_add_norepeat[2][2][2];
347
+ vk_pipeline pipeline_sub[2][2][2];
348
+ vk_pipeline pipeline_sub_norepeat[2][2][2];
349
+ vk_pipeline pipeline_mul[2][2][2];
350
+ vk_pipeline pipeline_mul_norepeat[2][2][2];
351
+ vk_pipeline pipeline_div[2][2][2];
352
+ vk_pipeline pipeline_div_norepeat[2][2][2];
353
+
348
354
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
349
355
vk_pipeline pipeline_upscale_f32;
350
356
vk_pipeline pipeline_scale_f32;
@@ -354,23 +360,26 @@ struct vk_device_struct {
354
360
vk_pipeline pipeline_clamp_f32;
355
361
vk_pipeline pipeline_pad_f32;
356
362
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
357
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
358
- vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
363
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
364
+ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
359
365
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
360
366
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
361
367
vk_pipeline pipeline_norm_f32;
362
368
vk_pipeline pipeline_group_norm_f32;
363
369
vk_pipeline pipeline_rms_norm_f32;
364
370
vk_pipeline pipeline_rms_norm_back_f32;
365
371
vk_pipeline pipeline_l2_norm_f32;
366
- vk_pipeline pipeline_gelu_f32;
367
- vk_pipeline pipeline_gelu_quick_f32;
368
- vk_pipeline pipeline_silu_f32;
369
- vk_pipeline pipeline_silu_back_f32;
370
- vk_pipeline pipeline_relu_f32;
372
+
373
+ // [src/dst 0=fp32,1=fp16]
374
+ vk_pipeline pipeline_gelu[2];
375
+ vk_pipeline pipeline_gelu_quick[2];
376
+ vk_pipeline pipeline_silu[2];
377
+ vk_pipeline pipeline_relu[2];
378
+ vk_pipeline pipeline_tanh[2];
379
+ vk_pipeline pipeline_sigmoid[2];
380
+
371
381
vk_pipeline pipeline_leaky_relu_f32;
372
- vk_pipeline pipeline_tanh_f32;
373
- vk_pipeline pipeline_sigmoid_f32;
382
+ vk_pipeline pipeline_silu_back_f32;
374
383
vk_pipeline pipeline_diag_mask_inf_f32;
375
384
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
376
385
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -2508,11 +2517,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
2508
2517
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2509
2518
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2510
2519
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2520
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2511
2521
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2512
2522
2513
2523
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2514
2524
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2515
2525
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2526
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2516
2527
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2517
2528
2518
2529
if (device->float_controls_rte_fp16) {
@@ -2538,19 +2549,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
2538
2549
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2539
2550
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2540
2551
2541
- ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2542
- ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2543
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2544
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2552
+ auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
2553
+ std::string s;
2554
+ s += std::string(src0_f16 ? "_f16" : "_f32");
2555
+ s += std::string(src1_f16 ? "_f16" : "_f32");
2556
+ s += std::string(dst_f16 ? "_f16" : "_f32");
2557
+ return s;
2558
+ };
2545
2559
2546
- ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2560
+ #define CREATE_BINARY(name, namemod, spec) \
2561
+ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2562
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2563
+ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2564
+ "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2565
+
2566
+ CREATE_BINARY(add, , {0})
2567
+ CREATE_BINARY(add, _norepeat, {1})
2568
+ CREATE_BINARY(sub, , {0})
2569
+ CREATE_BINARY(sub, _norepeat, {1})
2570
+ CREATE_BINARY(mul, , {0})
2571
+ CREATE_BINARY(mul, _norepeat, {1})
2572
+ CREATE_BINARY(div, , {0})
2573
+ CREATE_BINARY(div, _norepeat, {1})
2574
+ #undef CREATE_BINARY
2547
2575
2548
- ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2549
- ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2550
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2551
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2552
- ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2553
- ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2576
+ ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2554
2577
2555
2578
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2556
2579
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -2571,14 +2594,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
2571
2594
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2572
2595
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2573
2596
2574
- ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2575
- ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2576
- ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2577
- ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2578
- ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2597
+ #define CREATE_UNARY(name) \
2598
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
2599
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2600
+
2601
+ CREATE_UNARY(gelu)
2602
+ CREATE_UNARY(gelu_quick)
2603
+ CREATE_UNARY(silu)
2604
+ CREATE_UNARY(relu)
2605
+ CREATE_UNARY(tanh)
2606
+ CREATE_UNARY(sigmoid)
2607
+ #undef CREATE_UNARY
2608
+
2579
2609
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2580
- ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2581
- ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2610
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2582
2611
2583
2612
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2584
2613
@@ -4504,6 +4533,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4504
4533
return ctx->device->pipeline_cpy_f16_f16;
4505
4534
}
4506
4535
}
4536
+ if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
4537
+ if (contig) {
4538
+ return ctx->device->pipeline_contig_cpy_f16_f32;
4539
+ } else {
4540
+ return ctx->device->pipeline_cpy_f16_f32;
4541
+ }
4542
+ }
4507
4543
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
4508
4544
if (contig) {
4509
4545
return ctx->device->pipeline_contig_cpy_f32_bf16;
@@ -5894,26 +5930,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5894
5930
}
5895
5931
return nullptr;
5896
5932
case GGML_OP_ADD:
5897
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5898
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
5933
+ case GGML_OP_SUB:
5934
+ case GGML_OP_MUL:
5935
+ case GGML_OP_DIV:
5936
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
5937
+ (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
5938
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
5939
+ return nullptr;
5899
5940
}
5900
- if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
5901
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
5941
+ switch (op) {
5942
+ case GGML_OP_ADD:
5943
+ {
5944
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
5945
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5902
5946
}
5903
- return nullptr;
5904
- case GGML_OP_SUB:
5905
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5906
- return ggml_are_same_shape( src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32 ;
5947
+ case GGML_OP_SUB:
5948
+ {
5949
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
5950
+ return pipelines[ src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16] ;
5907
5951
}
5908
- return nullptr;
5909
- case GGML_OP_MUL:
5910
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5911
- return ggml_are_same_shape( src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32 ;
5952
+ case GGML_OP_MUL:
5953
+ {
5954
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
5955
+ return pipelines[ src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16] ;
5912
5956
}
5913
- return nullptr;
5914
- case GGML_OP_DIV:
5915
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5916
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
5957
+ case GGML_OP_DIV:
5958
+ {
5959
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
5960
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5961
+ }
5962
+ default:
5963
+ break;
5917
5964
}
5918
5965
return nullptr;
5919
5966
case GGML_OP_CONCAT:
@@ -6007,37 +6054,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6007
6054
}
6008
6055
return nullptr;
6009
6056
case GGML_OP_UNARY:
6057
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6058
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6059
+ (src0->type != dst->type)) {
6060
+ return nullptr;
6061
+ }
6062
+
6010
6063
switch (ggml_get_unary_op(dst)) {
6011
6064
case GGML_UNARY_OP_SILU:
6012
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6013
- return ctx->device->pipeline_silu_f32;
6014
- }
6015
- break;
6065
+ return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
6016
6066
case GGML_UNARY_OP_GELU:
6017
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6018
- return ctx->device->pipeline_gelu_f32;
6019
- }
6020
- break;
6067
+ return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
6021
6068
case GGML_UNARY_OP_GELU_QUICK:
6022
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6023
- return ctx->device->pipeline_gelu_quick_f32;
6024
- }
6025
- break;
6069
+ return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
6026
6070
case GGML_UNARY_OP_RELU:
6027
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6028
- return ctx->device->pipeline_relu_f32;
6029
- }
6030
- break;
6071
+ return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
6031
6072
case GGML_UNARY_OP_TANH:
6032
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6033
- return ctx->device->pipeline_tanh_f32;
6034
- }
6035
- break;
6073
+ return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
6036
6074
case GGML_UNARY_OP_SIGMOID:
6037
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6038
- return ctx->device->pipeline_sigmoid_f32;
6039
- }
6040
- break;
6075
+ return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
6041
6076
default:
6042
6077
break;
6043
6078
}
@@ -9423,7 +9458,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9423
9458
case GGML_UNARY_OP_RELU:
9424
9459
case GGML_UNARY_OP_TANH:
9425
9460
case GGML_UNARY_OP_SIGMOID:
9426
- return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
9461
+ return ggml_is_contiguous(op->src[0]) &&
9462
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9463
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
9464
+ (op->src[0]->type == op->type);
9427
9465
default:
9428
9466
return false;
9429
9467
}
@@ -9603,6 +9641,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9603
9641
}
9604
9642
if (src1_type == GGML_TYPE_F32) {
9605
9643
switch (src0_type) {
9644
+ case GGML_TYPE_F16:
9606
9645
case GGML_TYPE_Q4_0:
9607
9646
case GGML_TYPE_Q4_1:
9608
9647
case GGML_TYPE_Q5_0:
@@ -9641,6 +9680,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9641
9680
case GGML_OP_SUB:
9642
9681
case GGML_OP_MUL:
9643
9682
case GGML_OP_DIV:
9683
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9684
+ (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
9685
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
9644
9686
case GGML_OP_SILU_BACK:
9645
9687
case GGML_OP_RMS_NORM_BACK:
9646
9688
case GGML_OP_SQR:
0 commit comments