Skip to content

Commit 0a5a3b5

Browse files
authored
Add Conv2d for CPU (ggml-org#14388)
* Conv2D: Add CPU version * Half decent * Tiled approach for F32 * remove file * Fix tests * Support F16 operations * add assert about size * Review: further formatting fixes, add assert and use CPU version of fp32->fp16
1 parent 745f11f commit 0a5a3b5

File tree

5 files changed

+250
-3
lines changed

5 files changed

+250
-3
lines changed

ggml/include/ggml.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ extern "C" {
482482
GGML_OP_CONV_TRANSPOSE_1D,
483483
GGML_OP_IM2COL,
484484
GGML_OP_IM2COL_BACK,
485+
GGML_OP_CONV_2D,
485486
GGML_OP_CONV_2D_DW,
486487
GGML_OP_CONV_TRANSPOSE_2D,
487488
GGML_OP_POOL_1D,
@@ -1813,6 +1814,17 @@ extern "C" {
18131814
struct ggml_tensor * b,
18141815
int stride);
18151816

1817+
GGML_API struct ggml_tensor * ggml_conv_2d_direct(
1818+
struct ggml_context * ctx,
1819+
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
1820+
struct ggml_tensor * b, // input data [W, H, C, N]
1821+
int s0, // stride dimension 0
1822+
int s1, // stride dimension 1
1823+
int p0, // padding dimension 0
1824+
int p1, // padding dimension 1
1825+
int d0, // dilation dimension 0
1826+
int d1); // dilation dimension 1
1827+
18161828
enum ggml_op_pool {
18171829
GGML_OP_POOL_MAX,
18181830
GGML_OP_POOL_AVG,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
11931193
}
11941194
}
11951195

1196-
static void ggml_compute_forward_mul_mat(
1196+
void ggml_compute_forward_mul_mat(
11971197
const struct ggml_compute_params * params,
11981198
struct ggml_tensor * dst) {
11991199

@@ -1866,6 +1866,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18661866
{
18671867
ggml_compute_forward_im2col_back_f32(params, tensor);
18681868
} break;
1869+
case GGML_OP_CONV_2D:
1870+
{
1871+
ggml_compute_forward_conv_2d(params, tensor);
1872+
} break;
18691873
case GGML_OP_CONV_2D_DW:
18701874
{
18711875
ggml_compute_forward_conv_2d_dw(params, tensor);
@@ -2228,6 +2232,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22282232
} break;
22292233
case GGML_OP_IM2COL:
22302234
case GGML_OP_IM2COL_BACK:
2235+
case GGML_OP_CONV_2D:
22312236
case GGML_OP_CONV_2D_DW:
22322237
case GGML_OP_CONV_TRANSPOSE_1D:
22332238
case GGML_OP_CONV_TRANSPOSE_2D:
@@ -2746,6 +2751,10 @@ struct ggml_cplan ggml_graph_plan(
27462751
GGML_ABORT("fatal error");
27472752
}
27482753
} break;
2754+
case GGML_OP_CONV_2D:
2755+
{
2756+
cur = GGML_IM2COL_WORK_SIZE;
2757+
} break;
27492758
case GGML_OP_CONV_TRANSPOSE_2D:
27502759
{
27512760
const int64_t ne00 = node->src[0]->ne[0]; // W

ggml/src/ggml-cpu/ops.cpp

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "ggml-cpu.h"
44
#include "ggml-impl.h"
55
#include "binary-ops.h"
6+
#include "ggml.h"
67
#include "unary-ops.h"
78
#include "vec.h"
89

@@ -6545,6 +6546,186 @@ void ggml_compute_forward_im2col_back_f32(
65456546
}
65466547
}
65476548

6549+
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550+
void * a, void * b, float * c) {
6551+
const ggml_type_traits * traits = ggml_get_type_traits(type);
6552+
struct ggml_tensor src1 = {};
6553+
src1.type = type;
6554+
src1.ne[0] = k;
6555+
src1.ne[1] = m;
6556+
src1.ne[2] = 1;
6557+
src1.ne[3] = 1;
6558+
src1.nb[0] = traits->type_size;
6559+
src1.nb[1] = k * traits->type_size;
6560+
src1.nb[2] = src1.nb[1];
6561+
src1.nb[3] = src1.nb[2];
6562+
src1.data = a;
6563+
6564+
struct ggml_tensor src0 = {};
6565+
src0.type = type;
6566+
src0.ne[0] = k;
6567+
src0.ne[1] = n;
6568+
src0.ne[2] = 1;
6569+
src0.ne[3] = 1;
6570+
src0.nb[0] = traits->type_size;
6571+
src0.nb[1] = k * traits->type_size;
6572+
src0.nb[2] = src0.nb[1];
6573+
src0.nb[3] = src0.nb[2];
6574+
src0.data = b;
6575+
6576+
struct ggml_tensor dst = {};
6577+
dst.ne[0] = n;
6578+
dst.ne[1] = m;
6579+
dst.ne[2] = 1;
6580+
dst.ne[3] = 1;
6581+
dst.nb[0] = sizeof(float);
6582+
dst.nb[1] = n * sizeof(float);
6583+
dst.nb[2] = dst.nb[1];
6584+
dst.nb[3] = dst.nb[2];
6585+
dst.data = c;
6586+
dst.src[0] = &src0;
6587+
dst.src[1] = &src1;
6588+
6589+
ggml_compute_forward_mul_mat(params, &dst);
6590+
}
6591+
6592+
// ggml_compute_forward_conv_2d
6593+
6594+
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6595+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6596+
const ggml_tensor * src, // [W, H, C, N]
6597+
ggml_tensor * dst, // [OW, OH, OC, N]
6598+
ggml_type kernel_type) {
6599+
6600+
GGML_ASSERT(ggml_is_contiguous(kernel));
6601+
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6602+
GGML_ASSERT(kernel->type == kernel_type);
6603+
6604+
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6605+
6606+
const int32_t stride_x = dst->op_params[0];
6607+
const int32_t stride_y = dst->op_params[1];
6608+
const int32_t pad_x = dst->op_params[2];
6609+
const int32_t pad_y = dst->op_params[3];
6610+
const int32_t dilation_x = dst->op_params[4];
6611+
const int32_t dilation_y = dst->op_params[5];
6612+
6613+
const int64_t c_in = src->ne[2];
6614+
const int64_t c_out = kernel->ne[3];
6615+
GGML_ASSERT(c_in == kernel->ne[2]);
6616+
6617+
const int64_t src_w = src->ne[0];
6618+
const int64_t src_h = src->ne[1];
6619+
const int64_t knl_w = kernel->ne[0];
6620+
const int64_t knl_h = kernel->ne[1];
6621+
const int64_t dst_w = dst->ne[0];
6622+
const int64_t dst_h = dst->ne[1];
6623+
6624+
const float * src_data = (float *) src->data;
6625+
void * knl_data = kernel->data;
6626+
float * dst_data = (float *) dst->data;
6627+
6628+
const int64_t knl_n = knl_w * knl_h * c_in;
6629+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6630+
6631+
const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6632+
const int64_t batch_size = params->wsize / space_per_patch;
6633+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6634+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6635+
6636+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6637+
6638+
void * tmp = params->wdata;
6639+
6640+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6641+
6642+
const int64_t patch_start_batch = batch_i * patches_per_batch;
6643+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6644+
patch_total);
6645+
const int64_t patch_n = patch_end_batch - patch_start_batch;
6646+
6647+
const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6648+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6649+
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6650+
6651+
//im2col for a patch
6652+
for (int64_t p = patch_start; p < patch_end; ++p) {
6653+
const int64_t batch_n = p / (dst_w * dst_h);
6654+
const int64_t src_x = (p / dst_w) % dst_h;
6655+
const int64_t src_y = p % dst_w;
6656+
6657+
const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6658+
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6659+
6660+
for (int64_t ic = 0; ic < c_in; ++ic) {
6661+
for (int64_t ky = 0; ky < knl_h; ++ky) {
6662+
for (int64_t kx = 0; kx < knl_w; ++kx) {
6663+
const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6664+
const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6665+
6666+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6667+
6668+
float src_val;
6669+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6670+
src_val = 0.0f;
6671+
} else {
6672+
const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6673+
src_val = *src_ptr;
6674+
}
6675+
6676+
char * element_ptr = dst_row + dst_idx * traits->type_size;
6677+
if (kernel_type == GGML_TYPE_F32) {
6678+
*(float *) element_ptr = src_val;
6679+
} else if (kernel_type == GGML_TYPE_F16) {
6680+
*(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6681+
}
6682+
}
6683+
}
6684+
}
6685+
} // patches handled by this thread
6686+
6687+
ggml_barrier(params->threadpool);
6688+
6689+
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6690+
6691+
GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6692+
6693+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6694+
ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6695+
6696+
ggml_barrier(params->threadpool);
6697+
6698+
6699+
//permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6700+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6701+
const int64_t permute_start = params->ith * permute_per_thread;
6702+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6703+
6704+
for (int64_t i = permute_start; i < permute_end; ++i) {
6705+
const int64_t p = patch_start_batch + i;
6706+
const int64_t batch_n = p / (dst_w * dst_h);
6707+
const int64_t dst_y = (p / dst_w) % dst_h;
6708+
const int64_t dst_x = p % dst_w;
6709+
6710+
for (int64_t oc = 0; oc < c_out; ++oc) {
6711+
const float value = gemm_output[i * c_out + oc];
6712+
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
6713+
*dst_ptr = value;
6714+
}
6715+
}
6716+
}
6717+
}
6718+
6719+
void ggml_compute_forward_conv_2d(
6720+
const ggml_compute_params * params,
6721+
ggml_tensor * dst) {
6722+
6723+
const ggml_tensor * src0 = dst->src[0];
6724+
const ggml_tensor * src1 = dst->src[1];
6725+
6726+
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6727+
}
6728+
65486729
// ggml_compute_forward_conv_transpose_2d
65496730

65506731
void ggml_compute_forward_conv_transpose_2d(

ggml/src/ggml-cpu/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
2222

23+
// Work buffer size for im2col operations in CONV2D
24+
#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024)
25+
2326
#ifdef __cplusplus
2427
extern "C" {
2528
#endif
@@ -65,6 +68,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
6568
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6669
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6770
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
71+
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6872
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6973
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7074
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -107,6 +111,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
107111
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108112
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109113
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
114+
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
110115

111116
#ifdef __cplusplus
112117
}

ggml/src/ggml.c

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
945945
"CONV_TRANSPOSE_1D",
946946
"IM2COL",
947947
"IM2COL_BACK",
948+
"CONV_2D",
948949
"CONV_2D_DW",
949950
"CONV_TRANSPOSE_2D",
950951
"POOL_1D",
@@ -986,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
986987
"GLU",
987988
};
988989

989-
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
990+
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
990991

991992
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
992993
"none",
@@ -1044,6 +1045,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10441045
"conv_transpose_1d(x)",
10451046
"im2col(x)",
10461047
"im2col_back(x)",
1048+
"conv_2d(x)",
10471049
"conv_2d_dw(x)",
10481050
"conv_transpose_2d(x)",
10491051
"pool_1d(x)",
@@ -1085,7 +1087,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10851087
"glu(x)",
10861088
};
10871089

1088-
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1090+
static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
10891091

10901092
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10911093

@@ -4291,6 +4293,44 @@ struct ggml_tensor * ggml_conv_2d_dw_direct(
42914293
return result;
42924294
}
42934295

4296+
// ggml_conv_2d_direct
4297+
4298+
struct ggml_tensor * ggml_conv_2d_direct(
4299+
struct ggml_context * ctx,
4300+
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
4301+
struct ggml_tensor * b, // input data [W, H, C, N]
4302+
int s0, // stride dimension 0
4303+
int s1, // stride dimension 1
4304+
int p0, // padding dimension 0
4305+
int p1, // padding dimension 1
4306+
int d0, // dilation dimension 0
4307+
int d1) {// dilation dimension 1
4308+
4309+
GGML_ASSERT(a->ne[2] == b->ne[2]);
4310+
//GGML_ASSERT(a->type == b->type);
4311+
4312+
int64_t ne[4];
4313+
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4314+
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4315+
ne[2] = a->ne[3];
4316+
ne[3] = b->ne[3];
4317+
4318+
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4319+
4320+
ggml_set_op_params_i32(result, 0, s0);
4321+
ggml_set_op_params_i32(result, 1, s1);
4322+
ggml_set_op_params_i32(result, 2, p0);
4323+
ggml_set_op_params_i32(result, 3, p1);
4324+
ggml_set_op_params_i32(result, 4, d0);
4325+
ggml_set_op_params_i32(result, 5, d1);
4326+
4327+
result->op = GGML_OP_CONV_2D;
4328+
result->src[0] = a;
4329+
result->src[1] = b;
4330+
4331+
return result;
4332+
}
4333+
42944334
// ggml_conv_transpose_2d_p0
42954335

42964336
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {

0 commit comments

Comments
 (0)