Skip to content

Commit 7660a0c

Browse files
committed
Finally all tests passing
1 parent cc8e553 commit 7660a0c

File tree

11 files changed

+135
-69
lines changed

11 files changed

+135
-69
lines changed

daliuge-common/dlg/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -528,12 +528,12 @@ def prepareUser(DLG_ROOT=getDlgDir()):
528528

529529

530530
def serialize_data(d):
531-
# return pickle.dumps(d)
531+
# return dill.dumps(d)
532532
return b2s(base64.b64encode(dill.dumps(d)))
533533

534534

535535
def deserialize_data(d):
536-
# return pickle.loads()
536+
# return dill.loads(d)
537537
return dill.loads(base64.b64decode(d.encode("utf8")))
538538

539539

daliuge-engine/dlg/apps/pyfunc.py

+38-28
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import pickle
3434
import re
3535

36-
from typing import Callable
36+
from typing import Callable, Union
3737
import dill
3838
from io import StringIO
3939
from contextlib import redirect_stdout
@@ -68,7 +68,8 @@ def serialize_func(f):
6868
parts = f.split(".")
6969
f = getattr(importlib.import_module(".".join(parts[:-1])), parts[-1])
7070

71-
fser = dill.dumps(f)
71+
fser = base64.b64encode(dill.dumps(f)).decode()
72+
# fser = inspect.getsource(f)
7273
fdefaults = {"args": [], "kwargs": {}}
7374
adefaults = {"args": [], "kwargs": {}}
7475
a = inspect.getfullargspec(f)
@@ -139,32 +140,43 @@ def import_using_name(app, fname):
139140
logger.debug("Loaded module: %s", mod)
140141
return mod
141142

143+
def import_using_code_ser(func_code: Union[str, bytes], func_name: str):
144+
"""
145+
Import the function provided as a serialised code string.
146+
"""
147+
try:
148+
func_code = func_code if isinstance(func_code, bytes) else func_code.encode()
149+
func = dill.loads(base64.b64decode(func_code))
150+
except Exception as err:
151+
logger.warning("Unable to deserialize func_code: %s", err)
152+
raise
153+
if func_name and func_name.split(".")[-1] != func.__name__:
154+
raise ValueError(
155+
f"Function with name '{func.__name__}' instead of '{func_name}' found!"
156+
)
157+
return func
142158

143159
def import_using_code(func_code: str, func_name: str, serialized: bool = True):
144160
"""
145161
Import the function provided as a code string. Plain code as well as serialized code
146162
is supported. If the func_name does not match the provided func_name the load will fail.
147163
"""
148-
if not serialized:
149-
logger.debug(f"Trying to import code from string: {func_code}")
150-
mod = pyext.RuntimeModule.from_string("mod", func_name, func_code)
164+
mod = None
165+
if not serialized and not isinstance(func_code, bytes):
166+
logger.debug(f"Trying to import code from string: {func_code.strip()}")
167+
try:
168+
mod = pyext.RuntimeModule.from_string("mod", func_name, func_code.strip())
169+
except Exception:
170+
func = import_using_code_ser(func_code, func_name)
151171
logger.debug("Imported function: %s", func_name)
152-
if func_name:
172+
if mod and func_name:
153173
if hasattr(mod, func_name):
154174
func = getattr(mod, func_name)
155175
else:
156176
logger.warning("Function with name '%s' not found!", func_name)
157177
raise ValueError(f"Function with name '{func_name}' not found!")
158178
else:
159-
try:
160-
func = dill.loads(func_code)
161-
except Exception as err:
162-
logger.warning("Unable to deserialize func_code: %s", err)
163-
raise
164-
if func_name and func_name.split(".")[-1] != func.__name__:
165-
raise ValueError(
166-
f"Function with name '{func.__name__}' instead of '{func_name}' found!"
167-
)
179+
func = import_using_code_ser(func_code, func_name)
168180
logger.debug("Imported function: %s", func_name)
169181
return func
170182

@@ -570,16 +582,16 @@ def initialize_with_func_code(self):
570582
)
571583
except (SyntaxError, NameError) as err:
572584
logger.warning(
573-
"Problem importing code: %s. Checking whether it was serialized.",
585+
"Problem importing code: %s.",
574586
err,
575587
)
576-
serialized = True
577-
if isinstance(self.func_code, bytes) or serialized:
578-
if isinstance(self.func_code, str):
579-
self.func_code = base64.b64decode(self.func_code.encode("utf8"))
580-
self.func = import_using_code(
581-
self.func_code, self.func_name, serialized=True
582-
)
588+
# serialized = True
589+
# if isinstance(self.func_code, bytes) or serialized:
590+
# if isinstance(self.func_code, str):
591+
# self.func_code = base64.b64decode(self.func_code.encode("utf8"))
592+
# self.func = import_using_code(
593+
# self.func_code, self.func_name, serialized=True
594+
# )
583595

584596
self._init_fn_defaults()
585597
# make sure defaults are dicts
@@ -778,9 +790,9 @@ def write_results(self, result):
778790
result = result_iter[i]
779791

780792
parser = self._match_parser(o)
781-
if parser is DropParser.PICKLE:
782-
logger.debug(f"Writing pickeled result {type(result)} to {o}")
783-
o.write(pickle.dumps(result))
793+
if parser in [DropParser.PICKLE, DropParser.DILL]:
794+
logger.debug(f"Writing dilled result {type(result)} to {o}")
795+
o.write(dill.dumps(result))
784796
elif parser is DropParser.EVAL or parser is DropParser.UTF8:
785797
encoded_result = repr(result).encode("utf-8")
786798
o.write(encoded_result)
@@ -795,8 +807,6 @@ def write_results(self, result):
795807
drop_loaders.save_npy(o, result)
796808
elif parser is DropParser.RAW:
797809
o.write(result)
798-
elif parser is DropParser.DILL:
799-
o.write(dill.dumps(result))
800810
elif parser is DropParser.BINARY:
801811
drop_loaders.save_binary(o, result)
802812
else:

daliuge-engine/dlg/apps/simple.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# MA 02111-1307 USA
2121
#
2222
"""Applications used as examples, for testing, or in simple situations"""
23+
import base64
2324
import dill
2425
import _pickle
2526
from numbers import Number
@@ -1007,9 +1008,12 @@ def initialize(self, **kwargs):
10071008
def readData(self):
10081009
input = self.inputs[0]
10091010
data = pickle.loads(droputils.allDropContents(input))
1011+
# data = droputils.allDropContents(input)
1012+
# data = dill.loads(base64.b64decode(data))
10101013

10111014
# make sure we always have a ndarray with at least 1dim.
10121015
if type(data) not in (list, tuple) and not isinstance(data, (np.ndarray)):
1016+
logger.warning("Data type not in [list, tuple]: %s", data)
10131017
raise TypeError
10141018
if isinstance(data, np.ndarray) and data.ndim == 0:
10151019
data = np.array([data])

daliuge-engine/dlg/dask_emulation.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import base64
2727
import contextlib
28+
import inspect
2829
import logging
2930
import pickle
3031
import socket
@@ -312,6 +313,7 @@ def __init__(self, f, nout):
312313
if hasattr(f, "__name__"):
313314
self.fname = f.__name__
314315
self.fcode, self.fdefaults = pyfunc.serialize_func(f)
316+
# self.fcode = inspect.getsource(f)
315317
self.original_kwarg_names = []
316318
self.original_arg_names = []
317319
self.nout = nout
@@ -336,7 +338,10 @@ def make_dropdict(self):
336338
my_dropdict["func_name"] = self.fname
337339
my_dropdict["name"] = simple_fname
338340
if self.fcode is not None:
339-
my_dropdict["func_code"] = utils.b2s(base64.b64encode(self.fcode))
341+
logger.debug("func_code provided: %s", self.fcode)
342+
my_dropdict["func_code"] = self.fcode
343+
# my_dropdict["func_code"] = base64.b64encode(self.fcode)
344+
# my_dropdict["func_code"] = utils.b2s(base64.b64encode(self.fcode))
340345
if self.fdefaults:
341346
# APPLICATION ARGUMENTS
342347
my_dropdict["func_defaults"] = self.fdefaults
@@ -435,6 +440,7 @@ def make_dropdict(self):
435440
)
436441
if not self.producer:
437442
my_dropdict["pydata"] = pyfunc.serialize_data(self.pydata)
443+
# my_dropdict["pydata"] = self.pydata
438444
return my_dropdict
439445

440446
def __repr__(self):

daliuge-engine/dlg/data/drops/data_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def decrRefCount(self):
9797
with self._refLock:
9898
self._refCount -= 1
9999

100-
@track_current_drop
100+
# @track_current_drop
101101
def open(self, **kwargs):
102102
"""
103103
Opens the DROP for reading, and returns a "DROP descriptor"

daliuge-engine/dlg/data/drops/memory.py

+42-23
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,48 @@
2020
# MA 02111-1307 USA
2121
#
2222
import base64
23+
import builtins
24+
import dill
2325
import io
2426
import json
2527
import os
26-
import pickle
2728
import random
2829
import string
2930
import sys
31+
from typing import Union
3032

3133
from dlg.common.reproducibility.reproducibility import common_hash
3234
from dlg.data.drops.data_base import DataDROP, logger
3335
from dlg.data.io import SharedMemoryIO, MemoryIO
3436

37+
def get_builtins()-> dict:
38+
"""
39+
Get a tuple of buitlin types to compare pydata with.
40+
"""
41+
builtin_types = tuple(getattr(builtins, t) for t in dir(builtins) if isinstance(getattr(builtins, t), type))
42+
builtin_types = builtin_types[builtin_types.index(bool):]
43+
builtin_names = [b.__name__ for b in builtin_types]
44+
return dict(zip(builtin_names, builtin_types))
3545

36-
def parse_pydata(pd_dict: dict) -> bytes:
46+
47+
def parse_pydata(pd: Union[bytes, dict]) -> bytes:
3748
"""
3849
Parse and evaluate the pydata argument to populate memory during initialization
3950
40-
:param pd_dict: the pydata dictionary from the graph node
51+
:param pd: either the pydata dictionary from the graph node or the value directly
4152
4253
:returns a byte encoded value
4354
"""
55+
pd_dict = pd if isinstance(pd, dict) else {"value":pd, "type":"raw"}
4456
pydata = pd_dict["value"]
45-
logger.debug(f"pydata value provided: {pydata}, {pd_dict['type'].lower()}")
57+
logger.debug("pydata value provided: '%s' with type '%s'", pydata, type(pydata))
4658

4759
if pd_dict["type"].lower() in ["string", "str"]:
4860
return pydata if pydata != "None" else None
61+
builtin_types = get_builtins()
62+
if pd_dict["type"] != "raw" and type(pydata) in builtin_types.values() and pd_dict["type"] not in builtin_types.keys():
63+
logger.warning("Type of pydata %s provided differs from specified type: %s", type(pydata).__name__, pd_dict["type"])
64+
pd_dict["type"] = type(pydata).__name__
4965
if pd_dict["type"].lower() == "json":
5066
try:
5167
pydata = json.loads(pydata)
@@ -56,28 +72,35 @@ def parse_pydata(pd_dict: dict) -> bytes:
5672
pydata = eval(pydata)
5773
# except:
5874
# pydata = pydata.encode()
59-
elif pd_dict["type"].lower() == "int":
75+
elif pd_dict["type"].lower() == "int" or isinstance(pydata, int):
6076
try:
6177
pydata = int(pydata)
78+
pd_dict["type"] = "int"
6279
except:
6380
pydata = pydata.encode()
64-
elif pd_dict["type"].lower() == "float":
81+
elif pd_dict["type"].lower() == "float" or isinstance(pydata, float):
6582
try:
6683
pydata = float(pydata)
84+
pd_dict["type"] = "float"
6785
except:
6886
pydata = pydata.encode()
69-
elif pd_dict["type"].lower() == "boolean":
87+
elif pd_dict["type"].lower() == "boolean" or isinstance(pydata, bool):
7088
try:
7189
pydata = bool(pydata)
90+
pd_dict["type"] = "bool"
7291
except:
7392
pydata = pydata.encode()
7493
elif pd_dict["type"].lower() == "object":
7594
pydata = base64.b64decode(pydata.encode())
7695
try:
77-
pydata = pickle.loads(pydata)
96+
pydata = dill.loads(pydata)
7897
except:
7998
raise
80-
return pickle.dumps(pydata)
99+
elif pd_dict["type"].lower() == "raw":
100+
pydata = dill.loads(base64.b64decode(pydata))
101+
logger.debug("Returning pydata of type: %s", type(pydata))
102+
# return pydata
103+
return dill.dumps(pydata)
81104

82105

83106
##
@@ -117,34 +140,30 @@ def initialize(self, **kwargs):
117140
"""
118141
args = []
119142
pydata = None
143+
# pdict = {}
120144
pdict = {"type": "raw"} # initialize this value to enforce BytesIO
121145
self.data_type = pdict["type"]
122146
field_names = (
123147
[f["name"] for f in kwargs["fields"]] if "fields" in kwargs else []
124148
)
125149
if "pydata" in kwargs and not (
126150
"fields" in kwargs and "pydata" in field_names
127-
): # means that is was passed directly
151+
): # means that is was passed directly (e.g. from tests)
128152
pydata = kwargs.pop("pydata")
129-
logger.debug("pydata value provided: %s, %s", pydata, kwargs)
130-
try: # test whether given value is valid
131-
_ = pickle.loads(base64.b64decode(pydata))
132-
pydata = base64.b64decode(pydata)
133-
except:
134-
pydata = None
153+
pdict["value"] = pydata
154+
pydata = parse_pydata(pdict)
135155
elif "fields" in kwargs and "pydata" in field_names:
136156
data_pos = field_names.index("pydata")
137157
pdict = kwargs["fields"][data_pos]
138158
pydata = parse_pydata(pdict)
139-
if pdict["type"].lower() in ["str","string"]:
140-
self.data_type = "String"
141-
self._buf = io.StringIO(*args)
159+
if pdict and pdict["type"].lower() in ["str","string"]:
160+
self.data_type = "String" if pydata else "raw"
142161
else:
143-
self.data_type = pdict["type"]
144-
self._buf = io.BytesIO(*args)
162+
self.data_type = pdict["type"] if pdict else ""
145163
if pydata:
146164
args.append(pydata)
147-
logger.debug("Loaded into memory: %s, %s", pydata, self.data_type)
165+
logger.debug("Loaded into memory: %s, %s, %s", pydata, self.data_type, type(pydata))
166+
self._buf = io.BytesIO(*args) if self.data_type != "String" else io.StringIO(*args)
148167
self.size = len(pydata) if pydata else 0
149168

150169
def getIO(self):
@@ -230,7 +249,7 @@ def initialize(self, **kwargs):
230249
pydata = kwargs.pop("pydata")
231250
logger.debug("pydata value provided: %s", pydata)
232251
try: # test whether given value is valid
233-
_ = pickle.loads(base64.b64decode(pydata.encode("latin1")))
252+
_ = dill.loads(base64.b64decode(pydata.encode("latin1")))
234253
pydata = base64.b64decode(pydata.encode("latin1"))
235254
except:
236255
pydata = None

daliuge-engine/dlg/data/io.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# MA 02111-1307 USA
2121
#
2222
from abc import abstractmethod, ABCMeta
23+
import base64
2324
from http.client import HTTPConnection
2425
from multiprocessing.sharedctypes import Value
2526
from overrides import overrides
@@ -32,6 +33,7 @@
3233
from typing import Optional, Union
3334

3435
from dlg import ngaslite
36+
from dlg.common import b2s
3537

3638
if sys.version_info >= (3, 8):
3739
from dlg.shared_memory import DlgSharedMemory
@@ -251,16 +253,25 @@ def _open(self, **kwargs):
251253
elif self._mode == OpenMode.OPEN_READ:
252254
# TODO: potentially wasteful copy
253255
if isinstance(self._buf, io.StringIO):
256+
self._desc = io.StringIO
254257
return io.StringIO(self._buf.getvalue())
255258
return io.BytesIO(self._buf.getbuffer())
256259
else:
257260
raise ValueError()
258261

259262
@overrides
260263
def _write(self, data, **kwargs) -> int:
261-
if isinstance(data, str):
264+
if isinstance(self._desc, io.BytesIO) and isinstance(data, str):
262265
data = bytes(data, encoding="utf8")
263-
self._desc.write(data)
266+
elif isinstance(self._desc, io.StringIO) and isinstance(data, bytes):
267+
data = b2s(base64.b64encode(data))
268+
elif isinstance(data, memoryview):
269+
data = bytes(data)
270+
try:
271+
self._desc.write(data)
272+
except Exception:
273+
logger.debug("Writing of %s failed: %s", data)
274+
raise
264275
return len(data)
265276

266277
@overrides

0 commit comments

Comments
 (0)