Skip to content

Commit f25d596

Browse files
committed
feat: posterior mean
1 parent 89bf1a0 commit f25d596

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

probabilistic_word_embeddings/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,19 @@ def normalize_rotation(e, words):
130130

131131
e_new[vocabulary] = (Q.T @ e[vocabulary].numpy().T).T
132132
return e_new
133+
134+
def posterior_mean(paths):
135+
emb_paths = sorted(paths)
136+
e_ref = Embedding(saved_model_path=emb_paths[-1])
137+
words_reference = [f"{wd}_c" for wd in list(e_ref.vocabulary) if "_c" not in wd]
138+
139+
e_mean = Embedding(saved_model_path=emb_paths[-1])
140+
e_mean.theta = e_mean.theta * 0.0
141+
142+
for emb_path in emb_paths:
143+
e = Embedding(saved_model_path=emb_path)
144+
e_aligned = align(e_ref, e, words_reference)
145+
e_mean.theta += e_aligned.theta / len(emb_paths)
146+
147+
return e_mean
148+

0 commit comments

Comments
 (0)