Changes in sasmodels/kerneldll.py [4d76711:a5b8477] in sasmodels
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
sasmodels/kerneldll.py
r4d76711 ra5b8477 49 49 import os 50 50 import tempfile 51 import ctypes as ct 52 from ctypes import c_void_p, c_int, c_longdouble, c_double, c_float 53 import _ctypes 54 55 import numpy as np 51 import ctypes as ct # type: ignore 52 from ctypes import c_void_p, c_int32, c_longdouble, c_double, c_float # type: ignore 53 54 import numpy as np # type: ignore 56 55 57 56 from . import generate 58 from .kernelpy import PyInput, PyModel 57 from .kernel import KernelModel, Kernel 58 from .kernelpy import PyInput 59 59 from .exception import annotate_exception 60 from .generate import F16, F32, F64 61 62 try: 63 from typing import Tuple, Callable, Any 64 from .modelinfo import ModelInfo 65 from .details import CallDetails 66 except ImportError: 67 pass 60 68 61 69 # Compiler platform details … … 81 89 COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm" 82 90 if "SAS_OPENMP" in os.environ: 83 COMPILE = COMPILE +" -fopenmp"91 COMPILE += " -fopenmp" 84 92 else: 85 93 COMPILE = "cc -shared -fPIC -fopenmp -std=c99 -O2 -Wall %(source)s -o %(output)s -lm" … … 90 98 91 99 92 def dll_path(model_info, dtype="double"): 93 """ 94 Path to the compiled model defined by *model_info*. 95 """ 96 from os.path import join as joinpath, split as splitpath, splitext 97 basename = splitext(splitpath(model_info['filename'])[1])[0] 98 if np.dtype(dtype) == generate.F32: 99 basename += "32" 100 elif np.dtype(dtype) == generate.F64: 101 basename += "64" 102 else: 103 basename += "128" 104 return joinpath(DLL_PATH, basename+'.so') 105 106 107 def make_dll(source, model_info, dtype="double"): 108 """ 109 Load the compiled model defined by *kernel_module*. 110 111 Recompile if any files are newer than the model file. 100 def dll_name(model_info, dtype): 101 # type: (ModelInfo, np.dtype) -> str 102 """ 103 Name of the dll containing the model. This is the base file name without 104 any path or extension, with a form such as 'sas_sphere32'. 105 """ 106 bits = 8*dtype.itemsize 107 return "sas_%s%d"%(model_info.id, bits) 108 109 110 def dll_path(model_info, dtype): 111 # type: (ModelInfo, np.dtype) -> str 112 """ 113 Complete path to the dll for the model. Note that the dll may not 114 exist yet if it hasn't been compiled. 115 """ 116 return os.path.join(DLL_PATH, dll_name(model_info, dtype)+".so") 117 118 119 def make_dll(source, model_info, dtype=F64): 120 # type: (str, ModelInfo, np.dtype) -> str 121 """ 122 Returns the path to the compiled model defined by *kernel_module*. 123 124 If the model has not been compiled, or if the source file(s) are newer 125 than the dll, then *make_dll* will compile the model before returning. 126 This routine does not load the resulting dll. 112 127 113 128 *dtype* is a numpy floating point precision specifier indicating whether 114 the model should be single or double precision. The default is double115 precision.116 117 The DLL is not loaded until the kernel is called so models can118 be defined without using too many resources.129 the model should be single, double or long double precision. The default 130 is double precision, *np.dtype('d')*. 131 132 Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to False if single precision 133 models are not allowed as DLLs. 119 134 120 135 Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path. 121 136 The default is the system temporary directory. 122 123 Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to True if single precision 124 models are allowed as DLLs. 125 """ 126 if callable(model_info.get('Iq', None)): 127 return PyModel(model_info) 128 129 dtype = np.dtype(dtype) 130 if dtype == generate.F16: 137 """ 138 if dtype == F16: 131 139 raise ValueError("16 bit floats not supported") 132 if dtype == generate.F32 and not ALLOW_SINGLE_PRECISION_DLLS: 133 dtype = generate.F64 # Force 64-bit dll 134 135 if dtype == generate.F32: # 32-bit dll 136 tempfile_prefix = 'sas_' + model_info['name'] + '32_' 137 elif dtype == generate.F64: 138 tempfile_prefix = 'sas_' + model_info['name'] + '64_' 139 else: 140 tempfile_prefix = 'sas_' + model_info['name'] + '128_' 141 142 source = generate.convert_type(source, dtype) 143 source_files = generate.model_sources(model_info) + [model_info['filename']] 140 if dtype == F32 and not ALLOW_SINGLE_PRECISION_DLLS: 141 dtype = F64 # Force 64-bit dll 142 # Note: dtype may be F128 for long double precision 143 144 newest = generate.timestamp(model_info) 144 145 dll = dll_path(model_info, dtype) 145 newest = max(os.path.getmtime(f) for f in source_files)146 146 if not os.path.exists(dll) or os.path.getmtime(dll) < newest: 147 # Replace with a proper temp file 148 fid, filename = tempfile.mkstemp(suffix=".c", prefix=tempfile_prefix) 147 basename = dll_name(model_info, dtype) + "_" 148 fid, filename = tempfile.mkstemp(suffix=".c", prefix=basename) 149 source = generate.convert_type(source, dtype) 149 150 os.fdopen(fid, "w").write(source) 150 151 command = COMPILE%{"source":filename, "output":dll} … … 160 161 161 162 162 def load_dll(source, model_info, dtype="double"): 163 def load_dll(source, model_info, dtype=F64): 164 # type: (str, ModelInfo, np.dtype) -> "DllModel" 163 165 """ 164 166 Create and load a dll corresponding to the source, info pair returned … … 172 174 173 175 174 IQ_ARGS = [c_void_p, c_void_p, c_int] 175 IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int] 176 177 class DllModel(object): 176 class DllModel(KernelModel): 178 177 """ 179 178 ctypes wrapper for a single model. … … 191 190 192 191 def __init__(self, dllpath, model_info, dtype=generate.F32): 192 # type: (str, ModelInfo, np.dtype) -> None 193 193 self.info = model_info 194 194 self.dllpath = dllpath 195 self. dll = None195 self._dll = None # type: ct.CDLL 196 196 self.dtype = np.dtype(dtype) 197 197 198 198 def _load_dll(self): 199 Nfixed1d = len(self.info['partype']['fixed-1d']) 200 Nfixed2d = len(self.info['partype']['fixed-2d']) 201 Npd1d = len(self.info['partype']['pd-1d']) 202 Npd2d = len(self.info['partype']['pd-2d']) 203 199 # type: () -> None 204 200 #print("dll", self.dllpath) 205 201 try: 206 self. dll = ct.CDLL(self.dllpath)202 self._dll = ct.CDLL(self.dllpath) 207 203 except: 208 204 annotate_exception("while loading "+self.dllpath) … … 212 208 else c_double if self.dtype == generate.F64 213 209 else c_longdouble) 214 pd_args_1d = [c_void_p, fp] + [c_int]*Npd1d if Npd1d else [] 215 pd_args_2d = [c_void_p, fp] + [c_int]*Npd2d if Npd2d else [] 216 self.Iq = self.dll[generate.kernel_name(self.info, False)] 217 self.Iq.argtypes = IQ_ARGS + pd_args_1d + [fp]*Nfixed1d 218 219 self.Iqxy = self.dll[generate.kernel_name(self.info, True)] 220 self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [fp]*Nfixed2d 221 222 self.release() 210 211 # int, int, int, int*, double*, double*, double*, double*, double*, double 212 argtypes = [c_int32]*3 + [c_void_p]*5 + [fp] 213 self._Iq = self._dll[generate.kernel_name(self.info, is_2d=False)] 214 self._Iqxy = self._dll[generate.kernel_name(self.info, is_2d=True)] 215 self._Iq.argtypes = argtypes 216 self._Iqxy.argtypes = argtypes 223 217 224 218 def __getstate__(self): 219 # type: () -> Tuple[ModelInfo, str] 225 220 return self.info, self.dllpath 226 221 227 222 def __setstate__(self, state): 223 # type: (Tuple[ModelInfo, str]) -> None 228 224 self.info, self.dllpath = state 229 self. dll = None225 self._dll = None 230 226 231 227 def make_kernel(self, q_vectors): 228 # type: (List[np.ndarray]) -> DllKernel 232 229 q_input = PyInput(q_vectors, self.dtype) 233 if self.dll is None: self._load_dll() 234 kernel = self.Iqxy if q_input.is_2d else self.Iq 230 # Note: pickle not supported for DllKernel 231 if self._dll is None: 232 self._load_dll() 233 kernel = self._Iqxy if q_input.is_2d else self._Iq 235 234 return DllKernel(kernel, self.info, q_input) 236 235 237 236 def release(self): 237 # type: () -> None 238 238 """ 239 239 Release any resources associated with the model. … … 244 244 libHandle = dll._handle 245 245 #libHandle = ct.c_void_p(dll._handle) 246 del dll, self. dll247 self. dll = None246 del dll, self._dll 247 self._dll = None 248 248 #_ctypes.FreeLibrary(libHandle) 249 249 ct.windll.kernel32.FreeLibrary(libHandle) … … 252 252 253 253 254 class DllKernel( object):254 class DllKernel(Kernel): 255 255 """ 256 256 Callable SAS kernel. … … 272 272 """ 273 273 def __init__(self, kernel, model_info, q_input): 274 # type: (Callable[[], np.ndarray], ModelInfo, PyInput) -> None 275 self.kernel = kernel 274 276 self.info = model_info 275 277 self.q_input = q_input 276 self.kernel = kernel 277 self.res = np.empty(q_input.nq, q_input.dtype) 278 dim = '2d' if q_input.is_2d else '1d' 279 self.fixed_pars = model_info['partype']['fixed-' + dim] 280 self.pd_pars = model_info['partype']['pd-' + dim] 281 282 # In dll kernel, but not in opencl kernel 283 self.p_res = self.res.ctypes.data 284 285 def __call__(self, fixed_pars, pd_pars, cutoff): 286 real = (np.float32 if self.q_input.dtype == generate.F32 287 else np.float64 if self.q_input.dtype == generate.F64 288 else np.float128) 289 290 nq = c_int(self.q_input.nq) 291 if pd_pars: 292 cutoff = real(cutoff) 293 loops_N = [np.uint32(len(p[0])) for p in pd_pars] 294 loops = np.hstack(pd_pars) 295 loops = np.ascontiguousarray(loops.T, self.q_input.dtype).flatten() 296 p_loops = loops.ctypes.data 297 dispersed = [p_loops, cutoff] + loops_N 298 else: 299 dispersed = [] 300 fixed = [real(p) for p in fixed_pars] 301 args = self.q_input.q_pointers + [self.p_res, nq] + dispersed + fixed 302 #print(pars) 303 self.kernel(*args) 304 305 return self.res 278 self.dtype = q_input.dtype 279 self.dim = '2d' if q_input.is_2d else '1d' 280 self.result = np.empty(q_input.nq+1, q_input.dtype) 281 self.real = (np.float32 if self.q_input.dtype == generate.F32 282 else np.float64 if self.q_input.dtype == generate.F64 283 else np.float128) 284 285 def __call__(self, call_details, weights, values, cutoff): 286 # type: (CallDetails, np.ndarray, np.ndarray, float) -> np.ndarray 287 288 #print("in kerneldll") 289 #print("weights", weights) 290 #print("values", values) 291 start, stop = 0, call_details.total_pd 292 args = [ 293 self.q_input.nq, # nq 294 start, # pd_start 295 stop, # pd_stop pd_stride[MAX_PD] 296 call_details.ctypes.data, # problem 297 weights.ctypes.data, # weights 298 values.ctypes.data, #pars 299 self.q_input.q.ctypes.data, #q 300 self.result.ctypes.data, # results 301 self.real(cutoff), # cutoff 302 ] 303 self.kernel(*args) # type: ignore 304 return self.result[:-1] 306 305 307 306 def release(self): 307 # type: () -> None 308 308 """ 309 309 Release any resources associated with the kernel. 310 310 """ 311 pass311 self.q_input.release()
Note: See TracChangeset
for help on using the changeset viewer.