Changes in sasmodels/kerneldll.py [4d76711:a5b8477] in sasmodels


Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/kerneldll.py

    r4d76711 ra5b8477  
    4949import os 
    5050import tempfile 
    51 import ctypes as ct 
    52 from ctypes import c_void_p, c_int, c_longdouble, c_double, c_float 
    53 import _ctypes 
    54  
    55 import numpy as np 
     51import ctypes as ct  # type: ignore 
     52from ctypes import c_void_p, c_int32, c_longdouble, c_double, c_float  # type: ignore 
     53 
     54import numpy as np  # type: ignore 
    5655 
    5756from . import generate 
    58 from .kernelpy import PyInput, PyModel 
     57from .kernel import KernelModel, Kernel 
     58from .kernelpy import PyInput 
    5959from .exception import annotate_exception 
     60from .generate import F16, F32, F64 
     61 
     62try: 
     63    from typing import Tuple, Callable, Any 
     64    from .modelinfo import ModelInfo 
     65    from .details import CallDetails 
     66except ImportError: 
     67    pass 
    6068 
    6169# Compiler platform details 
     
    8189        COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm" 
    8290        if "SAS_OPENMP" in os.environ: 
    83             COMPILE = COMPILE + " -fopenmp" 
     91            COMPILE += " -fopenmp" 
    8492else: 
    8593    COMPILE = "cc -shared -fPIC -fopenmp -std=c99 -O2 -Wall %(source)s -o %(output)s -lm" 
     
    9098 
    9199 
    92 def dll_path(model_info, dtype="double"): 
    93     """ 
    94     Path to the compiled model defined by *model_info*. 
    95     """ 
    96     from os.path import join as joinpath, split as splitpath, splitext 
    97     basename = splitext(splitpath(model_info['filename'])[1])[0] 
    98     if np.dtype(dtype) == generate.F32: 
    99         basename += "32" 
    100     elif np.dtype(dtype) == generate.F64: 
    101         basename += "64" 
    102     else: 
    103         basename += "128" 
    104     return joinpath(DLL_PATH, basename+'.so') 
    105  
    106  
    107 def make_dll(source, model_info, dtype="double"): 
    108     """ 
    109     Load the compiled model defined by *kernel_module*. 
    110  
    111     Recompile if any files are newer than the model file. 
     100def dll_name(model_info, dtype): 
     101    # type: (ModelInfo, np.dtype) ->  str 
     102    """ 
     103    Name of the dll containing the model.  This is the base file name without 
     104    any path or extension, with a form such as 'sas_sphere32'. 
     105    """ 
     106    bits = 8*dtype.itemsize 
     107    return "sas_%s%d"%(model_info.id, bits) 
     108 
     109 
     110def dll_path(model_info, dtype): 
     111    # type: (ModelInfo, np.dtype) -> str 
     112    """ 
     113    Complete path to the dll for the model.  Note that the dll may not 
     114    exist yet if it hasn't been compiled. 
     115    """ 
     116    return os.path.join(DLL_PATH, dll_name(model_info, dtype)+".so") 
     117 
     118 
     119def make_dll(source, model_info, dtype=F64): 
     120    # type: (str, ModelInfo, np.dtype) -> str 
     121    """ 
     122    Returns the path to the compiled model defined by *kernel_module*. 
     123 
     124    If the model has not been compiled, or if the source file(s) are newer 
     125    than the dll, then *make_dll* will compile the model before returning. 
     126    This routine does not load the resulting dll. 
    112127 
    113128    *dtype* is a numpy floating point precision specifier indicating whether 
    114     the model should be single or double precision.  The default is double 
    115     precision. 
    116  
    117     The DLL is not loaded until the kernel is called so models can 
    118     be defined without using too many resources. 
     129    the model should be single, double or long double precision.  The default 
     130    is double precision, *np.dtype('d')*. 
     131 
     132    Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to False if single precision 
     133    models are not allowed as DLLs. 
    119134 
    120135    Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path. 
    121136    The default is the system temporary directory. 
    122  
    123     Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to True if single precision 
    124     models are allowed as DLLs. 
    125     """ 
    126     if callable(model_info.get('Iq', None)): 
    127         return PyModel(model_info) 
    128      
    129     dtype = np.dtype(dtype) 
    130     if dtype == generate.F16: 
     137    """ 
     138    if dtype == F16: 
    131139        raise ValueError("16 bit floats not supported") 
    132     if dtype == generate.F32 and not ALLOW_SINGLE_PRECISION_DLLS: 
    133         dtype = generate.F64  # Force 64-bit dll 
    134  
    135     if dtype == generate.F32: # 32-bit dll 
    136         tempfile_prefix = 'sas_' + model_info['name'] + '32_' 
    137     elif dtype == generate.F64: 
    138         tempfile_prefix = 'sas_' + model_info['name'] + '64_' 
    139     else: 
    140         tempfile_prefix = 'sas_' + model_info['name'] + '128_' 
    141   
    142     source = generate.convert_type(source, dtype) 
    143     source_files = generate.model_sources(model_info) + [model_info['filename']] 
     140    if dtype == F32 and not ALLOW_SINGLE_PRECISION_DLLS: 
     141        dtype = F64  # Force 64-bit dll 
     142    # Note: dtype may be F128 for long double precision 
     143 
     144    newest = generate.timestamp(model_info) 
    144145    dll = dll_path(model_info, dtype) 
    145     newest = max(os.path.getmtime(f) for f in source_files) 
    146146    if not os.path.exists(dll) or os.path.getmtime(dll) < newest: 
    147         # Replace with a proper temp file 
    148         fid, filename = tempfile.mkstemp(suffix=".c", prefix=tempfile_prefix) 
     147        basename = dll_name(model_info, dtype) + "_" 
     148        fid, filename = tempfile.mkstemp(suffix=".c", prefix=basename) 
     149        source = generate.convert_type(source, dtype) 
    149150        os.fdopen(fid, "w").write(source) 
    150151        command = COMPILE%{"source":filename, "output":dll} 
     
    160161 
    161162 
    162 def load_dll(source, model_info, dtype="double"): 
     163def load_dll(source, model_info, dtype=F64): 
     164    # type: (str, ModelInfo, np.dtype) -> "DllModel" 
    163165    """ 
    164166    Create and load a dll corresponding to the source, info pair returned 
     
    172174 
    173175 
    174 IQ_ARGS = [c_void_p, c_void_p, c_int] 
    175 IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int] 
    176  
    177 class DllModel(object): 
     176class DllModel(KernelModel): 
    178177    """ 
    179178    ctypes wrapper for a single model. 
     
    191190     
    192191    def __init__(self, dllpath, model_info, dtype=generate.F32): 
     192        # type: (str, ModelInfo, np.dtype) -> None 
    193193        self.info = model_info 
    194194        self.dllpath = dllpath 
    195         self.dll = None 
     195        self._dll = None  # type: ct.CDLL 
    196196        self.dtype = np.dtype(dtype) 
    197197 
    198198    def _load_dll(self): 
    199         Nfixed1d = len(self.info['partype']['fixed-1d']) 
    200         Nfixed2d = len(self.info['partype']['fixed-2d']) 
    201         Npd1d = len(self.info['partype']['pd-1d']) 
    202         Npd2d = len(self.info['partype']['pd-2d']) 
    203  
     199        # type: () -> None 
    204200        #print("dll", self.dllpath) 
    205201        try: 
    206             self.dll = ct.CDLL(self.dllpath) 
     202            self._dll = ct.CDLL(self.dllpath) 
    207203        except: 
    208204            annotate_exception("while loading "+self.dllpath) 
     
    212208              else c_double if self.dtype == generate.F64 
    213209              else c_longdouble) 
    214         pd_args_1d = [c_void_p, fp] + [c_int]*Npd1d if Npd1d else [] 
    215         pd_args_2d = [c_void_p, fp] + [c_int]*Npd2d if Npd2d else [] 
    216         self.Iq = self.dll[generate.kernel_name(self.info, False)] 
    217         self.Iq.argtypes = IQ_ARGS + pd_args_1d + [fp]*Nfixed1d 
    218  
    219         self.Iqxy = self.dll[generate.kernel_name(self.info, True)] 
    220         self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [fp]*Nfixed2d 
    221          
    222         self.release() 
     210 
     211        # int, int, int, int*, double*, double*, double*, double*, double*, double 
     212        argtypes = [c_int32]*3 + [c_void_p]*5 + [fp] 
     213        self._Iq = self._dll[generate.kernel_name(self.info, is_2d=False)] 
     214        self._Iqxy = self._dll[generate.kernel_name(self.info, is_2d=True)] 
     215        self._Iq.argtypes = argtypes 
     216        self._Iqxy.argtypes = argtypes 
    223217 
    224218    def __getstate__(self): 
     219        # type: () -> Tuple[ModelInfo, str] 
    225220        return self.info, self.dllpath 
    226221 
    227222    def __setstate__(self, state): 
     223        # type: (Tuple[ModelInfo, str]) -> None 
    228224        self.info, self.dllpath = state 
    229         self.dll = None 
     225        self._dll = None 
    230226 
    231227    def make_kernel(self, q_vectors): 
     228        # type: (List[np.ndarray]) -> DllKernel 
    232229        q_input = PyInput(q_vectors, self.dtype) 
    233         if self.dll is None: self._load_dll() 
    234         kernel = self.Iqxy if q_input.is_2d else self.Iq 
     230        # Note: pickle not supported for DllKernel 
     231        if self._dll is None: 
     232            self._load_dll() 
     233        kernel = self._Iqxy if q_input.is_2d else self._Iq 
    235234        return DllKernel(kernel, self.info, q_input) 
    236235 
    237236    def release(self): 
     237        # type: () -> None 
    238238        """ 
    239239        Release any resources associated with the model. 
     
    244244            libHandle = dll._handle 
    245245            #libHandle = ct.c_void_p(dll._handle) 
    246             del dll, self.dll 
    247             self.dll = None 
     246            del dll, self._dll 
     247            self._dll = None 
    248248            #_ctypes.FreeLibrary(libHandle) 
    249249            ct.windll.kernel32.FreeLibrary(libHandle) 
     
    252252 
    253253 
    254 class DllKernel(object): 
     254class DllKernel(Kernel): 
    255255    """ 
    256256    Callable SAS kernel. 
     
    272272    """ 
    273273    def __init__(self, kernel, model_info, q_input): 
     274        # type: (Callable[[], np.ndarray], ModelInfo, PyInput) -> None 
     275        self.kernel = kernel 
    274276        self.info = model_info 
    275277        self.q_input = q_input 
    276         self.kernel = kernel 
    277         self.res = np.empty(q_input.nq, q_input.dtype) 
    278         dim = '2d' if q_input.is_2d else '1d' 
    279         self.fixed_pars = model_info['partype']['fixed-' + dim] 
    280         self.pd_pars = model_info['partype']['pd-' + dim] 
    281  
    282         # In dll kernel, but not in opencl kernel 
    283         self.p_res = self.res.ctypes.data 
    284  
    285     def __call__(self, fixed_pars, pd_pars, cutoff): 
    286         real = (np.float32 if self.q_input.dtype == generate.F32 
    287                 else np.float64 if self.q_input.dtype == generate.F64 
    288                 else np.float128) 
    289  
    290         nq = c_int(self.q_input.nq) 
    291         if pd_pars: 
    292             cutoff = real(cutoff) 
    293             loops_N = [np.uint32(len(p[0])) for p in pd_pars] 
    294             loops = np.hstack(pd_pars) 
    295             loops = np.ascontiguousarray(loops.T, self.q_input.dtype).flatten() 
    296             p_loops = loops.ctypes.data 
    297             dispersed = [p_loops, cutoff] + loops_N 
    298         else: 
    299             dispersed = [] 
    300         fixed = [real(p) for p in fixed_pars] 
    301         args = self.q_input.q_pointers + [self.p_res, nq] + dispersed + fixed 
    302         #print(pars) 
    303         self.kernel(*args) 
    304  
    305         return self.res 
     278        self.dtype = q_input.dtype 
     279        self.dim = '2d' if q_input.is_2d else '1d' 
     280        self.result = np.empty(q_input.nq+1, q_input.dtype) 
     281        self.real = (np.float32 if self.q_input.dtype == generate.F32 
     282                     else np.float64 if self.q_input.dtype == generate.F64 
     283                     else np.float128) 
     284 
     285    def __call__(self, call_details, weights, values, cutoff): 
     286        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 
     287 
     288        #print("in kerneldll") 
     289        #print("weights", weights) 
     290        #print("values", values) 
     291        start, stop = 0, call_details.total_pd 
     292        args = [ 
     293            self.q_input.nq, # nq 
     294            start, # pd_start 
     295            stop, # pd_stop pd_stride[MAX_PD] 
     296            call_details.ctypes.data, # problem 
     297            weights.ctypes.data,  # weights 
     298            values.ctypes.data,  #pars 
     299            self.q_input.q.ctypes.data, #q 
     300            self.result.ctypes.data,   # results 
     301            self.real(cutoff), # cutoff 
     302            ] 
     303        self.kernel(*args) # type: ignore 
     304        return self.result[:-1] 
    306305 
    307306    def release(self): 
     307        # type: () -> None 
    308308        """ 
    309309        Release any resources associated with the kernel. 
    310310        """ 
    311         pass 
     311        self.q_input.release() 
Note: See TracChangeset for help on using the changeset viewer.