@@ -268,6 +268,14 @@ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
268
268
{ MODEL_LARGE, 71ull *MB },
269
269
};
270
270
271
+ static const std::map<e_model, size_t > MEM_REQ_KV_ENC_SELF = {
272
+ { MODEL_TINY, 23ull *MB },
273
+ { MODEL_BASE, 26ull *MB },
274
+ { MODEL_SMALL, 216ull *MB },
275
+ { MODEL_MEDIUM, 243ull *MB },
276
+ { MODEL_LARGE, 271ull *MB },
277
+ };
278
+
271
279
static const std::map<e_model, size_t > MEM_REQ_KV_CROSS = {
272
280
{ MODEL_TINY, 9ull *MB },
273
281
{ MODEL_BASE, 18ull *MB },
@@ -571,6 +579,7 @@ struct whisper_context {
571
579
// cross-attention KV cache for the decoders
572
580
// shared between all decoders
573
581
whisper_kv_cache kv_cross;
582
+ whisper_kv_cache kv_enc_self;
574
583
575
584
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
576
585
@@ -807,7 +816,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
807
816
MEM_REQ_SCRATCH3.at (model.type ) +
808
817
scale*MEM_REQ_MODEL.at (model.type ) +
809
818
scale*MEM_REQ_KV_CROSS.at (model.type ) +
810
- scale*std::max (MEM_REQ_ENCODE.at (model.type ), MEM_REQ_DECODE.at (model.type ));
819
+ scale*std::max (MEM_REQ_ENCODE.at (model.type ), MEM_REQ_DECODE.at (model.type ));
811
820
812
821
// this is the memory required by one decoder
813
822
const size_t mem_required_decoder =
@@ -838,6 +847,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
838
847
return false ;
839
848
}
840
849
850
+ if (!kv_cache_init (model.hparams , scale*MEM_REQ_KV_ENC_SELF.at (model.type ), wctx.kv_enc_self , wctx.wtype , model.hparams .n_audio_ctx )) {
851
+ fprintf (stderr, " %s: kv_cache_init() failed for cross-attention cache\n " , __func__);
852
+ return false ;
853
+ }
854
+
841
855
{
842
856
const size_t memory_size = ggml_nbytes (wctx.kv_cross .k ) + ggml_nbytes (wctx.kv_cross .v );
843
857
fprintf (stderr, " %s: kv cross size = %7.2f MB\n " , __func__, memory_size/1024.0 /1024.0 );
@@ -1415,6 +1429,9 @@ static bool whisper_encode(
1415
1429
}
1416
1430
}
1417
1431
1432
+ struct ggml_cgraph gf = {};
1433
+ gf.n_threads = n_threads;
1434
+
1418
1435
struct ggml_tensor * cur;
1419
1436
1420
1437
// convolution + gelu
@@ -1442,6 +1459,18 @@ static bool whisper_encode(
1442
1459
cur = ggml_gelu (ctx0, cur);
1443
1460
}
1444
1461
1462
+ // {
1463
+ // //printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
1464
+
1465
+ // wctx.use_buf(ctx0, -1);
1466
+
1467
+ // struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(0*n_ctx));
1468
+ // //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
1469
+
1470
+ // ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
1471
+ // //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1472
+ // }
1473
+
1445
1474
wctx.use_buf (ctx0, 3 );
1446
1475
1447
1476
// ===================================================================
@@ -1522,6 +1551,18 @@ static bool whisper_encode(
1522
1551
Vcur),
1523
1552
Vcur);
1524
1553
1554
+ {
1555
+ // printf("Kcur: %d %d %d %d, size element = %d\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], Kcur->ne[3], ggml_element_size(Kcur));
1556
+
1557
+ wctx.use_buf (ctx0, -1 );
1558
+
1559
+ struct ggml_tensor * k = ggml_view_1d (ctx0, wctx.kv_enc_self .k , n_state*n_ctx, (ggml_element_size (wctx.kv_enc_self .k )*n_state)*(il*n_ctx));
1560
+ struct ggml_tensor * v = ggml_view_1d (ctx0, wctx.kv_enc_self .v , n_state*n_ctx, (ggml_element_size (wctx.kv_enc_self .v )*n_state)*(il*n_ctx));
1561
+
1562
+ ggml_build_forward_expand (&gf, ggml_cpy (ctx0, Kcur, k));
1563
+ ggml_build_forward_expand (&gf, ggml_cpy (ctx0, Vcur, v));
1564
+ }
1565
+
1525
1566
// ------
1526
1567
1527
1568
wctx.use_buf (ctx0, 0 );
@@ -1606,6 +1647,18 @@ static bool whisper_encode(
1606
1647
cur = ggml_cpy (ctx0,
1607
1648
KQV_merged,
1608
1649
ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_state, n_ctx));
1650
+
1651
+ // {
1652
+ // //printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
1653
+
1654
+ // wctx.use_buf(ctx0, -1);
1655
+
1656
+ // struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
1657
+ // //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
1658
+
1659
+ // ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
1660
+ // //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1661
+ // }
1609
1662
}
1610
1663
1611
1664
// projection
@@ -1715,8 +1768,6 @@ static bool whisper_encode(
1715
1768
1716
1769
// run the computation
1717
1770
{
1718
- struct ggml_cgraph gf = {};
1719
- gf.n_threads = n_threads;
1720
1771
1721
1772
ggml_build_forward_expand (&gf, cur);
1722
1773
ggml_graph_compute (ctx0, &gf);
@@ -4858,7 +4909,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
4858
4909
const int n_state = ctx->model .hparams .n_audio_state ;
4859
4910
const int n_layer = ctx->model .hparams .n_audio_layer ;
4860
4911
4861
- #if 1
4912
+ #if 0
4862
4913
// use the last layer of the encoder
4863
4914
{
4864
4915
std::vector<float> embd(n_segments*n_state);
@@ -4878,7 +4929,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
4878
4929
const int n_features = std::min(4, n_segments);
4879
4930
4880
4931
ggml_svd_reduce_dims(n_state, n_segments, embd.data(), n_features);
4881
- #else
4932
+ #elif 0
4882
4933
// use cross kv cache of various layers
4883
4934
for (int il = 0; il < n_layer; ++il) {
4884
4935
std::vector<float> embd(n_segments*n_ctx*n_state);
@@ -4900,6 +4951,29 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
4900
4951
4901
4952
const int n_features = std::min(4, n_segments);
4902
4953
4954
+ ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
4955
+ #else
4956
+ // use enc self kv cache of various layers
4957
+ for (int il = 0 ; il < n_layer; ++il) {
4958
+ std::vector<float > embd (n_segments*n_ctx*n_state);
4959
+
4960
+ for (int i = 0 ; i < n_segments; ++i) {
4961
+ const auto & segment_i = ctx->result_all [i];
4962
+ printf (" %s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n " , __func__, il, i, (int ) segment_i.t0 , (int ) segment_i.t1 , segment_i.text .c_str ());
4963
+
4964
+ ctx->mel .n_len = segment_i.t1 ;
4965
+ whisper_encode (*ctx, segment_i.t0 , 7 , true );
4966
+
4967
+ const size_t offs = ggml_element_size (ctx->kv_enc_self .k )*(il*n_ctx*n_state);
4968
+ const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self .k ->data + offs);
4969
+
4970
+ for (int j = 0 ; j < n_ctx*n_state; ++j) {
4971
+ embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32 (f[j]);
4972
+ }
4973
+ }
4974
+
4975
+ const int n_features = std::min (16 , n_segments);
4976
+
4903
4977
ggml_svd_reduce_dims (n_ctx*n_state, n_segments, embd.data (), n_features);
4904
4978
#endif
4905
4979
@@ -4973,6 +5047,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
4973
5047
double d0 = 0.0 ;
4974
5048
double d1 = 0.0 ;
4975
5049
5050
+ #if 0
4976
5051
// use the euclidean distance
4977
5052
{
4978
5053
for (int m = 0; m < n_features; ++m) {
@@ -4985,35 +5060,36 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
4985
5060
}
4986
5061
d1 = std::sqrt(d1);
4987
5062
}
4988
-
5063
+ # else
4989
5064
// use the cosine distance
4990
- // {
4991
- // double dot = 0.0;
4992
- // double norm0 = 0.0;
4993
- // double norm1 = 0.0;
5065
+ {
5066
+ double dot = 0.0 ;
5067
+ double norm0 = 0.0 ;
5068
+ double norm1 = 0.0 ;
4994
5069
4995
- // for (int m = 0; m < n_features; ++m) {
4996
- // dot += features[j][m]*centroids[k][m];
4997
- // norm0 += std::pow(features[j][m], 2.0);
4998
- // norm1 += std::pow(centroids[k][m], 2.0);
4999
- // }
5070
+ for (int m = 0 ; m < n_features; ++m) {
5071
+ dot += features[j][m]*centroids[k][m];
5072
+ norm0 += std::pow (features[j][m], 2.0 );
5073
+ norm1 += std::pow (centroids[k][m], 2.0 );
5074
+ }
5000
5075
5001
- // d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
5076
+ d0 = 1.0 - dot/(std::sqrt (norm0)*std::sqrt (norm1));
5002
5077
5003
- // dot = 0.0;
5004
- // norm0 = 0.0;
5005
- // norm1 = 0.0;
5078
+ dot = 0.0 ;
5079
+ norm0 = 0.0 ;
5080
+ norm1 = 0.0 ;
5006
5081
5007
- // for (int m = 0; m < n_features; ++m) {
5008
- // dot += features[j][m]*centroids[l][m];
5009
- // norm0 += std::pow(features[j][m], 2.0);
5010
- // norm1 += std::pow(centroids[l][m], 2.0);
5011
- // }
5082
+ for (int m = 0 ; m < n_features; ++m) {
5083
+ dot += features[j][m]*centroids[l][m];
5084
+ norm0 += std::pow (features[j][m], 2.0 );
5085
+ norm1 += std::pow (centroids[l][m], 2.0 );
5086
+ }
5012
5087
5013
- // d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
5014
- // }
5088
+ d1 = 1.0 - dot/(std::sqrt (norm0)*std::sqrt (norm1));
5089
+ }
5090
+ #endif
5015
5091
5016
- sum += std::pow (d0/d1, 2.0 /(1.15 - 1.0 ));
5092
+ sum += std::pow (d0/d1, 2.0 /(2.0 - 1.0 ));
5017
5093
}
5018
5094
5019
5095
membership[j][k] = sum == 0.0 ? 0.0 : 1.0 /sum;
@@ -5024,16 +5100,19 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
5024
5100
if (i == niter - 1 ) {
5025
5101
// {
5026
5102
for (int i = 0 ; i < n_segments; ++i) {
5103
+ #if 1
5027
5104
printf (" %s: membership %3d: " , __func__, i);
5028
5105
for (int j = 0 ; j < n_clusters; ++j) {
5029
- printf (" %f " , membership[i][j]);
5106
+ printf (" %.1f " , membership[i][j]);
5030
5107
}
5031
5108
printf (" '%s'\n " , ctx->result_all [i].text .c_str ());
5032
- // printf("%s: features : ", __func__);
5033
- // for (int j = 0; j < n_features; ++j) {
5034
- // printf("%8.3f ", features[i][j]);
5035
- // }
5036
- // printf(" '%s'\n", ctx->result_all[i].text.c_str());
5109
+ #else
5110
+ printf("%s: features : ", __func__);
5111
+ for (int j = 0; j < n_features; ++j) {
5112
+ printf("%8.3f ", features[i][j]);
5113
+ }
5114
+ printf(" '%s'\n", ctx->result_all[i].text.c_str());
5115
+ #endif
5037
5116
}
5038
5117
printf (" ----------------\n " );
5039
5118
}
0 commit comments