Skip to content

Commit 4d9744d

Browse files
committed
fix: preprocessing bug
1 parent 512ee69 commit 4d9744d

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

probabilistic_word_embeddings/preprocessing.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66

77
def _pb_if_needed(l):
8-
if len(iterator) >= 100:
8+
if len(l) >= 1000:
99
return progressbar.progressbar(l)
1010
else:
1111
return l
@@ -57,15 +57,19 @@ def downsample_common_words(data, counts, cutoff=0.00001, chunk_len=5000000, see
5757
if not isinstance(data, tf.Tensor):
5858
data = tf.constant(data)
5959

60-
print("Discard some instances of the most common words...")
60+
#print("Discard some instances of the most common words...")
6161
N = sum(counts.values())
6262
counts_tf = dict_to_tf(counts)
6363
# Randomize and fetch by this probability
6464
if seed is not None:
6565
tf.random.set_seed(seed)
6666

6767
if len(data) < chunk_len:
68-
frequencies = counts_tf.lookup(data) / N
68+
try:
69+
frequencies = counts_tf.lookup(data) / N
70+
except:
71+
print("Error downsampling:", data)
72+
return [wd.decode("utf-8") for wd in data.numpy()]
6973
# Discard probability based on relative frequency
7074
probs = 1. - tf.sqrt(cutoff / frequencies)
7175

@@ -132,16 +136,21 @@ def preprocess_partitioned(texts, labels=None, lowercase=True, remove_punctuatio
132136
assert isinstance(texts[0], list), "Data should be provided as a list of lists"
133137
N = sum([len(t) for t in texts])
134138
if lowercase:
139+
print("Convert to lowercase...")
135140
texts = [[wd.lower() for wd in t] for t in texts]
136141

137142
if remove_punctuation:
143+
print("Remove punctuation...")
138144
def remove_punctuation_fun(s):
139145
return s.replace(".", "").replace(",", "").replace("!", "").replace("?", "")
140146
texts = [[remove_punctuation_fun(wd) for wd in t] for t in texts]
141147

148+
if limit > 1:
149+
print("Filter rare words...")
142150
texts, counts = filter_rare_words(texts, limit=limit, keep_words=keep_words)
143151
if downsample:
144-
texts = [downsample_common_words(text, counts, seed=seed) for text in texts]
152+
print("Discard some instances of the most common words...")
153+
texts = [downsample_common_words(text, counts, seed=seed) for text in _pb_if_needed(texts)]
145154

146155
def add_subscript(t, subscript):
147156
if len(t) == 0:
@@ -160,7 +169,8 @@ def add_subscript(t, subscript):
160169
return t
161170

162171
if labels is not None:
163-
texts = [add_subscript(text, label) for text, label in progressbar.progressbar(zip(texts, labels))]
172+
print("Add partition labels to words...")
173+
texts = [add_subscript(text, label) for text, label in zip(texts, progressbar.progressbar(labels))]
164174
vocabs = [set(text) for text in progressbar.progressbar(texts)]
165175
empty = set()
166176
vocabulary = empty.union(*vocabs)
@@ -171,6 +181,7 @@ def _remove_subscript(wd):
171181
n = len(s)
172182
return "_".join(s[:n-1])
173183

184+
print("Calculate word frequencies...")
174185
if labels is None:
175186
unnormalized_freqs = {wd: counts[wd] / N for wd in list(vocabulary)}
176187
else:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "probabilistic-word-embeddings"
3-
version = "1.12.0"
3+
version = "1.13.7"
44
description = "Probabilistic Word Embeddings for Python"
55
authors = ["Your Name <[email protected]>"]
66
license = "MIT"

0 commit comments

Comments
 (0)