Skip to content

Commit 7568d1a

Browse files
authored
Support dup & cont ops on CUDA (ggml-org#2242)
1 parent b764743 commit 7568d1a

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

ggml-cuda.cu

+18-1
Original file line numberDiff line numberDiff line change
@@ -3537,6 +3537,11 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
35373537
(void) dst;
35383538
}
35393539

3540+
void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3541+
ggml_cuda_cpy(src0, dst, nullptr);
3542+
(void) src1;
3543+
}
3544+
35403545
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
35413546
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
35423547
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
@@ -3670,7 +3675,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
36703675
// recursively assign CUDA buffers until a compute tensor is found
36713676
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
36723677
const ggml_op src0_op = tensor->src[0]->op;
3673-
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
3678+
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
36743679
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
36753680
}
36763681
}
@@ -3776,6 +3781,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
37763781
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
37773782

37783783
switch (tensor->op) {
3784+
case GGML_OP_DUP:
3785+
if (!any_on_device) {
3786+
return false;
3787+
}
3788+
func = ggml_cuda_dup;
3789+
break;
37793790
case GGML_OP_ADD:
37803791
if (!any_on_device) {
37813792
return false;
@@ -3830,6 +3841,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
38303841
}
38313842
func = ggml_cuda_cpy;
38323843
break;
3844+
case GGML_OP_CONT:
3845+
if (!any_on_device) {
3846+
return false;
3847+
}
3848+
func = ggml_cuda_dup;
3849+
break;
38333850
case GGML_OP_RESHAPE:
38343851
case GGML_OP_VIEW:
38353852
case GGML_OP_PERMUTE:

0 commit comments

Comments
 (0)