4
4
import numpy as np
5
5
import tensorflow as tf
6
6
import networkx as nx
7
- import random , pickle
7
+ import random , pickle , json
8
8
import progressbar
9
9
from .utils import dict_to_tf
10
10
import warnings
@@ -39,14 +39,18 @@ def __init__(self, vocabulary=None, dimensionality=100, lambda0=1.0, shared_cont
39
39
else :
40
40
if type (saved_model_path ) != str :
41
41
raise TypeError ("saved_model_path must be a str" )
42
- with open (saved_model_path , "rb" ) as f :
43
- d = pickle .load (f )
42
+ d = None
43
+ if saved_model_path .split ("." )[- 1 ] == "json" :
44
+ with open (saved_model_path , "r" ) as f :
45
+ d = json .load (f )
46
+ else :
47
+ with open (saved_model_path , "rb" ) as f :
48
+ d = pickle .load (f )
44
49
self .vocabulary = d ["vocabulary" ]
45
50
self .tf_vocabulary = dict_to_tf (self .vocabulary )
46
- self .theta = tf .Variable (d ["theta" ])
51
+ self .theta = tf .Variable (d ["theta" ], dtype = tf . float64 )
47
52
self .lambda0 = d ["lambda0" ]
48
53
49
- @tf .function
50
54
def _get_embeddings (self , item ):
51
55
if type (item ) == str :
52
56
return self .theta [self .vocabulary [item ]]
@@ -126,8 +130,16 @@ def save(self, path):
126
130
if hasattr (self , 'graph' ):
127
131
d ["graph" ] = self .graph
128
132
129
- with open (path , "wb" ) as f :
130
- pickle .dump (d , f , protocol = 4 )
133
+ if path .split ("." )[- 1 ] == "json" :
134
+ d ["theta" ] = theta .tolist ()
135
+ if "graph" in d :
136
+ d ["graph" ] = nx .readwrite .json_graph .adjacency_data (self .graph )
137
+
138
+ with open (path , 'w' ) as f :
139
+ json .dump (d , f , indent = 2 , ensure_ascii = False )
140
+ else :
141
+ with open (path , "wb" ) as f :
142
+ pickle .dump (d , f , protocol = 4 )
131
143
132
144
class LaplacianEmbedding (Embedding ):
133
145
"""
@@ -147,11 +159,16 @@ def __init__(self, vocabulary=None, dimensionality=100, graph=None, lambda0=1.0,
147
159
self .graph = graph
148
160
self .edges_i = None
149
161
else :
150
- with open (saved_model_path , "rb" ) as f :
151
- d = pickle .load (f )
162
+ d = None
163
+ if saved_model_path .split ("." )[- 1 ] == "json" :
164
+ with open (saved_model_path , "r" ) as f :
165
+ d = json .load (f )
166
+ else :
167
+ with open (saved_model_path , "rb" ) as f :
168
+ d = pickle .load (f )
152
169
self .vocabulary = d ["vocabulary" ]
153
170
self .tf_vocabulary = dict_to_tf (self .vocabulary )
154
- self .theta = tf .Variable (d ["theta" ])
171
+ self .theta = tf .Variable (d ["theta" ], dtype = tf . float64 )
155
172
self .lambda0 = d ["lambda0" ]
156
173
self .lambda1 = d ["lambda1" ]
157
174
self .graph = d ["graph" ]
0 commit comments