Skip to content

Commit 6def5cd

Browse files
committed
metal : add missing args for nb references in ssm_scan_f32_group
1 parent cf4f0a4 commit 6def5cd

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ggml/src/ggml-metal/ggml-metal.metal

+8-8
Original file line numberDiff line numberDiff line change
@@ -1350,16 +1350,16 @@ kernel void kernel_ssm_scan_f32_group(
13501350

13511351
device const int32_t * ids = (device const int32_t *) src6;
13521352

1353-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
1354-
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
1353+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1354+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
13551355

13561356
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1357-
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
1358-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns}
1359-
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh}
1360-
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
1361-
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
1362-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1357+
device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1358+
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1359+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1360+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1361+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1362+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns}
13631363

13641364
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
13651365
const float x_dt = x[0] * dt_soft_plus;

0 commit comments

Comments
 (0)