source: sasmodels/sasmodels/kerneldll.py @ 63b32bb

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 63b32bb was 63b32bb, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

lint cleaning

  • Property mode set to 100644
File size: 7.8 KB
Line 
1"""
2C types wrapper for sasview models.
3
4The global attribute *ALLOW_SINGLE_PRECISION_DLLS* should be set to *True* if
5you wish to allow single precision floating point evaluation for the compiled
6models, otherwise it defaults to *False*.
7"""
8
9import sys
10import os
11import tempfile
12import ctypes as ct
13from ctypes import c_void_p, c_int, c_double, c_float
14
15import numpy as np
16
17from . import generate
18from .kernelpy import PyInput, PyModel
19
20# Compiler platform details
21if sys.platform == 'darwin':
22    #COMPILE = "gcc-mp-4.7 -shared -fPIC -std=c99 -fopenmp -O2 -Wall %s -o %s -lm -lgomp"
23    COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
24elif os.name == 'nt':
25    # call vcvarsall.bat before compiling to set path, headers, libs, etc.
26    if "VCINSTALLDIR" in os.environ:
27        # MSVC compiler is available, so use it.
28        # TODO: remove intermediate OBJ file created in the directory
29        # TODO: maybe don't use randomized name for the c file
30        COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /openmp /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
31        # Can't find VCOMP90.DLL (don't know why), so remove openmp support
32        # from windows compiler build
33        #COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
34    else:
35        #COMPILE = "gcc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
36        COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
37else:
38    COMPILE = "cc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
39
40DLL_PATH = tempfile.gettempdir()
41
42ALLOW_SINGLE_PRECISION_DLLS = False
43
44
45def dll_path(info, dtype="double"):
46    """
47    Path to the compiled model defined by *info*.
48    """
49    from os.path import join as joinpath, split as splitpath, splitext
50    basename = splitext(splitpath(info['filename'])[1])[0]
51    if np.dtype(dtype) == generate.F32:
52        basename += "32"
53    return joinpath(DLL_PATH, basename+'.so')
54
55
56def load_model(kernel_module, dtype="double"):
57    """
58    Load the compiled model defined by *kernel_module*.
59
60    Recompile if any files are newer than the model file.
61
62    *dtype* is ignored.  Compiled files are always double.
63
64    The DLL is not loaded until the kernel is called so models an
65    be defined without using too many resources.
66    """
67    if not ALLOW_SINGLE_PRECISION_DLLS: dtype = "double"   # Force 64-bit dll
68    dtype = np.dtype(dtype)
69
70    source, info = generate.make(kernel_module)
71    if callable(info.get('Iq',None)):
72        return PyModel(info)
73
74    if dtype == generate.F32: # 32-bit dll
75        source = generate.use_single(source)
76        tempfile_prefix = 'sas_'+info['name']+'32_'
77    else:
78        tempfile_prefix = 'sas_'+info['name']+'_'
79
80    source_files = generate.sources(info) + [info['filename']]
81    dllpath = dll_path(info, dtype)
82    newest = max(os.path.getmtime(f) for f in source_files)
83    if not os.path.exists(dllpath) or os.path.getmtime(dllpath)<newest:
84        # Replace with a proper temp file
85        fid, filename = tempfile.mkstemp(suffix=".c",prefix=tempfile_prefix)
86        os.fdopen(fid,"w").write(source)
87        command = COMPILE%{"source":filename, "output":dllpath}
88        print "Compile command:",command
89        status = os.system(command)
90        if status != 0:
91            raise RuntimeError("compile failed.  File is in %r"%filename)
92        else:
93            ## uncomment the following to keep the generated c file
94            #os.unlink(filename); print "saving compiled file in %r"%filename
95            pass
96    return DllModel(dllpath, info, dtype=dtype)
97
98
99IQ_ARGS = [c_void_p, c_void_p, c_int]
100IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int]
101
102class DllModel(object):
103    """
104    ctypes wrapper for a single model.
105
106    *source* and *info* are the model source and interface as returned
107    from :func:`gen.make`.
108
109    *dtype* is the desired model precision.  Any numpy dtype for single
110    or double precision floats will do, such as 'f', 'float32' or 'single'
111    for single and 'd', 'float64' or 'double' for double.  Double precision
112    is an optional extension which may not be available on all devices.
113
114    Call :meth:`release` when done with the kernel.
115    """
116    def __init__(self, dllpath, info, dtype=generate.F32):
117        self.info = info
118        self.dllpath = dllpath
119        self.dll = None
120        self.dtype = np.dtype(dtype)
121
122    def _load_dll(self):
123        Nfixed1d = len(self.info['partype']['fixed-1d'])
124        Nfixed2d = len(self.info['partype']['fixed-2d'])
125        Npd1d = len(self.info['partype']['pd-1d'])
126        Npd2d = len(self.info['partype']['pd-2d'])
127
128        #print "dll",self.dllpath
129        self.dll = ct.CDLL(self.dllpath)
130
131        fp = c_float if self.dtype == generate.F32 else c_double
132        pd_args_1d = [c_void_p, fp] + [c_int]*Npd1d if Npd1d else []
133        pd_args_2d= [c_void_p, fp] + [c_int]*Npd2d if Npd2d else []
134        self.Iq = self.dll[generate.kernel_name(self.info, False)]
135        self.Iq.argtypes = IQ_ARGS + pd_args_1d + [fp]*Nfixed1d
136
137        self.Iqxy = self.dll[generate.kernel_name(self.info, True)]
138        self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [fp]*Nfixed2d
139
140    def __getstate__(self):
141        return {'info': self.info, 'dllpath': self.dllpath, 'dll': None}
142
143    def __setstate__(self, state):
144        self.__dict__ = state
145
146    def __call__(self, q_input):
147        if self.dtype != q_input.dtype:
148            raise TypeError("data is %s kernel is %s" % (q_input.dtype, self.dtype))
149        if self.dll is None: self._load_dll()
150        kernel = self.Iqxy if q_input.is_2D else self.Iq
151        return DllKernel(kernel, self.info, q_input)
152
153    # pylint: disable=no-self-use
154    def make_input(self, q_vectors):
155        """
156        Make q input vectors available to the model.
157
158        Note that each model needs its own q vector even if the case of
159        mixture models because some models may be OpenCL, some may be
160        ctypes and some may be pure python.
161        """
162        return PyInput(q_vectors, dtype=self.dtype)
163
164    def release(self):
165        pass # TODO: should release the dll
166
167
168class DllKernel(object):
169    """
170    Callable SAS kernel.
171
172    *kernel* is the c function to call.
173
174    *info* is the module information
175
176    *q_input* is the DllInput q vectors at which the kernel should be
177    evaluated.
178
179    The resulting call method takes the *pars*, a list of values for
180    the fixed parameters to the kernel, and *pd_pars*, a list of (value,weight)
181    vectors for the polydisperse parameters.  *cutoff* determines the
182    integration limits: any points with combined weight less than *cutoff*
183    will not be calculated.
184
185    Call :meth:`release` when done with the kernel instance.
186    """
187    def __init__(self, kernel, info, q_input):
188        self.info = info
189        self.q_input = q_input
190        self.kernel = kernel
191        self.res = np.empty(q_input.nq, q_input.dtype)
192        dim = '2d' if q_input.is_2D else '1d'
193        self.fixed_pars = info['partype']['fixed-'+dim]
194        self.pd_pars = info['partype']['pd-'+dim]
195
196        # In dll kernel, but not in opencl kernel
197        self.p_res = self.res.ctypes.data
198
199    def __call__(self, fixed_pars, pd_pars, cutoff):
200        real = np.float32 if self.q_input.dtype == generate.F32 else np.float64
201
202        nq = c_int(self.q_input.nq)
203        if pd_pars:
204            cutoff = real(cutoff)
205            loops_N = [np.uint32(len(p[0])) for p in pd_pars]
206            loops = np.hstack(pd_pars)
207            loops = np.ascontiguousarray(loops.T, self.q_input.dtype).flatten()
208            p_loops = loops.ctypes.data
209            dispersed = [p_loops, cutoff] + loops_N
210        else:
211            dispersed = []
212        fixed = [real(p) for p in fixed_pars]
213        args = self.q_input.q_pointers + [self.p_res, nq] + dispersed + fixed
214        #print pars
215        self.kernel(*args)
216
217        return self.res
218
219    def release(self):
220        pass
Note: See TracBrowser for help on using the repository browser.