@@ -3537,6 +3537,11 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
3537
3537
(void ) dst;
3538
3538
}
3539
3539
3540
+ void ggml_cuda_dup (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3541
+ ggml_cuda_cpy (src0, dst, nullptr );
3542
+ (void ) src1;
3543
+ }
3544
+
3540
3545
void ggml_cuda_diag_mask_inf (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3541
3546
GGML_ASSERT (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
3542
3547
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_diag_mask_inf, true , true );
@@ -3670,7 +3675,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
3670
3675
// recursively assign CUDA buffers until a compute tensor is found
3671
3676
if (tensor->src [0 ] != nullptr && tensor->src [0 ]->backend == GGML_BACKEND_CPU) {
3672
3677
const ggml_op src0_op = tensor->src [0 ]->op ;
3673
- if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
3678
+ if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE ) {
3674
3679
ggml_cuda_assign_buffers_impl (tensor->src [0 ], scratch, force_inplace);
3675
3680
}
3676
3681
}
@@ -3776,6 +3781,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
3776
3781
|| (tensor->src [1 ] != nullptr && tensor->src [1 ]->backend == GGML_BACKEND_GPU);
3777
3782
3778
3783
switch (tensor->op ) {
3784
+ case GGML_OP_DUP:
3785
+ if (!any_on_device) {
3786
+ return false ;
3787
+ }
3788
+ func = ggml_cuda_dup;
3789
+ break ;
3779
3790
case GGML_OP_ADD:
3780
3791
if (!any_on_device) {
3781
3792
return false ;
@@ -3830,6 +3841,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
3830
3841
}
3831
3842
func = ggml_cuda_cpy;
3832
3843
break ;
3844
+ case GGML_OP_CONT:
3845
+ if (!any_on_device) {
3846
+ return false ;
3847
+ }
3848
+ func = ggml_cuda_dup;
3849
+ break ;
3833
3850
case GGML_OP_RESHAPE:
3834
3851
case GGML_OP_VIEW:
3835
3852
case GGML_OP_PERMUTE:
0 commit comments