diff --git a/lyncs_quda/enum.py b/lyncs_quda/enum.py index 4b4d4e3..29b3fc1 100644 --- a/lyncs_quda/enum.py +++ b/lyncs_quda/enum.py @@ -118,7 +118,7 @@ class Enum(metaclass=EnumMeta): def __init__(self, fnc, lpath=None, default=None, callback=None): # fnc is supposed to return either a stripped key name or value of - # the corresponding QUDA enum type + # the corresponding QUDA enum type so as to decorate a property obj self.fnc = fnc self.lpath = lpath self.default = default diff --git a/lyncs_quda/lib.py b/lyncs_quda/lib.py index 58c2a66..c9d550d 100644 --- a/lyncs_quda/lib.py +++ b/lyncs_quda/lib.py @@ -222,7 +222,7 @@ def copy_struct(self): """ ) return self.lyncs_quda_copy_struct - + def save_tuning(self): if self.tune_enabled: self.saveTuneCache() @@ -261,8 +261,8 @@ def __del__(self): PATHS = list(__path__) headers = [ - "comm_quda.h", "quda.h", + "comm_quda.h", "gauge_field.h", "gauge_tools.h", "gauge_path_quda.h", diff --git a/lyncs_quda/struct.py b/lyncs_quda/struct.py index 2fd3ffd..046d238 100644 --- a/lyncs_quda/struct.py +++ b/lyncs_quda/struct.py @@ -6,8 +6,9 @@ "to_code", ] -from lyncs_cppyy import nullptr -from lyncs_utils import isiterable, setitems +import numpy as np +from lyncs_cppyy import nullptr, to_pointer, addressof +from lyncs_utils import isiterable from .lib import lib from . import enums @@ -17,11 +18,13 @@ def to_human(val, typ=None): if typ is None: return val if typ in dir(enums): - return getattr(enums, typ)[val] + return str(getattr(enums, typ)[val]) if isinstance(val, (int, float)): return val if "*" in typ and val == nullptr: return 0 + if "char" in typ: + return "".join(list(val)) return val @@ -32,7 +35,7 @@ def to_code(val, typ=None): if typ in dir(enums): if isinstance(val, int): return val - return getattr(enums, typ)[val] + return int(getattr(enums, typ)[val]) if typ in ["int", "float", "double"]: return val if "char" in typ: @@ -45,14 +48,69 @@ def to_code(val, typ=None): return val +def get_dtype(typ): + if "*" in typ: + return np.dtype(object) + typ_dict = {"complex":"c", "unsigned":"u", "float":"single", "char":"byte", "comlex_double":"double"} + typ_list = typ.split() + dtype = "" + for w in typ_list: + if w in ("complex", "unsigned"): + dtype = typ_dict[w] + dtype + dtype += typ_dict.get(w, w) + if typ in ("bool", "long"): + dtype += "_" + if "int" in typ: + dtype += "c" + return np.dtype(dtype) + + +def setitems(arr, vals, shape=None, is_string=False): + "Sets items of an iterable object" + shape = shape if shape is not None else arr.shape + size = shape[0] #len(arr) + if not is_string and type(vals) == str: + # sometimes, vals is turned into str + vals = eval(vals) + if hasattr(vals, "__len__") and type(vals) != bytes: + if len(vals) > size: + raise ValueError( + f"Values size ({len(vals)}) larger than array size ({size})" + ) + else: + vals = (vals,) * size + for i, val in enumerate(vals): + if len(shape)>1 and hasattr(arr[i], "__len__"): + is_string = len(shape[1:]) == 1 and type(vals[0]) == str + setitems(arr[i], val, shape = shape[1:], is_string=is_string) + else: + arr[i] = val + + class Struct: "Struct base class" _types = {} def __init__(self, *args, **kwargs): - # ? better to simply store (key, val) pair into an instance's own __dict__, if key is in _types.keys() - self._params = getattr(lib, type(self).__name__)() # ? recursive? - + #? is *args necessary? when provided, it causes error in update + self._quda_params = getattr(lib, "new"+type(self).__name__)() + + # some fields are not set by QUDA's new* function + default_params = getattr(lib, type(self).__name__)() + for key in self.keys(): + # to avoid Enum error due to unexpected key-value pair + if self._types[key] in dir(enums) and not key in kwargs: + enm = getattr(enums, self._types[key]) + if not getattr(self._quda_params, key) in enm.values(): + val = list(enm.values())[-1] + self._assign(key, val) + + # temporal fix: newQudaMultigridParam does not assign a default value to n_level + if "Multigrid" in type(self).__name__: + n = getattr(self._quda_params, "n_level") + n = lib.QUDA_MAX_MG_LEVEL if n < 0 or n > lib.QUDA_MAX_MG_LEVEL else n + setattr(self._quda_params, "n_level", n) + for arg in args: self.update(arg) self.update(kwargs) @@ -67,28 +125,67 @@ def items(self): def update(self, params): "Updates values of the structure" - if not hasattr( - params, "items" - ): # ? in __init__, it takes *args, which is a tuple. expect a tuple of dict's? + if not hasattr(params, "items"): raise TypeError(f"Unsopported type for params: {type(params)}") for key, val in params.items(): setattr(self, key, val) def _assign(self, key, val): typ = self._types[key] - val = to_code(val, typ) - cur = getattr(self._params, key) - - if hasattr(cur, "shape"): # ? what is this? - setitems(cur, val) # ? what is this? + val = to_code(val, typ) + cur = getattr(self._quda_params, key) + + if "[" in self._types[key] and not hasattr(cur, "shape"):# not sure if this is needed for cppyy3.0.0 + # safeguard against hectic behavior of cppyy + raise RuntimeError("cppyy is not happy for now. Try again!") + + + if typ.count("[") > 1: + # cppyy<=3.0.0 cannot handle subviews properly + # Trying to manipulate the sub-array either results in error or segfault + # => array of arrays is set using glb.memcpy + # Alternative: + # use ctypes (C = ctypes, arr = LowlevelView of array of arrays) + # ptr = C.cast(cppyy.ll.addressof(arr), C.POINTER(C.c_int)) + # narr = np.ctypeslib.as_array(ptr, shape=arr.shape) + # This allows to access sub-indicies properly, i.e., narr[2][3] = 9 works + assert hasattr(cur, "shape") + if "file" in key: + #? array = np.zeros(cur.shape, dtype="S1") and remove setitems(array, b"\0"); is this ok? + #? not sure of this as "" is not b"\0" + array = np.chararray(cur.shape) + setitems(array, b"\0") + setitems(array, val) + size = 1 + else: + dtype = get_dtype(typ[:typ.index("[")].strip()) + array = np.asarray(val, dtype=dtype) + size = dtype.itemsize + lib.memcpy(to_pointer(addressof(cur)), to_pointer(array.__array_interface__["data"][0]), int(np.prod(cur.shape))*size) + elif typ.count("[") == 1: + assert hasattr(cur, "shape") + shape = tuple([getattr(lib, macro) for macro in typ.split(" ") if "QUDA_" in macro or macro.isnumeric()]) #not necessary for cppyy3.0.0? + cur.reshape(shape) #? not necessary for cppyy3.0.0? + if "*" in typ: + for i in range(shape[0]): + val = to_pointer(addressof(val), ctype = typ[:-typ.index("[")].strip()) + is_string = True if "char" in typ else False + if is_string: + setitems(cur, b"\0") # for printing + setitems(cur, val, is_string=is_string) else: - setattr(self._params, key, val) + if "*" in typ: + # cannot set nullptr to void *, int *, etc; works for classes such as Enum classes with bind_object + if val == nullptr: + raise ValueError("Cannot cast nullptr to a valid pointer") + val = to_pointer(addressof(val), ctype = typ) + setattr(self._quda_params, key, val) def __dir__(self): - return list(set(list(super().__dir__()) + list(self._params.keys()))) + return list(set(list(super().__dir__()) + list(self._quda_params.keys()))) def __getattr__(self, key): - return to_human(getattr(self._params, key), self._types[key]) + return to_human(getattr(self._quda_params, key), self._types[key]) def __setattr__(self, key, val): if key in self.keys(): @@ -98,8 +195,23 @@ def __setattr__(self, key, val): raise TypeError( f"Cannot assign '{val}' to '{key}' of type '{self._types[key]}'" ) - else: + else: #should we allow this? super().__setattr__(key, val) def __str__(self): return str(dict(self.items())) + + @property + def quda(self): + return self._quda_params + + @property + def address(self): + return addressof(self.quda) + + @property + def ptr(self): + return to_pointer(addressof(self.quda), ctype = type(self).__name__ + " *") + + def printf(self): + getattr(lib, "print"+type(self).__name__)(self._quda_params) diff --git a/test/test_structs.py b/test/test_structs.py index cf48951..38fda97 100644 --- a/test/test_structs.py +++ b/test/test_structs.py @@ -1,13 +1,37 @@ from lyncs_quda import structs # This is also importing Enum +from lyncs_quda.enum import Enum from lyncs_quda.testing import fixlib as lib - def test_assign_zero(lib): for struct in dir(structs): - if struct.startswith("_") or struct == "Struct": - continue + if struct.startswith("_") or struct == "Struct" or struct == "Enum" or issubclass(getattr(structs, struct), Enum): + continue + params = getattr(structs, struct)() - for key in params.keys(): - setattr(params, key, 0) + typ = getattr(structs, struct)._types[key] + obj = getattr(structs, typ) if typ in dir(structs) else None + val = 0 + if obj != Enum and issubclass(obj, Enum): + val = list(obj.values())[0] + elif "*" in typ: # cannot set a pointer field to nullptr via cppyy + continue + print("tst",struct,key,typ, obj,val) + + setattr(params, key, val) + +def test_assign_something(lib): + mp = structs.QudaMultigridParam() + ip = structs.QudaInvertParam() + ep = structs.QudaEigParam() + + # ptr to strct class works + mp.n_level = 3 # This is supposed to be set explicitly + mp.invert_param = ip.quda + ip.split_grid = list(range(lib.QUDA_MAX_DIM)) + ip.madwf_param_infile = "hi I'm here!" + mp.geo_block_size = [[i+j+1 for j in range(lib.QUDA_MAX_DIM)] for i in range(lib.QUDA_MAX_MG_LEVEL)] + mp.vec_infile = ["infile" + str(i) for i in range(lib.QUDA_MAX_MG_LEVEL)] + mp.printf() + print(ip.madwf_param_infile)