Skip to content

Commit 1fb1596

Browse files
authored
chore: merge pull request #21 from ninpnin/dev
JSON loading and bugfix
2 parents 8000c6a + f7f4f43 commit 1fb1596

File tree

4 files changed

+36
-18
lines changed

4 files changed

+36
-18
lines changed

probabilistic_word_embeddings/embeddings.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import tensorflow as tf
66
import networkx as nx
7-
import random, pickle
7+
import random, pickle, json
88
import progressbar
99
from .utils import dict_to_tf
1010
import warnings
@@ -39,14 +39,18 @@ def __init__(self, vocabulary=None, dimensionality=100, lambda0=1.0, shared_cont
3939
else:
4040
if type(saved_model_path) != str:
4141
raise TypeError("saved_model_path must be a str")
42-
with open(saved_model_path, "rb") as f:
43-
d = pickle.load(f)
42+
d = None
43+
if saved_model_path.split(".")[-1] == "json":
44+
with open(saved_model_path, "r") as f:
45+
d = json.load(f)
46+
else:
47+
with open(saved_model_path, "rb") as f:
48+
d = pickle.load(f)
4449
self.vocabulary = d["vocabulary"]
4550
self.tf_vocabulary = dict_to_tf(self.vocabulary)
46-
self.theta = tf.Variable(d["theta"])
51+
self.theta = tf.Variable(d["theta"], dtype=tf.float64)
4752
self.lambda0 = d["lambda0"]
4853

49-
@tf.function
5054
def _get_embeddings(self, item):
5155
if type(item) == str:
5256
return self.theta[self.vocabulary[item]]
@@ -126,8 +130,16 @@ def save(self, path):
126130
if hasattr(self, 'graph'):
127131
d["graph"] = self.graph
128132

129-
with open(path, "wb") as f:
130-
pickle.dump(d, f, protocol=4)
133+
if path.split(".")[-1] == "json":
134+
d["theta"] = theta.tolist()
135+
if "graph" in d:
136+
d["graph"] = nx.readwrite.json_graph.adjacency_data(self.graph)
137+
138+
with open(path, 'w') as f:
139+
json.dump(d, f, indent=2, ensure_ascii=False)
140+
else:
141+
with open(path, "wb") as f:
142+
pickle.dump(d, f, protocol=4)
131143

132144
class LaplacianEmbedding(Embedding):
133145
"""
@@ -147,11 +159,16 @@ def __init__(self, vocabulary=None, dimensionality=100, graph=None, lambda0=1.0,
147159
self.graph = graph
148160
self.edges_i = None
149161
else:
150-
with open(saved_model_path, "rb") as f:
151-
d = pickle.load(f)
162+
d = None
163+
if saved_model_path.split(".")[-1] == "json":
164+
with open(saved_model_path, "r") as f:
165+
d = json.load(f)
166+
else:
167+
with open(saved_model_path, "rb") as f:
168+
d = pickle.load(f)
152169
self.vocabulary = d["vocabulary"]
153170
self.tf_vocabulary = dict_to_tf(self.vocabulary)
154-
self.theta = tf.Variable(d["theta"])
171+
self.theta = tf.Variable(d["theta"], dtype=tf.float64)
155172
self.lambda0 = d["lambda0"]
156173
self.lambda1 = d["lambda1"]
157174
self.graph = d["graph"]

probabilistic_word_embeddings/evaluation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ def words_in_e(row):
200200
r = len(df)
201201
target_words = list(df[columns[-1]])
202202

203-
X1 = embedding[df[columns[0]]]
204-
X2 = embedding[df[columns[1]]]
205-
X3 = embedding[df[columns[2]]]
203+
X1 = embedding[list(df[columns[0]])]
204+
X2 = embedding[list(df[columns[1]])]
205+
X3 = embedding[list(df[columns[2]])]
206206
X = X1 - X2 + X3
207207

208208
inv_vocab = {v: k for k, v in e.vocabulary.items()}

probabilistic_word_embeddings/preprocessing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,11 @@ def add_subscript(t, subscript):
171171
if labels is not None:
172172
print("Add partition labels to words...")
173173
texts = [add_subscript(text, label) for text, label in zip(texts, progressbar.progressbar(labels))]
174-
vocabs = [set(text) for text in progressbar.progressbar(texts)]
175-
empty = set()
176-
vocabulary = empty.union(*vocabs)
177-
174+
175+
vocabulary = set()
176+
for text in progressbar.progressbar(texts):
177+
for wd in text:
178+
vocabulary.add(wd)
178179

179180
def _remove_subscript(wd):
180181
s = wd.split("_")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "probabilistic-word-embeddings"
3-
version = "1.13.7"
3+
version = "1.15.1"
44
description = "Probabilistic Word Embeddings for Python"
55
authors = ["Your Name <[email protected]>"]
66
license = "MIT"

0 commit comments

Comments
 (0)