-
Notifications
You must be signed in to change notification settings - Fork 12.2k
ggml : add ggml_set_rows #14274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : add ggml_set_rows #14274
Changes from 13 commits
3788aa7
b2bd0a7
788fb0f
2f3a43d
93a568c
6b9b86d
b13524c
b597019
f7d0aab
a881dc2
a92973e
1a5b2a1
3c33124
929e118
56d6914
8f0f615
838e89d
f46ddba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,6 +192,7 @@ typedef pthread_t ggml_thread_t; | |
|
||
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { | ||
[GGML_TYPE_F32] = { | ||
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Attention here |
||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, | ||
.vec_dot_type = GGML_TYPE_F32, | ||
.nrows = 1, | ||
|
@@ -1814,6 +1815,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm | |
{ | ||
ggml_compute_forward_get_rows_back(params, tensor); | ||
} break; | ||
case GGML_OP_SET_ROWS: | ||
{ | ||
ggml_compute_forward_set_rows(params, tensor); | ||
} break; | ||
case GGML_OP_DIAG: | ||
{ | ||
ggml_compute_forward_diag(params, tensor); | ||
|
@@ -2167,6 +2172,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { | |
n_tasks = n_threads; | ||
} break; | ||
case GGML_OP_GET_ROWS: | ||
case GGML_OP_SET_ROWS: | ||
{ | ||
// FIXME: get_rows can use additional threads, but the cost of launching additional threads | ||
// decreases performance with GPU offloading | ||
|
@@ -3121,6 +3127,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g | |
return ggml_graph_compute(cgraph, &cplan); | ||
} | ||
|
||
void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) { | ||
memcpy(y, x, n * sizeof(float)); | ||
} | ||
|
||
void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { | ||
int64_t i = 0; | ||
#if defined(__F16C__) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32( | |
if (ggml_is_contiguous(dst)) { | ||
// TODO: simplify | ||
if (nb00 == sizeof(float)) { | ||
if (dst->type == GGML_TYPE_F32) { | ||
size_t id = 0; | ||
const size_t rs = ne00 * nb00; | ||
char * dst_ptr = (char *) dst->data; | ||
|
||
for (int i03 = 0; i03 < ne03; i03++) { | ||
for (int i02 = 0; i02 < ne02; i02++) { | ||
id += rs * ir0; | ||
for (int i01 = ir0; i01 < ir1; i01++) { | ||
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; | ||
memcpy(dst_ptr + id, src0_ptr, rs); | ||
id += rs; | ||
} | ||
id += rs * (ne01 - ir1); | ||
} | ||
} | ||
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) { | ||
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; | ||
if (ggml_get_type_traits_cpu(dst->type)->from_float) { | ||
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float; | ||
|
||
size_t id = 0; | ||
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); | ||
|
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32( | |
id += rs * ir0; | ||
for (int i01 = ir0; i01 < ir1; i01++) { | ||
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); | ||
quantize_row_q(src0_ptr, dst_ptr + id, ne00); | ||
from_float(src0_ptr, dst_ptr + id, ne00); | ||
id += rs; | ||
} | ||
id += rs * (ne01 - ir1); | ||
|
@@ -2282,6 +2266,52 @@ static void ggml_compute_forward_repeat_f16( | |
} | ||
} | ||
|
||
static void ggml_compute_forward_repeat_i64( | ||
const ggml_compute_params * params, | ||
ggml_tensor * dst) { | ||
|
||
const ggml_tensor * src0 = dst->src[0]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After adding the broadcast support to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to use a template instead of duplicating the code however. We need to start somewhere porting this code to C++. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, will add template in a follow-up PR. For now, removed the i64 support and added TODO. |
||
|
||
if (params->ith != 0) { | ||
return; | ||
} | ||
|
||
GGML_ASSERT(ggml_can_repeat(src0, dst)); | ||
|
||
GGML_TENSOR_UNARY_OP_LOCALS | ||
|
||
// guaranteed to be an integer due to the check in ggml_can_repeat | ||
const int nr0 = (int)(ne0/ne00); | ||
const int nr1 = (int)(ne1/ne01); | ||
const int nr2 = (int)(ne2/ne02); | ||
const int nr3 = (int)(ne3/ne03); | ||
|
||
// TODO: support for transposed / permuted tensors | ||
GGML_ASSERT(nb0 == sizeof(int64_t)); | ||
GGML_ASSERT(nb00 == sizeof(int64_t)); | ||
|
||
// TODO: maybe this is not optimal? | ||
for (int i3 = 0; i3 < nr3; i3++) { | ||
for (int k3 = 0; k3 < ne03; k3++) { | ||
for (int i2 = 0; i2 < nr2; i2++) { | ||
for (int k2 = 0; k2 < ne02; k2++) { | ||
for (int i1 = 0; i1 < nr1; i1++) { | ||
for (int k1 = 0; k1 < ne01; k1++) { | ||
for (int i0 = 0; i0 < nr0; i0++) { | ||
int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); | ||
int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); | ||
for (int i = 0; i < ne00; ++i) { | ||
y[i] = x[i]; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
void ggml_compute_forward_repeat( | ||
const ggml_compute_params * params, | ||
ggml_tensor * dst) { | ||
|
@@ -2300,6 +2330,10 @@ void ggml_compute_forward_repeat( | |
{ | ||
ggml_compute_forward_repeat_f32(params, dst); | ||
} break; | ||
case GGML_TYPE_I64: | ||
{ | ||
ggml_compute_forward_repeat_i64(params, dst); | ||
} break; | ||
default: | ||
{ | ||
GGML_ABORT("fatal error"); | ||
|
@@ -4470,6 +4504,74 @@ void ggml_compute_forward_get_rows( | |
//} | ||
} | ||
|
||
static void ggml_compute_forward_set_rows_f32( | ||
const ggml_compute_params * params, | ||
ggml_tensor * dst) { | ||
|
||
const ggml_tensor * src0 = dst->src[0]; | ||
const ggml_tensor * src1 = dst->src[1]; | ||
|
||
GGML_TENSOR_BINARY_OP_LOCALS | ||
|
||
const int64_t nc = ne00; | ||
const int64_t nr = ne01; | ||
|
||
assert(ne0 == nc); | ||
assert(ne2 == ne02); | ||
assert(ne3 == ne03); | ||
assert(src0->type == GGML_TYPE_F32); | ||
assert(ne02 % ne11 == 0); | ||
assert(ne03 % ne12 == 0); | ||
|
||
const int ith = params->ith; | ||
const int nth = params->nth; | ||
|
||
// rows per thread | ||
const int dr = (nr + nth - 1)/nth; | ||
|
||
// row range for this thread | ||
const int ir0 = dr*ith; | ||
const int ir1 = MIN(ir0 + dr, nr); | ||
ggerganov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float; | ||
|
||
for (int64_t i03 = 0; i03 < ne03; ++i03) { | ||
for (int64_t i02 = 0; i02 < ne02; ++i02) { | ||
for (int64_t i = ir0; i < ir1; ++i) { | ||
const int64_t i12 = i03%ne12; | ||
const int64_t i11 = i02%ne11; | ||
const int64_t i10 = i; | ||
|
||
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); | ||
|
||
GGML_ASSERT(i1 >= 0 && i1 < ne1); | ||
|
||
from_float( | ||
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), | ||
((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc); | ||
} | ||
} | ||
} | ||
} | ||
|
||
void ggml_compute_forward_set_rows( | ||
const ggml_compute_params * params, | ||
ggml_tensor * dst) { | ||
|
||
const ggml_tensor * src0 = dst->src[0]; | ||
|
||
switch (src0->type) { | ||
case GGML_TYPE_F32: | ||
{ | ||
ggml_compute_forward_set_rows_f32(params, dst); | ||
ggerganov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} break; | ||
default: | ||
{ | ||
GGML_ABORT("fatal error"); | ||
ggerganov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
} | ||
|
||
// ggml_compute_forward_get_rows_back | ||
|
||
static void ggml_compute_forward_get_rows_back_f32_f16( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attention here