diff --git a/src/madp_pymad.py b/src/madp_pymad.py index ca185122f..4091ecb9d 100644 --- a/src/madp_pymad.py +++ b/src/madp_pymad.py @@ -4,64 +4,63 @@ __all__ = ["mad_process"] -def get_typestring(a: Union[str, int, float, np.ndarray, bool, list]): - if isinstance(a, np.ndarray): - return a.dtype - elif type(a) is int: # Check for signed 32 bit int - if a.bit_length() < 31: return int - else: return float - else: - return type(a) - -data_types = { - type(None) : "nil_", - str : "str_", - int : "int_", - np.int32 : "int_", - float : "num_", - np.float64 : "num_", - complex : "cpx_", - np.complex128 : "cpx_", - bool : "bool", - list : "tbl_", - range : "irng", - np.dtype("float64") : "mat_", - np.dtype("complex128") : "cmat", - np.dtype("int32") : "imat", - np.dtype("ubyte") : "mono", -} + +def is_not_private(varname): + if varname[:2] == "__" and varname[:8] != "__last__": + return False + return True + + class mad_process: def __init__(self, mad_path: str, py_name: str = "py", debug: bool = False) -> None: self.py_name = py_name + # Create the pipes for communication self.from_mad, mad_write = os.pipe() mad_read, self.to_mad = os.pipe() - self.fto_mad = os.fdopen(self.to_mad, "wb", buffering=0) # Sensible to not buffer stdin? - - startupChunk = f"MAD.pymad '{py_name}' {{_dbg = {str(debug).lower()}}} :__ini({mad_write})" + # Open the pipes for communication to MAD (the stdin of MAD) + self.fto_mad = os.fdopen(self.to_mad, "wb", buffering=0) + + # Create a chunk of code to start the process + startupChunk = ( + f"MAD.pymad '{py_name}' {{_dbg = {str(debug).lower()}}} :__ini({mad_write})" + ) + + # Start the process self.process = subprocess.Popen( [mad_path, "-q", "-e", startupChunk], bufsize=0, - stdin=mad_read, - stdout=sys.stdout.fileno(), + stdin=mad_read, # Set the stdin of MAD to the read end of the pipe + stdout=sys.stdout.fileno(), # Forward stdout preexec_fn=os.setpgrp, # Don't forward signals - pass_fds=[mad_write, sys.stdout.fileno(), sys.stderr.fileno()], + pass_fds=[ + mad_write, + sys.stdout.fileno(), + sys.stderr.fileno(), + ], # Don't close these (python closes all fds by default) ) + + # Close the ends of the pipes that are not used by the process os.close(mad_write) os.close(mad_read) - self.globalVars = {"np" : np} + # Create a global variable dictionary for the exec function (could be extended to include more variables) + self.globalVars = {"np": np} + + # Open the pipe from MAD (this is where MAD will no longer hang) self.ffrom_mad = os.fdopen(self.from_mad, "rb") - # stdout should be line buffered by default, but for jupyter notebook, + # stdout should be line buffered by default, but for jupyter notebook, # stdout is redirected and not line buffered by default self.send( f"""io.stdout:setvbuf('line') - {self.py_name}:send(1)""" + {self.py_name}:send(1)""" ) - checker = select.select([self.ffrom_mad], [], [], 1) # May not work on windows - if not checker[0] or self.recv() != 1: # Need to check number? + + # Check if MAD started successfully using select + checker = select.select([self.ffrom_mad], [], [], 1) # May not work on windows + if not checker[0] or self.recv() != 1: # Need to check number? raise OSError(f"Unsuccessful starting of {mad_path} process") def send_rng(self, start: float, stop: float, size: int): @@ -87,20 +86,19 @@ def send_ctpsa(self, monos: np.ndarray, coefficients: np.ndarray): def send(self, data: Union[str, int, float, np.ndarray, bool, list]): """Send data to MAD, returns self for chaining""" try: - typ = data_types[get_typestring(data)] + typ = type_str[get_typestr(data)] self.fto_mad.write(typ.encode("utf-8")) - str_to_fun[typ]["send"](self, data) + type_fun[typ]["send"](self, data) return self except KeyError: # raise not in exception to reduce error output - pass - raise TypeError( - f"Unsupported data type, expected a type in: \n{list(data_types.keys())}, got {type(data)}" - ) + raise TypeError( + f"Unsupported data type, expected a type in: \n{list(type_str.keys())}, got {type(data)}" + ) from None def psend(self, string: str): """Perform a protected send to MAD, by first enabling error handling, so that if an error occurs, an error is returned""" return self.send(f"{self.py_name}:__err(true); {string}; {self.py_name}:__err(false);") - + def precv(self, name: str): """Perform a protected send receive to MAD, by first enabling error handling, so that if an error occurs, an error is received""" return self.send(f"{self.py_name}:__err(true):send({name}):__err(false)").recv(name) @@ -109,70 +107,58 @@ def errhdlr(self, on_off: bool): """Enable or disable error handling""" self.send(f"{self.py_name}:__err({str(on_off).lower()})") - def recv( - self, varname: str = None - ) -> Union[str, int, float, np.ndarray, bool, list]: + def recv(self, varname: str = None): """Receive data from MAD, if a function is returned, it will be executed with the argument mad_communication""" typ = self.ffrom_mad.read(4).decode("utf-8") self.varname = varname # For mad reference - return str_to_fun[typ]["recv"](self) + return type_fun[typ]["recv"](self) - def recv_and_exec(self, env: dict = {}) -> dict: + def recv_and_exec(self, env: dict = {}): """Read data from MAD and execute it""" - try: env["mad"] # Check if user has already defined mad (madp_object will have mad defined) - except KeyError: env["mad"] = self # If not, define it + # Check if user has already defined mad (madp_object will have mad defined), otherwise define it + try: env["mad"] + except KeyError: env["mad"] = self exec(compile(self.recv(), "ffrom_mad", "exec"), self.globalVars, env) return env - # ----------------- Dealing with communication of variables ---------------# - def send_vars(self, names, vars): - if isinstance(names, str): - names = [names] - vars = [vars] - else: - assert isinstance(vars, list), "A list of names must be matched with a list of variables" - assert len(vars) == len(names), "The number of names must match the number of variables" - for i, var in enumerate(vars): - if isinstance(vars[i], mad_ref): - self.send(f"{names[i]} = {var.__name__}") + # ----------------- Dealing with communication of variables ---------------- # + def send_vars(self, **vars): + for name, var in vars.items(): + if isinstance(var, mad_ref): + self.send(f"{name} = {var.__name__}") else: - self.send(f"{names[i]} = {self.py_name}:recv()").send(var) + self.send(f"{name} = {self.py_name}:recv()").send(var) - def recv_vars(self, names) -> Any: - if isinstance(names, str): - names = [names] - cnvrt = lambda rtrn: rtrn[0] - else: - cnvrt = lambda rtrn: tuple(rtrn) - - rtrn_vars = [] - for name in names: - if name[:2] != "__" or name[:8] == "__last__": # Check for private variables - rtrn_vars.append(self.precv(name)) - return cnvrt(rtrn_vars) + def recv_vars(self, *names): + if len(names) == 1: + if is_not_private(names[0]): + return self.precv(names[0]) + else: + return tuple(self.precv(name) for name in names if is_not_private(name)) - # -------------------------------------------------------------------------# + # -------------------------------------------------------------------------- # def __del__(self): self.send(f"{self.py_name}:__fin()") self.ffrom_mad.close() - self.process.terminate() #In case user left mad waiting + self.process.terminate() # In case user left mad waiting self.fto_mad.close() self.process.wait() - -class mad_ref(object): - def __init__(self, name: str, mad_proc: mad_process) -> None: + + +class mad_ref(object): + def __init__(self, name: str, mad_proc: mad_process): assert name is not None, "Reference must have a variable to reference to. Did you forget to put a name in the receive functions?" self.__name__ = name self.__mad__ = mad_proc - def __getattr__ (self, item): + def __getattr__(self, item): if item[0] != "_": try: return self[item] except (IndexError, KeyError): pass - raise AttributeError (item) # For python + raise AttributeError(item) # For python def __getitem__(self, item: Union[str, int]): if isinstance(item, int): @@ -191,105 +177,148 @@ def __getitem__(self, item: Union[str, int]): def eval(self): return self.__mad__.recv_vars(self.__name__) -data_types[mad_ref] = "ref_" # Add mad_ref to the datatypes -# -------------------------------- Sending data -------------------------------# +# data transfer -------------------------------------------------------------- # + +# Data ----------------------------------------------------------------------- # + + +def send_dat(self: mad_process, dat_fmt: str, *dat: Any): + self.fto_mad.write(struct.pack(dat_fmt, *dat)) + + +def recv_dat(self: mad_process, dat_sz: int, dat_typ: np.dtype): + return np.frombuffer(self.ffrom_mad.read(dat_sz), dtype=dat_typ) + + +# None ----------------------------------------------------------------------- # + send_nil = lambda self, input: None +recv_nil = lambda self: None + +# Boolean -------------------------------------------------------------------- # + +send_bool = lambda self, input: self.fto_mad.write(struct.pack("?", input)) +recv_bool = lambda self: recv_dat(self, 1, np.bool_)[0] -def send_int(self: mad_process, input: int) -> None: - self.fto_mad.write(struct.pack("i", input)) +# int32 ---------------------------------------------------------------------- # -def send_str(self: mad_process, input: str) -> None: +send_int = lambda self, input: send_dat(self, "i", input) +recv_int = lambda self: recv_dat(self, 4, np.int32)[ + 0 +] # Should it be a python int or a numpy int32? + +# String --------------------------------------------------------------------- # + + +def send_str(self: mad_process, input: str): send_int(self, len(input)) self.fto_mad.write(input.encode("utf-8")) -def send_ref(self: mad_process, obj: mad_ref) -> None: - send_str(self, f"return {obj.__name__}") -def send_num(self: mad_process, input: float) -> None: - self.fto_mad.write(struct.pack("d", input)) +def recv_str(self: mad_process) -> str: + return self.ffrom_mad.read(recv_int(self)).decode("utf-8") + + +# number (float64) ----------------------------------------------------------- # + +send_num = lambda self, input: send_dat(self, "d", input) +recv_num = lambda self: recv_dat(self, 8, np.float64)[0] + +# Complex (complex128) ------------------------------------------------------- # + +send_cpx = lambda self, input: send_dat(self, "dd", input.real, input.imag) +recv_cpx = lambda self: recv_dat(self, 16, np.complex128)[0] + +# Range ---------------------------------------------------------------------- # + +send_grng = lambda self, start, stop, size: send_dat(self, "ddi", start, stop, size) + + +def recv_rng(self: mad_process) -> np.ndarray: + return np.linspace(*struct.unpack("ddi", self.ffrom_mad.read(20))) + + +def recv_lrng(self: mad_process) -> np.ndarray: + return np.geomspace(*struct.unpack("ddi", self.ffrom_mad.read(20))) + -def send_cpx(self: mad_process, input: complex) -> None: - self.fto_mad.write(struct.pack("dd", input.real, input.imag)) +# irange --------------------------------------------------------------------- # -def send_bool(self: mad_process, input: bool) -> None: - self.fto_mad.write(struct.pack("?", input)) +send_irng = lambda self, rng: send_dat(self, "iii", rng.start, rng.stop, rng.step) -def send_gmat(self: mad_process, mat: np.ndarray) -> None: +def recv_irng(self: mad_process) -> range: + start, stop, step = recv_dat(self, 12, np.int32) + return range(start, stop + 1, step) # MAD is inclusive at both ends + + +# matrix --------------------------------------------------------------------- # + + +def send_gmat(self: mad_process, mat: np.ndarray): assert len(mat.shape) == 2, "Matrix must be of two dimensions" - send_int(self, mat.shape[0]) - send_int(self, mat.shape[1]) + send_dat(self, "ii", *mat.shape) self.fto_mad.write(mat.tobytes()) -def send_list(self: mad_process, lst: list) -> None: - n = len(lst) - send_int(self, n) - for item in lst: - self.send(item) # deep copy - return self -def send_grng(self: mad_process, start: float, stop: float, size: int) -> None: - self.fto_mad.write(struct.pack("ddi", start, stop, size)) +def recv_gmat(self: mad_process, dtype: np.dtype) -> str: + shape = recv_dat(self, 8, np.int32) + return recv_dat(self, shape[0] * shape[1] * dtype.itemsize, dtype).reshape(shape) + -def send_irng(self: mad_process, rng: range) -> None: - self.fto_mad.write(struct.pack("iii", rng.start, rng.stop, rng.step)) +recv_mat = lambda self: recv_gmat(self, np.dtype("float64")) +recv_cmat = lambda self: recv_gmat(self, np.dtype("complex128")) +recv_imat = lambda self: recv_gmat(self, np.dtype("int32")) -def send_mono(self: mad_process, mono: np.ndarray) -> None: +# monomial ------------------------------------------------------------------- # + + +def send_mono(self: mad_process, mono: np.ndarray): send_int(self, mono.size) self.fto_mad.write(mono.tobytes()) +recv_mono = lambda self: recv_dat(self, recv_int(self), np.ubyte) + +# TPSA ----------------------------------------------------------------------- # + + def send_gtpsa( self: mad_process, monos: np.ndarray, coefficients: np.ndarray, - fsendNum: Callable[[mad_process, Union[float, complex]], None], -) -> None: + send_num: Callable[[mad_process, Union[float, complex]], None], +): assert len(monos.shape) == 2, "The list of monomials must have two dimensions" assert len(monos) == len(coefficients), "The number of monomials must be equal to the number of coefficients" assert monos.dtype == np.uint8, "The monomials must be of type 8-bit unsigned integer " - send_int(self, len(monos)) # Num monomials - send_int(self, len(monos[0])) # Monomial length + send_dat(self, "ii", len(monos), len(monos[0])) for mono in monos: self.fto_mad.write(mono.tobytes()) for coefficient in coefficients: - fsendNum(self, coefficient) -# -----------------------------------------------------------------------------# - -# --------------------------- Receiving data ----------------------------------# - -recv_nil = lambda self: None + send_num(self, coefficient) -def recv_ref(self: mad_process) -> mad_ref: - return mad_ref(self.varname, self) -def recv_str(self: mad_process) -> str: - return self.ffrom_mad.read(recv_int(self)).decode("utf-8") - -def recv_int(self: mad_process) -> int: # Must be int32 - return int.from_bytes(self.ffrom_mad.read(4), sys.byteorder, signed=True) +def recv_gtpsa(self: mad_process, dtype: np.dtype) -> np.ndarray: + num_mono, mono_len = recv_dat(self, 8, np.int32) + mono_list = np.reshape( + recv_dat(self, mono_len * num_mono, np.ubyte), + (num_mono, mono_len), + ) + coefficients = recv_dat(self, num_mono * dtype.itemsize, dtype) + return mono_list, coefficients -def recv_num(self: mad_process) -> float: - return np.frombuffer(self.ffrom_mad.read(8), dtype=np.float64)[0] -def recv_cpx(self: mad_process) -> complex: - return np.frombuffer(self.ffrom_mad.read(16), dtype=np.complex128)[0] +recv_ctpa = lambda self: recv_gtpsa(self, np.dtype("complex128")) +recv_tpsa = lambda self: recv_gtpsa(self, np.dtype("float64")) -def recv_bool(self: mad_process) -> str: - return np.frombuffer(self.ffrom_mad.read(1), dtype=np.bool_)[0] +# lists ---------------------------------------------------------------------- # -def recv_gmat(self: mad_process, dtype: np.dtype) -> str: - shape = np.frombuffer(self.ffrom_mad.read(8), dtype=np.int32) - arraySize = shape[0] * shape[1] * dtype.itemsize - return np.frombuffer(self.ffrom_mad.read(arraySize), dtype=dtype).reshape(shape) -def recv_mat(self: mad_process) -> str: - return recv_gmat(self, np.dtype("float64")) - -def recv_cmat(self: mad_process) -> str: - return recv_gmat(self, np.dtype("complex128")) +def send_list(self: mad_process, lst: list): + send_int(self, len(lst)) + for item in lst: + self.send(item) -def recv_imat(self: mad_process) -> str: - return recv_gmat(self, np.dtype("int32")) def recv_list(self: mad_process) -> list: varname = self.varname # cache @@ -298,51 +327,28 @@ def recv_list(self: mad_process) -> list: vals = [self.recv(varname and varname + f"[{i+1}]") for i in range(lstLen)] self.varname = varname # reset if haskeys and lstLen == 0: - return recv_ref(self) + return type_fun["ref_"]["recv"](self) elif haskeys: - return vals, recv_ref(self) + return vals, type_fun["ref_"]["recv"](self) else: return vals -def recv_irng(self: mad_process) -> range: - start, stop, step = np.frombuffer(self.ffrom_mad.read(12), dtype=np.int32) - return range(start, stop + 1, step) # MAD is inclusive at both ends -def recv_rng(self: mad_process) -> np.ndarray: - return np.linspace(*struct.unpack("ddi", self.ffrom_mad.read(20))) +# object (table with metatable are treated as pure reference) ---------------- # -def recv_lrng(self: mad_process) -> np.ndarray: - return np.geomspace(*struct.unpack("ddi", self.ffrom_mad.read(20))) +recv_ref = lambda self: mad_ref(self.varname, self) +send_ref = lambda self, obj: send_str(self, f"return {obj.__name__}") -def recv_mono(self: mad_process) -> np.ndarray: - mono_len = recv_int(self) - return np.frombuffer(self.ffrom_mad.read(mono_len), dtype=np.ubyte) +# error ---------------------------------------------------------------------- # -def recv_gtpsa(self: mad_process, dtype: np.dtype) -> np.ndarray: - num_mono, mono_len = np.frombuffer(self.ffrom_mad.read(8), dtype=np.int32) - mono_list = np.reshape( - np.frombuffer(self.ffrom_mad.read(mono_len * num_mono), dtype=np.ubyte), - (num_mono, mono_len), - ) - coefficients = np.frombuffer( - self.ffrom_mad.read(num_mono * dtype.itemsize), dtype=dtype - ) - return mono_list, coefficients - -def recv_ctpa(self: mad_process): - return recv_gtpsa(self, np.dtype("complex128")) - -def recv_tpsa(self: mad_process): - return recv_gtpsa(self, np.dtype("float64")) def recv_err(self: mad_process): self.errhdlr(False) raise RuntimeError("MAD Errored (see the MAD error output)") -# -----------------------------------------------------------------------------# -# ----------------------------- Dict for data ---------------------------------# -str_to_fun = { +# ---------------------------- dispatch tables ------------------------------- # +type_fun = { "nil_": {"recv": recv_nil , "send": send_nil }, "bool": {"recv": recv_bool, "send": send_bool}, "str_": {"recv": recv_str , "send": send_str }, @@ -364,4 +370,37 @@ def recv_err(self: mad_process): "ctpa": {"recv": recv_ctpa, }, "err_": {"recv": recv_err , }, } -# ---------------------------------------------------------------------------- # \ No newline at end of file + + +def get_typestr(a: Union[str, int, float, np.ndarray, bool, list]): + if isinstance(a, np.ndarray): + return a.dtype + elif type(a) is int: # Check for signed 32 bit int + if a.bit_length() < 31: + return int + else: + return float + else: + return type(a) + + +type_str = { + type(None) : "nil_", + bool : "bool", + str : "str_", + list : "tbl_", + tuple : "tbl_", + mad_ref : "ref_", + int : "int_", + np.int32 : "int_", + float : "num_", + np.float64 : "num_", + complex : "cpx_", + np.complex128 : "cpx_", + range : "irng", + np.dtype("float64") : "mat_", + np.dtype("complex128") : "cmat", + np.dtype("int32") : "imat", + np.dtype("ubyte") : "mono", +} +# ---------------------------------------------------------------------------- # diff --git a/tests/utests/comm_tests.py b/tests/utests/comm_tests.py index f27d05bfe..e9cf63c52 100644 --- a/tests/utests/comm_tests.py +++ b/tests/utests/comm_tests.py @@ -11,8 +11,8 @@ class TestExecution(unittest.TestCase): def test_recv_and_exec(self): mad = MAD("../mad") mad.send("""py:send([==[mad.send('''py:send([=[mad.send("py:send([[a = 100/2]])")]=])''')]==])""") - mad.recv_and_exec({"mad": mad}) - mad.recv_and_exec({"mad": mad}) + mad.recv_and_exec() + mad.recv_and_exec() a = mad.recv_and_exec()["a"] self.assertEqual(a, 50) @@ -121,10 +121,10 @@ def test_send_recv_int(self): mad.send(int_lst[i]) recv_num = mad.recv() self.assertEqual(recv_num, int_lst[i]) - self.assertTrue(isinstance(recv_num, int)) + self.assertTrue(isinstance(recv_num, np.int32)) recv_num = mad.recv() self.assertEqual(recv_num, -int_lst[i]) - self.assertTrue(isinstance(recv_num, int)) + self.assertTrue(isinstance(recv_num, np.int32)) self.assertTrue(mad.recv()) def test_send_recv_num(self): @@ -159,6 +159,45 @@ def test_send_recv_cpx(self): self.assertEqual(mad.recv(), -my_cpx) self.assertEqual(mad.recv(), my_cpx * 1.31j) +class TestMatrices(unittest.TestCase): + + def test_send_recv_imat(self): + mad = MAD("../mad") + mad.send(""" + local imat = py:recv() + py:send(imat) + py:send(MAD.imatrix(3, 5):seq()) + """) + imat = np.random.randint(0, 255, (5, 5), dtype=np.int32) + mad.send(imat) + self.assertTrue(np.all(mad.recv() == imat)) + self.assertTrue(np.all(mad.recv() == np.arange(1, 16).reshape(3, 5))) + + def test_send_recv_mat(self): + mad = MAD("../mad") + mad.send(""" + local mat = py:recv() + py:send(mat) + py:send(MAD.matrix(3, 5):seq() / 2) + """) + mat = np.arange(1, 25).reshape(4, 6) / 4 + mad.send(mat) + self.assertTrue(np.all(mad.recv() == mat)) + self.assertTrue(np.all(mad.recv() == np.arange(1, 16).reshape(3, 5) / 2)) + + def test_send_recv_cmat(self): + mad = MAD("../mad") + mad.send(""" + local cmat = py:recv() + py:send(cmat) + py:send(MAD.cmatrix(3, 5):seq() / 2i) + """) + cmat = np.arange(1, 25).reshape(4, 6) / 4 + 1j * np.arange(1, 25).reshape(4, 6) / 4 + mad.send(cmat) + self.assertTrue(np.all(mad.recv() == cmat)) + self.assertTrue(np.all(mad.recv() == (np.arange(1, 16).reshape(3, 5) / 2j))) + + class TestRngs(unittest.TestCase): def test_recv(self):