|
3 | 3 | #include "ggml-cpu.h"
|
4 | 4 | #include "ggml-impl.h"
|
5 | 5 | #include "binary-ops.h"
|
| 6 | +#include "ggml.h" |
6 | 7 | #include "unary-ops.h"
|
7 | 8 | #include "vec.h"
|
8 | 9 |
|
@@ -6545,6 +6546,186 @@ void ggml_compute_forward_im2col_back_f32(
|
6545 | 6546 | }
|
6546 | 6547 | }
|
6547 | 6548 |
|
| 6549 | +static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k, |
| 6550 | + void * a, void * b, float * c) { |
| 6551 | + const ggml_type_traits * traits = ggml_get_type_traits(type); |
| 6552 | + struct ggml_tensor src1 = {}; |
| 6553 | + src1.type = type; |
| 6554 | + src1.ne[0] = k; |
| 6555 | + src1.ne[1] = m; |
| 6556 | + src1.ne[2] = 1; |
| 6557 | + src1.ne[3] = 1; |
| 6558 | + src1.nb[0] = traits->type_size; |
| 6559 | + src1.nb[1] = k * traits->type_size; |
| 6560 | + src1.nb[2] = src1.nb[1]; |
| 6561 | + src1.nb[3] = src1.nb[2]; |
| 6562 | + src1.data = a; |
| 6563 | + |
| 6564 | + struct ggml_tensor src0 = {}; |
| 6565 | + src0.type = type; |
| 6566 | + src0.ne[0] = k; |
| 6567 | + src0.ne[1] = n; |
| 6568 | + src0.ne[2] = 1; |
| 6569 | + src0.ne[3] = 1; |
| 6570 | + src0.nb[0] = traits->type_size; |
| 6571 | + src0.nb[1] = k * traits->type_size; |
| 6572 | + src0.nb[2] = src0.nb[1]; |
| 6573 | + src0.nb[3] = src0.nb[2]; |
| 6574 | + src0.data = b; |
| 6575 | + |
| 6576 | + struct ggml_tensor dst = {}; |
| 6577 | + dst.ne[0] = n; |
| 6578 | + dst.ne[1] = m; |
| 6579 | + dst.ne[2] = 1; |
| 6580 | + dst.ne[3] = 1; |
| 6581 | + dst.nb[0] = sizeof(float); |
| 6582 | + dst.nb[1] = n * sizeof(float); |
| 6583 | + dst.nb[2] = dst.nb[1]; |
| 6584 | + dst.nb[3] = dst.nb[2]; |
| 6585 | + dst.data = c; |
| 6586 | + dst.src[0] = &src0; |
| 6587 | + dst.src[1] = &src1; |
| 6588 | + |
| 6589 | + ggml_compute_forward_mul_mat(params, &dst); |
| 6590 | +} |
| 6591 | + |
| 6592 | +// ggml_compute_forward_conv_2d |
| 6593 | + |
| 6594 | +static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params, |
| 6595 | + const ggml_tensor * kernel, // [KW, KH, IC, OC] |
| 6596 | + const ggml_tensor * src, // [W, H, C, N] |
| 6597 | + ggml_tensor * dst, // [OW, OH, OC, N] |
| 6598 | + ggml_type kernel_type) { |
| 6599 | + |
| 6600 | + GGML_ASSERT(ggml_is_contiguous(kernel)); |
| 6601 | + GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32); |
| 6602 | + GGML_ASSERT(kernel->type == kernel_type); |
| 6603 | + |
| 6604 | + const ggml_type_traits * traits = ggml_get_type_traits(kernel_type); |
| 6605 | + |
| 6606 | + const int32_t stride_x = dst->op_params[0]; |
| 6607 | + const int32_t stride_y = dst->op_params[1]; |
| 6608 | + const int32_t pad_x = dst->op_params[2]; |
| 6609 | + const int32_t pad_y = dst->op_params[3]; |
| 6610 | + const int32_t dilation_x = dst->op_params[4]; |
| 6611 | + const int32_t dilation_y = dst->op_params[5]; |
| 6612 | + |
| 6613 | + const int64_t c_in = src->ne[2]; |
| 6614 | + const int64_t c_out = kernel->ne[3]; |
| 6615 | + GGML_ASSERT(c_in == kernel->ne[2]); |
| 6616 | + |
| 6617 | + const int64_t src_w = src->ne[0]; |
| 6618 | + const int64_t src_h = src->ne[1]; |
| 6619 | + const int64_t knl_w = kernel->ne[0]; |
| 6620 | + const int64_t knl_h = kernel->ne[1]; |
| 6621 | + const int64_t dst_w = dst->ne[0]; |
| 6622 | + const int64_t dst_h = dst->ne[1]; |
| 6623 | + |
| 6624 | + const float * src_data = (float *) src->data; |
| 6625 | + void * knl_data = kernel->data; |
| 6626 | + float * dst_data = (float *) dst->data; |
| 6627 | + |
| 6628 | + const int64_t knl_n = knl_w * knl_h * c_in; |
| 6629 | + const int64_t patch_total = dst->ne[3] * dst_w * dst_h; |
| 6630 | + |
| 6631 | + const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float); |
| 6632 | + const int64_t batch_size = params->wsize / space_per_patch; |
| 6633 | + const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size; |
| 6634 | + const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch; |
| 6635 | + |
| 6636 | + GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1); |
| 6637 | + |
| 6638 | + void * tmp = params->wdata; |
| 6639 | + |
| 6640 | + for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) { |
| 6641 | + |
| 6642 | + const int64_t patch_start_batch = batch_i * patches_per_batch; |
| 6643 | + const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, |
| 6644 | + patch_total); |
| 6645 | + const int64_t patch_n = patch_end_batch - patch_start_batch; |
| 6646 | + |
| 6647 | + const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth; |
| 6648 | + const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread; |
| 6649 | + const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch); |
| 6650 | + |
| 6651 | + //im2col for a patch |
| 6652 | + for (int64_t p = patch_start; p < patch_end; ++p) { |
| 6653 | + const int64_t batch_n = p / (dst_w * dst_h); |
| 6654 | + const int64_t src_x = (p / dst_w) % dst_h; |
| 6655 | + const int64_t src_y = p % dst_w; |
| 6656 | + |
| 6657 | + const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]); |
| 6658 | + char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size; |
| 6659 | + |
| 6660 | + for (int64_t ic = 0; ic < c_in; ++ic) { |
| 6661 | + for (int64_t ky = 0; ky < knl_h; ++ky) { |
| 6662 | + for (int64_t kx = 0; kx < knl_w; ++kx) { |
| 6663 | + const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y; |
| 6664 | + const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x; |
| 6665 | + |
| 6666 | + int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx; |
| 6667 | + |
| 6668 | + float src_val; |
| 6669 | + if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) { |
| 6670 | + src_val = 0.0f; |
| 6671 | + } else { |
| 6672 | + const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]); |
| 6673 | + src_val = *src_ptr; |
| 6674 | + } |
| 6675 | + |
| 6676 | + char * element_ptr = dst_row + dst_idx * traits->type_size; |
| 6677 | + if (kernel_type == GGML_TYPE_F32) { |
| 6678 | + *(float *) element_ptr = src_val; |
| 6679 | + } else if (kernel_type == GGML_TYPE_F16) { |
| 6680 | + *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val); |
| 6681 | + } |
| 6682 | + } |
| 6683 | + } |
| 6684 | + } |
| 6685 | + } // patches handled by this thread |
| 6686 | + |
| 6687 | + ggml_barrier(params->threadpool); |
| 6688 | + |
| 6689 | + float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size); |
| 6690 | + |
| 6691 | + GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize); |
| 6692 | + |
| 6693 | + // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out] |
| 6694 | + ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output); |
| 6695 | + |
| 6696 | + ggml_barrier(params->threadpool); |
| 6697 | + |
| 6698 | + |
| 6699 | + //permute back [OC, N, OH, OW] to [N, OC, OH, OW] |
| 6700 | + const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth; |
| 6701 | + const int64_t permute_start = params->ith * permute_per_thread; |
| 6702 | + const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n); |
| 6703 | + |
| 6704 | + for (int64_t i = permute_start; i < permute_end; ++i) { |
| 6705 | + const int64_t p = patch_start_batch + i; |
| 6706 | + const int64_t batch_n = p / (dst_w * dst_h); |
| 6707 | + const int64_t dst_y = (p / dst_w) % dst_h; |
| 6708 | + const int64_t dst_x = p % dst_w; |
| 6709 | + |
| 6710 | + for (int64_t oc = 0; oc < c_out; ++oc) { |
| 6711 | + const float value = gemm_output[i * c_out + oc]; |
| 6712 | + float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]); |
| 6713 | + *dst_ptr = value; |
| 6714 | + } |
| 6715 | + } |
| 6716 | + } |
| 6717 | +} |
| 6718 | + |
| 6719 | +void ggml_compute_forward_conv_2d( |
| 6720 | + const ggml_compute_params * params, |
| 6721 | + ggml_tensor * dst) { |
| 6722 | + |
| 6723 | + const ggml_tensor * src0 = dst->src[0]; |
| 6724 | + const ggml_tensor * src1 = dst->src[1]; |
| 6725 | + |
| 6726 | + ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type); |
| 6727 | +} |
| 6728 | + |
6548 | 6729 | // ggml_compute_forward_conv_transpose_2d
|
6549 | 6730 |
|
6550 | 6731 | void ggml_compute_forward_conv_transpose_2d(
|
|
0 commit comments