source: sasmodels/sasmodels/kerneldll.py @ 0763009

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 0763009 was 0763009, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

minor code cleanup

  • Property mode set to 100644
File size: 8.7 KB
RevLine 
[14de349]1"""
2C types wrapper for sasview models.
[750ffa5]3
4The global attribute *ALLOW_SINGLE_PRECISION_DLLS* should be set to *True* if
5you wish to allow single precision floating point evaluation for the compiled
6models, otherwise it defaults to *False*.
[14de349]7"""
[750ffa5]8
[5d4777d]9import sys
10import os
[df4dc86]11import tempfile
[14de349]12import ctypes as ct
[750ffa5]13from ctypes import c_void_p, c_int, c_double, c_float
[14de349]14
15import numpy as np
16
[cb6ecf4]17from . import generate
[f734e7d]18from .kernelpy import PyInput, PyModel
[2c801fe]19from .exception import annotate_exception
[14de349]20
[5d4777d]21# Compiler platform details
22if sys.platform == 'darwin':
[216a9e1]23    #COMPILE = "gcc-mp-4.7 -shared -fPIC -std=c99 -fopenmp -O2 -Wall %s -o %s -lm -lgomp"
[df4dc86]24    COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
[5d4777d]25elif os.name == 'nt':
[3c56da87]26    # call vcvarsall.bat before compiling to set path, headers, libs, etc.
[68d3c1b]27    if "VCINSTALLDIR" in os.environ:
28        # MSVC compiler is available, so use it.
[f734e7d]29        # TODO: remove intermediate OBJ file created in the directory
30        # TODO: maybe don't use randomized name for the c file
[f3f46cd]31        #COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /openmp /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
[3c56da87]32        # Can't find VCOMP90.DLL (don't know why), so remove openmp support
33        # from windows compiler build
[f3f46cd]34        COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
[68d3c1b]35    else:
36        #COMPILE = "gcc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
37        COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
[5d4777d]38else:
[df4dc86]39    COMPILE = "cc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
40
41DLL_PATH = tempfile.gettempdir()
[5d4777d]42
[750ffa5]43ALLOW_SINGLE_PRECISION_DLLS = False
[5d4777d]44
[750ffa5]45
46def dll_path(info, dtype="double"):
[5d4777d]47    """
48    Path to the compiled model defined by *info*.
49    """
50    from os.path import join as joinpath, split as splitpath, splitext
51    basename = splitext(splitpath(info['filename'])[1])[0]
[750ffa5]52    if np.dtype(dtype) == generate.F32:
53        basename += "32"
[5d4777d]54    return joinpath(DLL_PATH, basename+'.so')
55
56
[aa4946b]57def make_dll(source, info, dtype="double"):
[5d4777d]58    """
59    Load the compiled model defined by *kernel_module*.
60
61    Recompile if any files are newer than the model file.
62
[aa4946b]63    *dtype* is a numpy floating point precision specifier indicating whether
64    the model should be single or double precision.  The default is double
65    precision.
[5d4777d]66
[aa4946b]67    The DLL is not loaded until the kernel is called so models can
[5d4777d]68    be defined without using too many resources.
[aa4946b]69
70    Set *sasmodels.kerneldll.DLL_PATH* to the compiled dll output path.
71    The default is the system temporary directory.
72
73    Set *sasmodels.ALLOW_SINGLE_PRECISION_DLLS* to True if single precision
74    models are allowed as DLLs.
[5d4777d]75    """
[750ffa5]76    if not ALLOW_SINGLE_PRECISION_DLLS: dtype = "double"   # Force 64-bit dll
77    dtype = np.dtype(dtype)
78
[f734e7d]79    if callable(info.get('Iq',None)):
80        return PyModel(info)
[750ffa5]81
82    if dtype == generate.F32: # 32-bit dll
83        source = generate.use_single(source)
84        tempfile_prefix = 'sas_'+info['name']+'32_'
85    else:
86        tempfile_prefix = 'sas_'+info['name']+'_'
87
[cb6ecf4]88    source_files = generate.sources(info) + [info['filename']]
[aa4946b]89    dll= dll_path(info, dtype)
[5d4777d]90    newest = max(os.path.getmtime(f) for f in source_files)
[aa4946b]91    if not os.path.exists(dll) or os.path.getmtime(dll)<newest:
[5d4777d]92        # Replace with a proper temp file
[750ffa5]93        fid, filename = tempfile.mkstemp(suffix=".c",prefix=tempfile_prefix)
[5d4777d]94        os.fdopen(fid,"w").write(source)
[aa4946b]95        command = COMPILE%{"source":filename, "output":dll}
[df4dc86]96        print "Compile command:",command
97        status = os.system(command)
[aa4946b]98        if status != 0 or not os.path.exists(dll):
[f734e7d]99            raise RuntimeError("compile failed.  File is in %r"%filename)
[5d4777d]100        else:
101            ## uncomment the following to keep the generated c file
[aa4946b]102            os.unlink(filename); print "saving compiled file in %r"%filename
103    return dll
104
105
106def load_dll(source, info, dtype="double"):
107    """
108    Create and load a dll corresponding to the source,info pair returned
109    from :func:`sasmodels.generate.make` compiled for the target precision.
110
111    See :func:`make_dll` for details on controlling the dll path and the
112    allowed floating point precision.
113    """
114    filename = make_dll(source, info, dtype=dtype)
115    return DllModel(filename, info, dtype=dtype)
[5d4777d]116
[14de349]117
[f734e7d]118IQ_ARGS = [c_void_p, c_void_p, c_int]
119IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int]
[14de349]120
121class DllModel(object):
122    """
123    ctypes wrapper for a single model.
124
[ce27e21]125    *source* and *info* are the model source and interface as returned
[14de349]126    from :func:`gen.make`.
127
128    *dtype* is the desired model precision.  Any numpy dtype for single
129    or double precision floats will do, such as 'f', 'float32' or 'single'
130    for single and 'd', 'float64' or 'double' for double.  Double precision
131    is an optional extension which may not be available on all devices.
[ff7119b]132
133    Call :meth:`release` when done with the kernel.
[14de349]134    """
[750ffa5]135    def __init__(self, dllpath, info, dtype=generate.F32):
[ce27e21]136        self.info = info
137        self.dllpath = dllpath
138        self.dll = None
[750ffa5]139        self.dtype = np.dtype(dtype)
[14de349]140
[ce27e21]141    def _load_dll(self):
142        Nfixed1d = len(self.info['partype']['fixed-1d'])
143        Nfixed2d = len(self.info['partype']['fixed-2d'])
144        Npd1d = len(self.info['partype']['pd-1d'])
145        Npd2d = len(self.info['partype']['pd-2d'])
[14de349]146
[df4dc86]147        #print "dll",self.dllpath
[2c801fe]148        try:
149            self.dll = ct.CDLL(self.dllpath)
150        except Exception, exc:
[f3f46cd]151            annotate_exception(exc, "while loading "+self.dllpath)
[2c801fe]152            raise
[14de349]153
[750ffa5]154        fp = c_float if self.dtype == generate.F32 else c_double
155        pd_args_1d = [c_void_p, fp] + [c_int]*Npd1d if Npd1d else []
156        pd_args_2d= [c_void_p, fp] + [c_int]*Npd2d if Npd2d else []
[cb6ecf4]157        self.Iq = self.dll[generate.kernel_name(self.info, False)]
[750ffa5]158        self.Iq.argtypes = IQ_ARGS + pd_args_1d + [fp]*Nfixed1d
[ce27e21]159
[cb6ecf4]160        self.Iqxy = self.dll[generate.kernel_name(self.info, True)]
[750ffa5]161        self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [fp]*Nfixed2d
[ce27e21]162
163    def __getstate__(self):
164        return {'info': self.info, 'dllpath': self.dllpath, 'dll': None}
165
166    def __setstate__(self, state):
167        self.__dict__ = state
168
[3c56da87]169    def __call__(self, q_input):
[750ffa5]170        if self.dtype != q_input.dtype:
171            raise TypeError("data is %s kernel is %s" % (q_input.dtype, self.dtype))
[b3f6bc3]172        if self.dll is None: self._load_dll()
[3c56da87]173        kernel = self.Iqxy if q_input.is_2D else self.Iq
174        return DllKernel(kernel, self.info, q_input)
[14de349]175
[3c56da87]176    # pylint: disable=no-self-use
[14de349]177    def make_input(self, q_vectors):
178        """
179        Make q input vectors available to the model.
180
[b3f6bc3]181        Note that each model needs its own q vector even if the case of
182        mixture models because some models may be OpenCL, some may be
183        ctypes and some may be pure python.
[14de349]184        """
[750ffa5]185        return PyInput(q_vectors, dtype=self.dtype)
[14de349]186
[ff7119b]187    def release(self):
188        pass # TODO: should release the dll
189
[14de349]190
191class DllKernel(object):
[ff7119b]192    """
193    Callable SAS kernel.
194
[b3f6bc3]195    *kernel* is the c function to call.
[ff7119b]196
197    *info* is the module information
198
[3c56da87]199    *q_input* is the DllInput q vectors at which the kernel should be
[ff7119b]200    evaluated.
201
202    The resulting call method takes the *pars*, a list of values for
203    the fixed parameters to the kernel, and *pd_pars*, a list of (value,weight)
204    vectors for the polydisperse parameters.  *cutoff* determines the
205    integration limits: any points with combined weight less than *cutoff*
206    will not be calculated.
207
208    Call :meth:`release` when done with the kernel instance.
209    """
[3c56da87]210    def __init__(self, kernel, info, q_input):
[5d4777d]211        self.info = info
[3c56da87]212        self.q_input = q_input
[14de349]213        self.kernel = kernel
[3c56da87]214        self.res = np.empty(q_input.nq, q_input.dtype)
215        dim = '2d' if q_input.is_2D else '1d'
[ce27e21]216        self.fixed_pars = info['partype']['fixed-'+dim]
217        self.pd_pars = info['partype']['pd-'+dim]
[14de349]218
[ce27e21]219        # In dll kernel, but not in opencl kernel
220        self.p_res = self.res.ctypes.data
[14de349]221
[f734e7d]222    def __call__(self, fixed_pars, pd_pars, cutoff):
[63b32bb]223        real = np.float32 if self.q_input.dtype == generate.F32 else np.float64
[14de349]224
[3c56da87]225        nq = c_int(self.q_input.nq)
[f734e7d]226        if pd_pars:
227            cutoff = real(cutoff)
228            loops_N = [np.uint32(len(p[0])) for p in pd_pars]
229            loops = np.hstack(pd_pars)
[3c56da87]230            loops = np.ascontiguousarray(loops.T, self.q_input.dtype).flatten()
[f734e7d]231            p_loops = loops.ctypes.data
232            dispersed = [p_loops, cutoff] + loops_N
233        else:
234            dispersed = []
235        fixed = [real(p) for p in fixed_pars]
[3c56da87]236        args = self.q_input.q_pointers + [self.p_res, nq] + dispersed + fixed
[14de349]237        #print pars
[ce27e21]238        self.kernel(*args)
[14de349]239
240        return self.res
241
242    def release(self):
243        pass
Note: See TracBrowser for help on using the repository browser.