5
5
import warnings
6
6
7
7
def _pb_if_needed (l ):
8
- if len (iterator ) >= 100 :
8
+ if len (l ) >= 1000 :
9
9
return progressbar .progressbar (l )
10
10
else :
11
11
return l
@@ -57,15 +57,19 @@ def downsample_common_words(data, counts, cutoff=0.00001, chunk_len=5000000, see
57
57
if not isinstance (data , tf .Tensor ):
58
58
data = tf .constant (data )
59
59
60
- print ("Discard some instances of the most common words..." )
60
+ # print("Discard some instances of the most common words...")
61
61
N = sum (counts .values ())
62
62
counts_tf = dict_to_tf (counts )
63
63
# Randomize and fetch by this probability
64
64
if seed is not None :
65
65
tf .random .set_seed (seed )
66
66
67
67
if len (data ) < chunk_len :
68
- frequencies = counts_tf .lookup (data ) / N
68
+ try :
69
+ frequencies = counts_tf .lookup (data ) / N
70
+ except :
71
+ print ("Error downsampling:" , data )
72
+ return [wd .decode ("utf-8" ) for wd in data .numpy ()]
69
73
# Discard probability based on relative frequency
70
74
probs = 1. - tf .sqrt (cutoff / frequencies )
71
75
@@ -132,16 +136,21 @@ def preprocess_partitioned(texts, labels=None, lowercase=True, remove_punctuatio
132
136
assert isinstance (texts [0 ], list ), "Data should be provided as a list of lists"
133
137
N = sum ([len (t ) for t in texts ])
134
138
if lowercase :
139
+ print ("Convert to lowercase..." )
135
140
texts = [[wd .lower () for wd in t ] for t in texts ]
136
141
137
142
if remove_punctuation :
143
+ print ("Remove punctuation..." )
138
144
def remove_punctuation_fun (s ):
139
145
return s .replace ("." , "" ).replace ("," , "" ).replace ("!" , "" ).replace ("?" , "" )
140
146
texts = [[remove_punctuation_fun (wd ) for wd in t ] for t in texts ]
141
147
148
+ if limit > 1 :
149
+ print ("Filter rare words..." )
142
150
texts , counts = filter_rare_words (texts , limit = limit , keep_words = keep_words )
143
151
if downsample :
144
- texts = [downsample_common_words (text , counts , seed = seed ) for text in texts ]
152
+ print ("Discard some instances of the most common words..." )
153
+ texts = [downsample_common_words (text , counts , seed = seed ) for text in _pb_if_needed (texts )]
145
154
146
155
def add_subscript (t , subscript ):
147
156
if len (t ) == 0 :
@@ -160,7 +169,8 @@ def add_subscript(t, subscript):
160
169
return t
161
170
162
171
if labels is not None :
163
- texts = [add_subscript (text , label ) for text , label in progressbar .progressbar (zip (texts , labels ))]
172
+ print ("Add partition labels to words..." )
173
+ texts = [add_subscript (text , label ) for text , label in zip (texts , progressbar .progressbar (labels ))]
164
174
vocabs = [set (text ) for text in progressbar .progressbar (texts )]
165
175
empty = set ()
166
176
vocabulary = empty .union (* vocabs )
@@ -171,6 +181,7 @@ def _remove_subscript(wd):
171
181
n = len (s )
172
182
return "_" .join (s [:n - 1 ])
173
183
184
+ print ("Calculate word frequencies..." )
174
185
if labels is None :
175
186
unnormalized_freqs = {wd : counts [wd ] / N for wd in list (vocabulary )}
176
187
else :
0 commit comments