20
20
# MA 02111-1307 USA
21
21
#
22
22
import base64
23
+ import builtins
24
+ import dill
23
25
import io
24
26
import json
25
27
import os
26
- import pickle
27
28
import random
28
29
import string
29
30
import sys
31
+ from typing import Union
30
32
31
33
from dlg .common .reproducibility .reproducibility import common_hash
32
34
from dlg .data .drops .data_base import DataDROP , logger
33
35
from dlg .data .io import SharedMemoryIO , MemoryIO
34
36
37
+ def get_builtins ()-> dict :
38
+ """
39
+ Get a tuple of buitlin types to compare pydata with.
40
+ """
41
+ builtin_types = tuple (getattr (builtins , t ) for t in dir (builtins ) if isinstance (getattr (builtins , t ), type ))
42
+ builtin_types = builtin_types [builtin_types .index (bool ):]
43
+ builtin_names = [b .__name__ for b in builtin_types ]
44
+ return dict (zip (builtin_names , builtin_types ))
35
45
36
- def parse_pydata (pd_dict : dict ) -> bytes :
46
+
47
+ def parse_pydata (pd : Union [bytes , dict ]) -> bytes :
37
48
"""
38
49
Parse and evaluate the pydata argument to populate memory during initialization
39
50
40
- :param pd_dict: the pydata dictionary from the graph node
51
+ :param pd: either the pydata dictionary from the graph node or the value directly
41
52
42
53
:returns a byte encoded value
43
54
"""
55
+ pd_dict = pd if isinstance (pd , dict ) else {"value" :pd , "type" :"raw" }
44
56
pydata = pd_dict ["value" ]
45
- logger .debug (f "pydata value provided: { pydata } , { pd_dict [ ' type' ]. lower () } " )
57
+ logger .debug ("pydata value provided: '%s' with type '%s'" , pydata , type ( pydata ) )
46
58
47
59
if pd_dict ["type" ].lower () in ["string" , "str" ]:
48
60
return pydata if pydata != "None" else None
61
+ builtin_types = get_builtins ()
62
+ if pd_dict ["type" ] != "raw" and type (pydata ) in builtin_types .values () and pd_dict ["type" ] not in builtin_types .keys ():
63
+ logger .warning ("Type of pydata %s provided differs from specified type: %s" , type (pydata ).__name__ , pd_dict ["type" ])
64
+ pd_dict ["type" ] = type (pydata ).__name__
49
65
if pd_dict ["type" ].lower () == "json" :
50
66
try :
51
67
pydata = json .loads (pydata )
@@ -56,28 +72,35 @@ def parse_pydata(pd_dict: dict) -> bytes:
56
72
pydata = eval (pydata )
57
73
# except:
58
74
# pydata = pydata.encode()
59
- elif pd_dict ["type" ].lower () == "int" :
75
+ elif pd_dict ["type" ].lower () == "int" or isinstance ( pydata , int ) :
60
76
try :
61
77
pydata = int (pydata )
78
+ pd_dict ["type" ] = "int"
62
79
except :
63
80
pydata = pydata .encode ()
64
- elif pd_dict ["type" ].lower () == "float" :
81
+ elif pd_dict ["type" ].lower () == "float" or isinstance ( pydata , float ) :
65
82
try :
66
83
pydata = float (pydata )
84
+ pd_dict ["type" ] = "float"
67
85
except :
68
86
pydata = pydata .encode ()
69
- elif pd_dict ["type" ].lower () == "boolean" :
87
+ elif pd_dict ["type" ].lower () == "boolean" or isinstance ( pydata , bool ) :
70
88
try :
71
89
pydata = bool (pydata )
90
+ pd_dict ["type" ] = "bool"
72
91
except :
73
92
pydata = pydata .encode ()
74
93
elif pd_dict ["type" ].lower () == "object" :
75
94
pydata = base64 .b64decode (pydata .encode ())
76
95
try :
77
- pydata = pickle .loads (pydata )
96
+ pydata = dill .loads (pydata )
78
97
except :
79
98
raise
80
- return pickle .dumps (pydata )
99
+ elif pd_dict ["type" ].lower () == "raw" :
100
+ pydata = dill .loads (base64 .b64decode (pydata ))
101
+ logger .debug ("Returning pydata of type: %s" , type (pydata ))
102
+ # return pydata
103
+ return dill .dumps (pydata )
81
104
82
105
83
106
##
@@ -117,34 +140,30 @@ def initialize(self, **kwargs):
117
140
"""
118
141
args = []
119
142
pydata = None
143
+ # pdict = {}
120
144
pdict = {"type" : "raw" } # initialize this value to enforce BytesIO
121
145
self .data_type = pdict ["type" ]
122
146
field_names = (
123
147
[f ["name" ] for f in kwargs ["fields" ]] if "fields" in kwargs else []
124
148
)
125
149
if "pydata" in kwargs and not (
126
150
"fields" in kwargs and "pydata" in field_names
127
- ): # means that is was passed directly
151
+ ): # means that is was passed directly (e.g. from tests)
128
152
pydata = kwargs .pop ("pydata" )
129
- logger .debug ("pydata value provided: %s, %s" , pydata , kwargs )
130
- try : # test whether given value is valid
131
- _ = pickle .loads (base64 .b64decode (pydata ))
132
- pydata = base64 .b64decode (pydata )
133
- except :
134
- pydata = None
153
+ pdict ["value" ] = pydata
154
+ pydata = parse_pydata (pdict )
135
155
elif "fields" in kwargs and "pydata" in field_names :
136
156
data_pos = field_names .index ("pydata" )
137
157
pdict = kwargs ["fields" ][data_pos ]
138
158
pydata = parse_pydata (pdict )
139
- if pdict ["type" ].lower () in ["str" ,"string" ]:
140
- self .data_type = "String"
141
- self ._buf = io .StringIO (* args )
159
+ if pdict and pdict ["type" ].lower () in ["str" ,"string" ]:
160
+ self .data_type = "String" if pydata else "raw"
142
161
else :
143
- self .data_type = pdict ["type" ]
144
- self ._buf = io .BytesIO (* args )
162
+ self .data_type = pdict ["type" ] if pdict else ""
145
163
if pydata :
146
164
args .append (pydata )
147
- logger .debug ("Loaded into memory: %s, %s" , pydata , self .data_type )
165
+ logger .debug ("Loaded into memory: %s, %s, %s" , pydata , self .data_type , type (pydata ))
166
+ self ._buf = io .BytesIO (* args ) if self .data_type != "String" else io .StringIO (* args )
148
167
self .size = len (pydata ) if pydata else 0
149
168
150
169
def getIO (self ):
@@ -230,7 +249,7 @@ def initialize(self, **kwargs):
230
249
pydata = kwargs .pop ("pydata" )
231
250
logger .debug ("pydata value provided: %s" , pydata )
232
251
try : # test whether given value is valid
233
- _ = pickle .loads (base64 .b64decode (pydata .encode ("latin1" )))
252
+ _ = dill .loads (base64 .b64decode (pydata .encode ("latin1" )))
234
253
pydata = base64 .b64decode (pydata .encode ("latin1" ))
235
254
except :
236
255
pydata = None
0 commit comments