Changeset f619de7 in sasmodels for sasmodels/kerneldll.py


Ignore:
Timestamp:
Apr 11, 2016 9:14:50 AM (8 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, core_shell_microgels, costrafo411, magnetic_model, release_v0.94, release_v0.95, ticket-1257-vesicle-product, ticket_1156, ticket_1265_superball, ticket_822_more_unit_tests
Children:
7ae2b7f
Parents:
9a943d0
Message:

more type hinting

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/kerneldll.py

    r6d6508e rf619de7  
    5656from . import generate 
    5757from . import details 
    58 from .kernelpy import PyInput, PyModel 
     58from .kernel import KernelModel, Kernel 
     59from .kernelpy import PyInput 
    5960from .exception import annotate_exception 
     61from .generate import F16, F32, F64 
     62 
     63try: 
     64    from typing import Tuple, Callable, Any 
     65    from .modelinfo import ModelInfo 
     66    from .details import CallDetails 
     67except ImportError: 
     68    pass 
    6069 
    6170# Compiler platform details 
     
    91100 
    92101def dll_name(model_info, dtype): 
     102    # type: (ModelInfo, np.dtype) ->  str 
    93103    """ 
    94104    Name of the dll containing the model.  This is the base file name without 
     
    98108    return "sas_%s%d"%(model_info.id, bits) 
    99109 
     110 
    100111def dll_path(model_info, dtype): 
     112    # type: (ModelInfo, np.dtype) -> str 
    101113    """ 
    102114    Complete path to the dll for the model.  Note that the dll may not 
     
    105117    return os.path.join(DLL_PATH, dll_name(model_info, dtype)+".so") 
    106118 
    107 def make_dll(source, model_info, dtype="double"): 
    108     """ 
    109     Load the compiled model defined by *kernel_module*. 
    110  
    111     Recompile if any files are newer than the model file. 
     119 
     120def make_dll(source, model_info, dtype=F64): 
     121    # type: (str, ModelInfo, np.dtype) -> str 
     122    """ 
     123    Returns the path to the compiled model defined by *kernel_module*. 
     124 
     125    If the model has not been compiled, or if the source file(s) are newer 
     126    than the dll, then *make_dll* will compile the model before returning. 
     127    This routine does not load the resulting dll. 
    112128 
    113129    *dtype* is a numpy floating point precision specifier indicating whether 
    114     the model should be single or double precision.  The default is double 
    115     precision. 
    116  
    117     The DLL is not loaded until the kernel is called so models can 
    118     be defined without using too many resources. 
     130    the model should be single, double or long double precision.  The default 
     131    is double precision, *np.dtype('d')*. 
     132 
     133    Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to False if single precision 
     134    models are not allowed as DLLs. 
    119135 
    120136    Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path. 
    121137    The default is the system temporary directory. 
    122  
    123     Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to True if single precision 
    124     models are allowed as DLLs. 
    125     """ 
    126     if callable(model_info.Iq): 
    127         return PyModel(model_info) 
    128      
    129     dtype = np.dtype(dtype) 
    130     if dtype == generate.F16: 
     138    """ 
     139    if dtype == F16: 
    131140        raise ValueError("16 bit floats not supported") 
    132     if dtype == generate.F32 and not ALLOW_SINGLE_PRECISION_DLLS: 
    133         dtype = generate.F64  # Force 64-bit dll 
    134  
    135     source = generate.convert_type(source, dtype) 
     141    if dtype == F32 and not ALLOW_SINGLE_PRECISION_DLLS: 
     142        dtype = F64  # Force 64-bit dll 
     143    # Note: dtype may be F128 for long double precision 
     144 
    136145    newest = generate.timestamp(model_info) 
    137146    dll = dll_path(model_info, dtype) 
     
    139148        basename = dll_name(model_info, dtype) + "_" 
    140149        fid, filename = tempfile.mkstemp(suffix=".c", prefix=basename) 
     150        source = generate.convert_type(source, dtype) 
    141151        os.fdopen(fid, "w").write(source) 
    142152        command = COMPILE%{"source":filename, "output":dll} 
     
    152162 
    153163 
    154 def load_dll(source, model_info, dtype="double"): 
     164def load_dll(source, model_info, dtype=F64): 
     165    # type: (str, ModelInfo, np.dtype) -> "DllModel" 
    155166    """ 
    156167    Create and load a dll corresponding to the source, info pair returned 
     
    163174    return DllModel(filename, model_info, dtype=dtype) 
    164175 
    165 class DllModel(object): 
     176 
     177class DllModel(KernelModel): 
    166178    """ 
    167179    ctypes wrapper for a single model. 
     
    179191     
    180192    def __init__(self, dllpath, model_info, dtype=generate.F32): 
     193        # type: (str, ModelInfo, np.dtype) -> None 
    181194        self.info = model_info 
    182195        self.dllpath = dllpath 
    183         self.dll = None 
     196        self._dll = None  # type: ct.CDLL 
    184197        self.dtype = np.dtype(dtype) 
    185198 
    186199    def _load_dll(self): 
     200        # type: () -> None 
    187201        #print("dll", self.dllpath) 
    188202        try: 
    189             self.dll = ct.CDLL(self.dllpath) 
     203            self._dll = ct.CDLL(self.dllpath) 
    190204        except: 
    191205            annotate_exception("while loading "+self.dllpath) 
     
    198212        # int, int, int, int*, double*, double*, double*, double*, double*, double 
    199213        argtypes = [c_int32]*3 + [c_void_p]*5 + [fp] 
    200         self.Iq = self.dll[generate.kernel_name(self.info, False)] 
    201         self.Iqxy = self.dll[generate.kernel_name(self.info, True)] 
    202         self.Iq.argtypes = argtypes 
    203         self.Iqxy.argtypes = argtypes 
     214        self._Iq = self._dll[generate.kernel_name(self.info, is_2d=False)] 
     215        self._Iqxy = self._dll[generate.kernel_name(self.info, is_2d=True)] 
     216        self._Iq.argtypes = argtypes 
     217        self._Iqxy.argtypes = argtypes 
    204218 
    205219    def __getstate__(self): 
     220        # type: () -> Tuple[ModelInfo, str] 
    206221        return self.info, self.dllpath 
    207222 
    208223    def __setstate__(self, state): 
     224        # type: (Tuple[ModelInfo, str]) -> None 
    209225        self.info, self.dllpath = state 
    210         self.dll = None 
     226        self._dll = None 
    211227 
    212228    def make_kernel(self, q_vectors): 
     229        # type: (List[np.ndarray]) -> DllKernel 
    213230        q_input = PyInput(q_vectors, self.dtype) 
    214         if self.dll is None: self._load_dll() 
    215         kernel = self.Iqxy if q_input.is_2d else self.Iq 
     231        # Note: pickle not supported for DllKernel 
     232        if self._dll is None: 
     233            self._load_dll() 
     234        kernel = self._Iqxy if q_input.is_2d else self._Iq 
    216235        return DllKernel(kernel, self.info, q_input) 
    217236 
    218237    def release(self): 
     238        # type: () -> None 
    219239        """ 
    220240        Release any resources associated with the model. 
     
    225245            libHandle = dll._handle 
    226246            #libHandle = ct.c_void_p(dll._handle) 
    227             del dll, self.dll 
    228             self.dll = None 
     247            del dll, self._dll 
     248            self._dll = None 
    229249            #_ctypes.FreeLibrary(libHandle) 
    230250            ct.windll.kernel32.FreeLibrary(libHandle) 
     
    233253 
    234254 
    235 class DllKernel(object): 
     255class DllKernel(Kernel): 
    236256    """ 
    237257    Callable SAS kernel. 
     
    253273    """ 
    254274    def __init__(self, kernel, model_info, q_input): 
     275        # type: (Callable[[], np.ndarray], ModelInfo, PyInput) -> None 
    255276        self.kernel = kernel 
    256277        self.info = model_info 
     
    261282 
    262283    def __call__(self, call_details, weights, values, cutoff): 
     284        # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 
    263285        real = (np.float32 if self.q_input.dtype == generate.F32 
    264286                else np.float64 if self.q_input.dtype == generate.F64 
     
    282304            real(cutoff), # cutoff 
    283305            ] 
    284         self.kernel(*args) 
     306        self.kernel(*args) # type: ignore 
    285307        return self.result[:-3] 
    286308 
    287309    def release(self): 
     310        # type: () -> None 
    288311        """ 
    289312        Release any resources associated with the kernel. 
    290313        """ 
    291         pass 
     314        self.q_input.release() 
Note: See TracChangeset for help on using the changeset viewer.