Skip to content

Commit 657741b

Browse files
committed
diarization : try conv and self-attention embeddings
1 parent d11f359 commit 657741b

File tree

2 files changed

+112
-35
lines changed

2 files changed

+112
-35
lines changed

ggml.c

-2
Original file line numberDiff line numberDiff line change
@@ -8674,8 +8674,6 @@ void ggml_svd_reduce_dims(
86748674
//}
86758675
//printf("\n");
86768676

8677-
8678-
printf("n = %d, m = %d, nd = %d\n", n, m, nd);
86798677
// project A0 onto U
86808678
for (int i = 0; i < n; ++i) {
86818679
for (int j = 0; j < nd; ++j) {

whisper.cpp

+112-33
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
268268
{ MODEL_LARGE, 71ull*MB },
269269
};
270270

271+
static const std::map<e_model, size_t> MEM_REQ_KV_ENC_SELF = {
272+
{ MODEL_TINY, 23ull*MB },
273+
{ MODEL_BASE, 26ull*MB },
274+
{ MODEL_SMALL, 216ull*MB },
275+
{ MODEL_MEDIUM, 243ull*MB },
276+
{ MODEL_LARGE, 271ull*MB },
277+
};
278+
271279
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
272280
{ MODEL_TINY, 9ull*MB },
273281
{ MODEL_BASE, 18ull*MB },
@@ -571,6 +579,7 @@ struct whisper_context {
571579
// cross-attention KV cache for the decoders
572580
// shared between all decoders
573581
whisper_kv_cache kv_cross;
582+
whisper_kv_cache kv_enc_self;
574583

575584
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
576585

@@ -807,7 +816,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
807816
MEM_REQ_SCRATCH3.at (model.type) +
808817
scale*MEM_REQ_MODEL.at (model.type) +
809818
scale*MEM_REQ_KV_CROSS.at(model.type) +
810-
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
819+
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
811820

812821
// this is the memory required by one decoder
813822
const size_t mem_required_decoder =
@@ -838,6 +847,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
838847
return false;
839848
}
840849

850+
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_ENC_SELF.at(model.type), wctx.kv_enc_self, wctx.wtype, model.hparams.n_audio_ctx)) {
851+
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
852+
return false;
853+
}
854+
841855
{
842856
const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v);
843857
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
@@ -1415,6 +1429,9 @@ static bool whisper_encode(
14151429
}
14161430
}
14171431

1432+
struct ggml_cgraph gf = {};
1433+
gf.n_threads = n_threads;
1434+
14181435
struct ggml_tensor * cur;
14191436

14201437
// convolution + gelu
@@ -1442,6 +1459,18 @@ static bool whisper_encode(
14421459
cur = ggml_gelu(ctx0, cur);
14431460
}
14441461

1462+
//{
1463+
// //printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
1464+
1465+
// wctx.use_buf(ctx0, -1);
1466+
1467+
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(0*n_ctx));
1468+
// //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
1469+
1470+
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
1471+
// //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1472+
//}
1473+
14451474
wctx.use_buf(ctx0, 3);
14461475

14471476
// ===================================================================
@@ -1522,6 +1551,18 @@ static bool whisper_encode(
15221551
Vcur),
15231552
Vcur);
15241553

1554+
{
1555+
//printf("Kcur: %d %d %d %d, size element = %d\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], Kcur->ne[3], ggml_element_size(Kcur));
1556+
1557+
wctx.use_buf(ctx0, -1);
1558+
1559+
struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
1560+
struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
1561+
1562+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
1563+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1564+
}
1565+
15251566
// ------
15261567

15271568
wctx.use_buf(ctx0, 0);
@@ -1606,6 +1647,18 @@ static bool whisper_encode(
16061647
cur = ggml_cpy(ctx0,
16071648
KQV_merged,
16081649
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
1650+
1651+
//{
1652+
// //printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
1653+
1654+
// wctx.use_buf(ctx0, -1);
1655+
1656+
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
1657+
// //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
1658+
1659+
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
1660+
// //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
1661+
//}
16091662
}
16101663

16111664
// projection
@@ -1715,8 +1768,6 @@ static bool whisper_encode(
17151768

17161769
// run the computation
17171770
{
1718-
struct ggml_cgraph gf = {};
1719-
gf.n_threads = n_threads;
17201771

17211772
ggml_build_forward_expand(&gf, cur);
17221773
ggml_graph_compute (ctx0, &gf);
@@ -4858,7 +4909,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
48584909
const int n_state = ctx->model.hparams.n_audio_state;
48594910
const int n_layer = ctx->model.hparams.n_audio_layer;
48604911

4861-
#if 1
4912+
#if 0
48624913
// use the last layer of the encoder
48634914
{
48644915
std::vector<float> embd(n_segments*n_state);
@@ -4878,7 +4929,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
48784929
const int n_features = std::min(4, n_segments);
48794930

48804931
ggml_svd_reduce_dims(n_state, n_segments, embd.data(), n_features);
4881-
#else
4932+
#elif 0
48824933
// use cross kv cache of various layers
48834934
for (int il = 0; il < n_layer; ++il) {
48844935
std::vector<float> embd(n_segments*n_ctx*n_state);
@@ -4900,6 +4951,29 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
49004951

49014952
const int n_features = std::min(4, n_segments);
49024953

4954+
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
4955+
#else
4956+
// use enc self kv cache of various layers
4957+
for (int il = 0; il < n_layer; ++il) {
4958+
std::vector<float> embd(n_segments*n_ctx*n_state);
4959+
4960+
for (int i = 0; i < n_segments; ++i) {
4961+
const auto & segment_i = ctx->result_all[i];
4962+
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
4963+
4964+
ctx->mel.n_len = segment_i.t1;
4965+
whisper_encode(*ctx, segment_i.t0, 7, true);
4966+
4967+
const size_t offs = ggml_element_size(ctx->kv_enc_self.k)*(il*n_ctx*n_state);
4968+
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self.k->data + offs);
4969+
4970+
for (int j = 0; j < n_ctx*n_state; ++j) {
4971+
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
4972+
}
4973+
}
4974+
4975+
const int n_features = std::min(16, n_segments);
4976+
49034977
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
49044978
#endif
49054979

@@ -4973,6 +5047,7 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
49735047
double d0 = 0.0;
49745048
double d1 = 0.0;
49755049

5050+
#if 0
49765051
// use the euclidean distance
49775052
{
49785053
for (int m = 0; m < n_features; ++m) {
@@ -4985,35 +5060,36 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
49855060
}
49865061
d1 = std::sqrt(d1);
49875062
}
4988-
5063+
#else
49895064
// use the cosine distance
4990-
//{
4991-
// double dot = 0.0;
4992-
// double norm0 = 0.0;
4993-
// double norm1 = 0.0;
5065+
{
5066+
double dot = 0.0;
5067+
double norm0 = 0.0;
5068+
double norm1 = 0.0;
49945069

4995-
// for (int m = 0; m < n_features; ++m) {
4996-
// dot += features[j][m]*centroids[k][m];
4997-
// norm0 += std::pow(features[j][m], 2.0);
4998-
// norm1 += std::pow(centroids[k][m], 2.0);
4999-
// }
5070+
for (int m = 0; m < n_features; ++m) {
5071+
dot += features[j][m]*centroids[k][m];
5072+
norm0 += std::pow(features[j][m], 2.0);
5073+
norm1 += std::pow(centroids[k][m], 2.0);
5074+
}
50005075

5001-
// d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
5076+
d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
50025077

5003-
// dot = 0.0;
5004-
// norm0 = 0.0;
5005-
// norm1 = 0.0;
5078+
dot = 0.0;
5079+
norm0 = 0.0;
5080+
norm1 = 0.0;
50065081

5007-
// for (int m = 0; m < n_features; ++m) {
5008-
// dot += features[j][m]*centroids[l][m];
5009-
// norm0 += std::pow(features[j][m], 2.0);
5010-
// norm1 += std::pow(centroids[l][m], 2.0);
5011-
// }
5082+
for (int m = 0; m < n_features; ++m) {
5083+
dot += features[j][m]*centroids[l][m];
5084+
norm0 += std::pow(features[j][m], 2.0);
5085+
norm1 += std::pow(centroids[l][m], 2.0);
5086+
}
50125087

5013-
// d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
5014-
//}
5088+
d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
5089+
}
5090+
#endif
50155091

5016-
sum += std::pow(d0/d1, 2.0/(1.15 - 1.0));
5092+
sum += std::pow(d0/d1, 2.0/(2.0 - 1.0));
50175093
}
50185094

50195095
membership[j][k] = sum == 0.0 ? 0.0 : 1.0/sum;
@@ -5024,16 +5100,19 @@ void whisper_full_cluster_segments(struct whisper_context * ctx) {
50245100
if (i == niter - 1) {
50255101
//{
50265102
for (int i = 0; i < n_segments; ++i) {
5103+
#if 1
50275104
printf("%s: membership %3d: ", __func__, i);
50285105
for (int j = 0; j < n_clusters; ++j) {
5029-
printf("%f ", membership[i][j]);
5106+
printf("%.1f ", membership[i][j]);
50305107
}
50315108
printf(" '%s'\n", ctx->result_all[i].text.c_str());
5032-
//printf("%s: features : ", __func__);
5033-
//for (int j = 0; j < n_features; ++j) {
5034-
// printf("%8.3f ", features[i][j]);
5035-
//}
5036-
//printf(" '%s'\n", ctx->result_all[i].text.c_str());
5109+
#else
5110+
printf("%s: features : ", __func__);
5111+
for (int j = 0; j < n_features; ++j) {
5112+
printf("%8.3f ", features[i][j]);
5113+
}
5114+
printf(" '%s'\n", ctx->result_all[i].text.c_str());
5115+
#endif
50375116
}
50385117
printf("----------------\n");
50395118
}

0 commit comments

Comments
 (0)