@@ -1350,16 +1350,16 @@ kernel void kernel_ssm_scan_f32_group(
1350
1350
1351
1351
device const int32_t * ids = (device const int32_t *) src6;
1352
1352
1353
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
1354
- device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
1353
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args. nb02 + ids[i3]*args. nb03 );
1354
+ device float * s = (device float *) ((device char *) dst + ir*args. nb02 + i3*args. nb03 + s_off);
1355
1355
1356
1356
for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1357
- device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
1358
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns}
1359
- device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh}
1360
- device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns}
1361
- device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns}
1362
- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
1357
+ device const float * x = (device const float *) ((device const char *) src1 + i1*args. nb10 + ir*args. nb11 + i2*args. nb12 + i3*args. nb13 ); // {dim, nh, nt, ns}
1358
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*args. nb20 + i2*args. nb21 + i3*args. nb22 ); // {nh, nt, ns}
1359
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args. nb31 ); // {1, nh}
1360
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args. nb41 + i2*args. nb42 + i3*args. nb43 ); // {d_state, ng, nt, ns}
1361
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args. nb51 + i2*args. nb52 + i3*args. nb53 ); // {d_state, ng, nt, ns}
1362
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*args. nb00 ); // {dim, nh, nt, ns}
1363
1363
1364
1364
const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
1365
1365
const float x_dt = x[0 ] * dt_soft_plus;
0 commit comments