source: sasmodels/sasmodels/kerneldll.py @ 750ffa5

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 750ffa5 was 750ffa5, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

allow test of dll using single precision

  • Property mode set to 100644
File size: 7.9 KB
Line 
1"""
2C types wrapper for sasview models.
3
4The global attribute *ALLOW_SINGLE_PRECISION_DLLS* should be set to *True* if
5you wish to allow single precision floating point evaluation for the compiled
6models, otherwise it defaults to *False*.
7"""
8
9import sys
10import os
11import tempfile
12import ctypes as ct
13from ctypes import c_void_p, c_int, c_double, c_float
14
15import numpy as np
16
17from . import generate
18from .kernelpy import PyInput, PyModel
19
20from .generate import F32, F64
21# Compiler platform details
22if sys.platform == 'darwin':
23    #COMPILE = "gcc-mp-4.7 -shared -fPIC -std=c99 -fopenmp -O2 -Wall %s -o %s -lm -lgomp"
24    COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
25elif os.name == 'nt':
26    # call vcvarsall.bat before compiling to set path, headers, libs, etc.
27    if "VCINSTALLDIR" in os.environ:
28        # MSVC compiler is available, so use it.
29        # TODO: remove intermediate OBJ file created in the directory
30        # TODO: maybe don't use randomized name for the c file
31        COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /openmp /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
32        # Can't find VCOMP90.DLL (don't know why), so remove openmp support
33        # from windows compiler build
34        #COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
35    else:
36        #COMPILE = "gcc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
37        COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
38else:
39    COMPILE = "cc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
40
41DLL_PATH = tempfile.gettempdir()
42
43ALLOW_SINGLE_PRECISION_DLLS = False
44
45
46def dll_path(info, dtype="double"):
47    """
48    Path to the compiled model defined by *info*.
49    """
50    from os.path import join as joinpath, split as splitpath, splitext
51    basename = splitext(splitpath(info['filename'])[1])[0]
52    if np.dtype(dtype) == generate.F32:
53        basename += "32"
54    return joinpath(DLL_PATH, basename+'.so')
55
56
57def load_model(kernel_module, dtype="double"):
58    """
59    Load the compiled model defined by *kernel_module*.
60
61    Recompile if any files are newer than the model file.
62
63    *dtype* is ignored.  Compiled files are always double.
64
65    The DLL is not loaded until the kernel is called so models an
66    be defined without using too many resources.
67    """
68    if not ALLOW_SINGLE_PRECISION_DLLS: dtype = "double"   # Force 64-bit dll
69    dtype = np.dtype(dtype)
70
71    source, info = generate.make(kernel_module)
72    if callable(info.get('Iq',None)):
73        return PyModel(info)
74
75    if dtype == generate.F32: # 32-bit dll
76        source = generate.use_single(source)
77        tempfile_prefix = 'sas_'+info['name']+'32_'
78    else:
79        tempfile_prefix = 'sas_'+info['name']+'_'
80
81    source_files = generate.sources(info) + [info['filename']]
82    dllpath = dll_path(info, dtype)
83    newest = max(os.path.getmtime(f) for f in source_files)
84    if not os.path.exists(dllpath) or os.path.getmtime(dllpath)<newest:
85        # Replace with a proper temp file
86        fid, filename = tempfile.mkstemp(suffix=".c",prefix=tempfile_prefix)
87        os.fdopen(fid,"w").write(source)
88        command = COMPILE%{"source":filename, "output":dllpath}
89        print "Compile command:",command
90        status = os.system(command)
91        if status != 0:
92            raise RuntimeError("compile failed.  File is in %r"%filename)
93        else:
94            ## uncomment the following to keep the generated c file
95            #os.unlink(filename); print "saving compiled file in %r"%filename
96            pass
97    return DllModel(dllpath, info, dtype=dtype)
98
99
100IQ_ARGS = [c_void_p, c_void_p, c_int]
101IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int]
102
103class DllModel(object):
104    """
105    ctypes wrapper for a single model.
106
107    *source* and *info* are the model source and interface as returned
108    from :func:`gen.make`.
109
110    *dtype* is the desired model precision.  Any numpy dtype for single
111    or double precision floats will do, such as 'f', 'float32' or 'single'
112    for single and 'd', 'float64' or 'double' for double.  Double precision
113    is an optional extension which may not be available on all devices.
114
115    Call :meth:`release` when done with the kernel.
116    """
117    def __init__(self, dllpath, info, dtype=generate.F32):
118        self.info = info
119        self.dllpath = dllpath
120        self.dll = None
121        self.dtype = np.dtype(dtype)
122
123    def _load_dll(self):
124        Nfixed1d = len(self.info['partype']['fixed-1d'])
125        Nfixed2d = len(self.info['partype']['fixed-2d'])
126        Npd1d = len(self.info['partype']['pd-1d'])
127        Npd2d = len(self.info['partype']['pd-2d'])
128
129        #print "dll",self.dllpath
130        self.dll = ct.CDLL(self.dllpath)
131
132        fp = c_float if self.dtype == generate.F32 else c_double
133        pd_args_1d = [c_void_p, fp] + [c_int]*Npd1d if Npd1d else []
134        pd_args_2d= [c_void_p, fp] + [c_int]*Npd2d if Npd2d else []
135        self.Iq = self.dll[generate.kernel_name(self.info, False)]
136        self.Iq.argtypes = IQ_ARGS + pd_args_1d + [fp]*Nfixed1d
137
138        self.Iqxy = self.dll[generate.kernel_name(self.info, True)]
139        self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [fp]*Nfixed2d
140
141    def __getstate__(self):
142        return {'info': self.info, 'dllpath': self.dllpath, 'dll': None}
143
144    def __setstate__(self, state):
145        self.__dict__ = state
146
147    def __call__(self, q_input):
148        if self.dtype != q_input.dtype:
149            raise TypeError("data is %s kernel is %s" % (q_input.dtype, self.dtype))
150        if self.dll is None: self._load_dll()
151        kernel = self.Iqxy if q_input.is_2D else self.Iq
152        return DllKernel(kernel, self.info, q_input)
153
154    # pylint: disable=no-self-use
155    def make_input(self, q_vectors):
156        """
157        Make q input vectors available to the model.
158
159        Note that each model needs its own q vector even if the case of
160        mixture models because some models may be OpenCL, some may be
161        ctypes and some may be pure python.
162        """
163        return PyInput(q_vectors, dtype=self.dtype)
164
165    def release(self):
166        pass # TODO: should release the dll
167
168
169class DllKernel(object):
170    """
171    Callable SAS kernel.
172
173    *kernel* is the c function to call.
174
175    *info* is the module information
176
177    *q_input* is the DllInput q vectors at which the kernel should be
178    evaluated.
179
180    The resulting call method takes the *pars*, a list of values for
181    the fixed parameters to the kernel, and *pd_pars*, a list of (value,weight)
182    vectors for the polydisperse parameters.  *cutoff* determines the
183    integration limits: any points with combined weight less than *cutoff*
184    will not be calculated.
185
186    Call :meth:`release` when done with the kernel instance.
187    """
188    def __init__(self, kernel, info, q_input):
189        self.info = info
190        self.q_input = q_input
191        self.kernel = kernel
192        self.res = np.empty(q_input.nq, q_input.dtype)
193        dim = '2d' if q_input.is_2D else '1d'
194        self.fixed_pars = info['partype']['fixed-'+dim]
195        self.pd_pars = info['partype']['pd-'+dim]
196
197        # In dll kernel, but not in opencl kernel
198        self.p_res = self.res.ctypes.data
199
200    def __call__(self, fixed_pars, pd_pars, cutoff):
201        real = np.float32 if self.q_input.dtype == F32 else np.float64
202
203        nq = c_int(self.q_input.nq)
204        if pd_pars:
205            cutoff = real(cutoff)
206            loops_N = [np.uint32(len(p[0])) for p in pd_pars]
207            loops = np.hstack(pd_pars)
208            loops = np.ascontiguousarray(loops.T, self.q_input.dtype).flatten()
209            p_loops = loops.ctypes.data
210            dispersed = [p_loops, cutoff] + loops_N
211        else:
212            dispersed = []
213        fixed = [real(p) for p in fixed_pars]
214        args = self.q_input.q_pointers + [self.p_res, nq] + dispersed + fixed
215        #print pars
216        self.kernel(*args)
217
218        return self.res
219
220    def release(self):
221        pass
Note: See TracBrowser for help on using the repository browser.