source: sasmodels/sasmodels/kerneldll.py @ eafc9fa

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since eafc9fa was eafc9fa, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor kernel wrappers to simplify q input handling

  • Property mode set to 100644
File size: 10.8 KB
Line 
1r"""
2DLL driver for C kernels
3
4The global attribute *ALLOW_SINGLE_PRECISION_DLLS* should be set to *True* if
5you wish to allow single precision floating point evaluation for the compiled
6models, otherwise it defaults to *False*.
7
8The compiler command line is stored in the attribute *COMPILE*, with string
9substitutions for %(source)s and %(output)s indicating what to compile and
10where to store it.  The actual command is system dependent.
11
12On windows systems, you have a choice of compilers.  *MinGW* is the GNU
13compiler toolchain, available in packages such as anaconda and PythonXY,
14or available stand alone. This toolchain has had difficulties on some
15systems, and may or may not work for you.  In order to build DLLs, *gcc*
16must be on your path.  If the environment variable *SAS_OPENMP* is given
17then -fopenmp is added to the compiler flags.  This requires a version
18of MinGW compiled with OpenMP support.
19
20An alternative toolchain uses the Microsoft Visual C++ compiler, available
21free from microsoft:
22
23    `<http://www.microsoft.com/en-us/download/details.aspx?id=44266>`_
24
25Again, this requires that the compiler is available on your path.  This is
26done by running vcvarsall.bat in a windows terminal.  Install locations are
27system dependent, such as:
28
29    C:\Program Files (x86)\Common Files\Microsoft\Visual C++ for Python\9.0\vcvarsall.bat
30
31or maybe
32
33    C:\Users\yourname\AppData\Local\Programs\Common\Microsoft\Visual C++ for Python\9.0\vcvarsall.bat
34
35And again, the environment variable *SAS_OPENMP* controls whether OpenMP is
36used to compile the C code.  This requires the Microsoft vcomp90.dll library,
37which doesn't seem to be included with the compiler, nor does there appear
38to be a public download location.  There may be one on your machine already
39in a location such as:
40
41    C:\Windows\winsxs\x86_microsoft.vc90.openmp*\vcomp90.dll
42
43If you copy this onto your path, such as the python directory or the install
44directory for this application, then OpenMP should be supported.
45"""
46from __future__ import print_function
47
48import sys
49import os
50import tempfile
51import ctypes as ct
52from ctypes import c_void_p, c_int, c_longdouble, c_double, c_float
53
54import numpy as np
55
56from . import generate
57from .kernelpy import PyInput, PyModel
58from .exception import annotate_exception
59
60# Compiler platform details
61if sys.platform == 'darwin':
62    #COMPILE = "gcc-mp-4.7 -shared -fPIC -std=c99 -fopenmp -O2 -Wall %s -o %s -lm -lgomp"
63    COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
64elif os.name == 'nt':
65    # call vcvarsall.bat before compiling to set path, headers, libs, etc.
66    if "VCINSTALLDIR" in os.environ:
67        # MSVC compiler is available, so use it.  OpenMP requires a copy of
68        # vcomp90.dll on the path.  One may be found here:
69        #       C:/Windows/winsxs/x86_microsoft.vc90.openmp*/vcomp90.dll
70        # Copy this to the python directory and uncomment the OpenMP COMPILE
71        # TODO: remove intermediate OBJ file created in the directory
72        # TODO: maybe don't use randomized name for the c file
73        CC = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s "
74        LN = "/link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
75        if "SAS_OPENMP" in os.environ:
76            COMPILE = " ".join((CC, "/openmp", LN))
77        else:
78            COMPILE = " ".join((CC, LN))
79    else:
80        COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
81        if "SAS_OPENMP" in os.environ:
82            COMPILE = COMPILE + " -fopenmp"
83else:
84    COMPILE = "cc -shared -fPIC -fopenmp -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
85
86DLL_PATH = tempfile.gettempdir()
87
88ALLOW_SINGLE_PRECISION_DLLS = True
89
90
91def dll_path(info, dtype="double"):
92    """
93    Path to the compiled model defined by *info*.
94    """
95    from os.path import join as joinpath, split as splitpath, splitext
96    basename = splitext(splitpath(info['filename'])[1])[0]
97    if np.dtype(dtype) == generate.F32:
98        basename += "32"
99    elif np.dtype(dtype) == generate.F64:
100        basename += "64"
101    else:
102        basename += "128"
103    return joinpath(DLL_PATH, basename+'.so')
104
105
106def make_dll(source, info, dtype="double"):
107    """
108    Load the compiled model defined by *kernel_module*.
109
110    Recompile if any files are newer than the model file.
111
112    *dtype* is a numpy floating point precision specifier indicating whether
113    the model should be single or double precision.  The default is double
114    precision.
115
116    The DLL is not loaded until the kernel is called so models can
117    be defined without using too many resources.
118
119    Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path.
120    The default is the system temporary directory.
121
122    Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to True if single precision
123    models are allowed as DLLs.
124    """
125    if callable(info.get('Iq', None)):
126        return PyModel(info)
127
128    dtype = np.dtype(dtype)
129    if dtype == generate.F16:
130        raise ValueError("16 bit floats not supported")
131    if dtype == generate.F32 and not ALLOW_SINGLE_PRECISION_DLLS:
132        dtype = generate.F64  # Force 64-bit dll
133
134    if dtype == generate.F32: # 32-bit dll
135        tempfile_prefix = 'sas_'+info['name']+'32_'
136    elif dtype == generate.F64:
137        tempfile_prefix = 'sas_'+info['name']+'64_'
138    else:
139        tempfile_prefix = 'sas_'+info['name']+'128_'
140
141    source = generate.convert_type(source, dtype)
142    source_files = generate.model_sources(info) + [info['filename']]
143    dll = dll_path(info, dtype)
144    newest = max(os.path.getmtime(f) for f in source_files)
145    if not os.path.exists(dll) or os.path.getmtime(dll) < newest:
146        # Replace with a proper temp file
147        fid, filename = tempfile.mkstemp(suffix=".c", prefix=tempfile_prefix)
148        os.fdopen(fid, "w").write(source)
149        command = COMPILE%{"source":filename, "output":dll}
150        print("Compile command: "+command)
151        status = os.system(command)
152        if status != 0 or not os.path.exists(dll):
153            raise RuntimeError("compile failed.  File is in %r"%filename)
154        else:
155            ## comment the following to keep the generated c file
156            os.unlink(filename)
157            #print("saving compiled file in %r"%filename)
158    return dll
159
160
161def load_dll(source, info, dtype="double"):
162    """
163    Create and load a dll corresponding to the source, info pair returned
164    from :func:`sasmodels.generate.make` compiled for the target precision.
165
166    See :func:`make_dll` for details on controlling the dll path and the
167    allowed floating point precision.
168    """
169    filename = make_dll(source, info, dtype=dtype)
170    return DllModel(filename, info, dtype=dtype)
171
172
173IQ_ARGS = [c_void_p, c_void_p, c_int]
174IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int]
175
176class DllModel(object):
177    """
178    ctypes wrapper for a single model.
179
180    *source* and *info* are the model source and interface as returned
181    from :func:`gen.make`.
182
183    *dtype* is the desired model precision.  Any numpy dtype for single
184    or double precision floats will do, such as 'f', 'float32' or 'single'
185    for single and 'd', 'float64' or 'double' for double.  Double precision
186    is an optional extension which may not be available on all devices.
187
188    Call :meth:`release` when done with the kernel.
189    """
190    def __init__(self, dllpath, info, dtype=generate.F32):
191        self.info = info
192        self.dllpath = dllpath
193        self.dll = None
194        self.dtype = np.dtype(dtype)
195
196    def _load_dll(self):
197        Nfixed1d = len(self.info['partype']['fixed-1d'])
198        Nfixed2d = len(self.info['partype']['fixed-2d'])
199        Npd1d = len(self.info['partype']['pd-1d'])
200        Npd2d = len(self.info['partype']['pd-2d'])
201
202        #print("dll", self.dllpath)
203        try:
204            self.dll = ct.CDLL(self.dllpath)
205        except Exception as exc:
206            annotate_exception(exc, "while loading "+self.dllpath)
207            raise
208
209        fp = (c_float if self.dtype == generate.F32
210              else c_double if self.dtype == generate.F64
211              else c_longdouble)
212        pd_args_1d = [c_void_p, fp] + [c_int]*Npd1d if Npd1d else []
213        pd_args_2d = [c_void_p, fp] + [c_int]*Npd2d if Npd2d else []
214        self.Iq = self.dll[generate.kernel_name(self.info, False)]
215        self.Iq.argtypes = IQ_ARGS + pd_args_1d + [fp]*Nfixed1d
216
217        self.Iqxy = self.dll[generate.kernel_name(self.info, True)]
218        self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [fp]*Nfixed2d
219
220    def __getstate__(self):
221        return self.info, self.dllpath
222
223    def __setstate__(self, state):
224        self.info, self.dllpath = state
225        self.dll = None
226
227    def __call__(self, q_vectors):
228        q_input = PyInput(q_vectors, self.dtype)
229        if self.dll is None: self._load_dll()
230        kernel = self.Iqxy if q_input.is_2d else self.Iq
231        return DllKernel(kernel, self.info, q_input)
232
233    def release(self):
234        """
235        Release any resources associated with the model.
236        """
237        pass # TODO: should release the dll
238
239
240class DllKernel(object):
241    """
242    Callable SAS kernel.
243
244    *kernel* is the c function to call.
245
246    *info* is the module information
247
248    *q_input* is the DllInput q vectors at which the kernel should be
249    evaluated.
250
251    The resulting call method takes the *pars*, a list of values for
252    the fixed parameters to the kernel, and *pd_pars*, a list of (value, weight)
253    vectors for the polydisperse parameters.  *cutoff* determines the
254    integration limits: any points with combined weight less than *cutoff*
255    will not be calculated.
256
257    Call :meth:`release` when done with the kernel instance.
258    """
259    def __init__(self, kernel, info, q_input):
260        self.info = info
261        self.q_input = q_input
262        self.kernel = kernel
263        self.res = np.empty(q_input.nq, q_input.dtype)
264        dim = '2d' if q_input.is_2d else '1d'
265        self.fixed_pars = info['partype']['fixed-'+dim]
266        self.pd_pars = info['partype']['pd-'+dim]
267
268        # In dll kernel, but not in opencl kernel
269        self.p_res = self.res.ctypes.data
270
271    def __call__(self, fixed_pars, pd_pars, cutoff):
272        real = (np.float32 if self.q_input.dtype == generate.F32
273                else np.float64 if self.q_input.dtype == generate.F64
274                else np.float128)
275
276        nq = c_int(self.q_input.nq)
277        if pd_pars:
278            cutoff = real(cutoff)
279            loops_N = [np.uint32(len(p[0])) for p in pd_pars]
280            loops = np.hstack(pd_pars)
281            loops = np.ascontiguousarray(loops.T, self.q_input.dtype).flatten()
282            p_loops = loops.ctypes.data
283            dispersed = [p_loops, cutoff] + loops_N
284        else:
285            dispersed = []
286        fixed = [real(p) for p in fixed_pars]
287        args = self.q_input.q_pointers + [self.p_res, nq] + dispersed + fixed
288        #print(pars)
289        self.kernel(*args)
290
291        return self.res
292
293    def release(self):
294        """
295        Release any resources associated with the kernel.
296        """
297        pass
Note: See TracBrowser for help on using the repository browser.