source: sasmodels/sasmodels/kerneldll.py @ 754c454

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 754c454 was f786ff3, checked in by pkienzle, 9 years ago

rename modules for clarity

  • Property mode set to 100644
File size: 6.7 KB
RevLine 
[14de349]1"""
2C types wrapper for sasview models.
3"""
[5d4777d]4import sys
5import os
[df4dc86]6import tempfile
[14de349]7import ctypes as ct
8from ctypes import c_void_p, c_int, c_double
9
10import numpy as np
11
[cb6ecf4]12from . import generate
[f786ff3]13from .kernelpy import PyInput, PyKernel
[14de349]14
[cb6ecf4]15from .generate import F32, F64
[5d4777d]16# Compiler platform details
17if sys.platform == 'darwin':
[216a9e1]18    #COMPILE = "gcc-mp-4.7 -shared -fPIC -std=c99 -fopenmp -O2 -Wall %s -o %s -lm -lgomp"
[df4dc86]19    COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
[5d4777d]20elif os.name == 'nt':
[df4dc86]21    # make sure vcvarsall.bat is called first in order to set compiler, headers, lib paths, etc.
22    ##COMPILER = r'"C:\Program Files (x86)\Common Files\Microsoft\Visual C++ for Python\9.0\VC\Bin\cl.exe"'
23    # Can't find VCOMP90.DLL (don't know why), so remove openmp support from windows compiler build
24    #COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
25    COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /openmp /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
26    #COMPILE = "gcc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
[5d4777d]27else:
[df4dc86]28    COMPILE = "cc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
29
30DLL_PATH = tempfile.gettempdir()
[5d4777d]31
32
33def dll_path(info):
34    """
35    Path to the compiled model defined by *info*.
36    """
37    from os.path import join as joinpath, split as splitpath, splitext
38    basename = splitext(splitpath(info['filename'])[1])[0]
39    return joinpath(DLL_PATH, basename+'.so')
40
41
42def load_model(kernel_module, dtype=None):
43    """
44    Load the compiled model defined by *kernel_module*.
45
46    Recompile if any files are newer than the model file.
47
48    *dtype* is ignored.  Compiled files are always double.
49
50    The DLL is not loaded until the kernel is called so models an
51    be defined without using too many resources.
52    """
53    import tempfile
54
[cb6ecf4]55    source, info = generate.make(kernel_module)
56    source_files = generate.sources(info) + [info['filename']]
[5d4777d]57    newest = max(os.path.getmtime(f) for f in source_files)
58    dllpath = dll_path(info)
59    if not os.path.exists(dllpath) or os.path.getmtime(dllpath)<newest:
60        # Replace with a proper temp file
61        fid, filename = tempfile.mkstemp(suffix=".c",prefix="sas_"+info['name'])
62        os.fdopen(fid,"w").write(source)
[df4dc86]63        command = COMPILE%{"source":filename, "output":dllpath}
64        print "Compile command:",command
65        status = os.system(command)
[5d4777d]66        if status != 0:
67            print "compile failed.  File is in %r"%filename
68        else:
69            ## uncomment the following to keep the generated c file
70            #os.unlink(filename); print "saving compiled file in %r"%filename
71            pass
72    return DllModel(dllpath, info)
73
[14de349]74
75IQ_ARGS = [c_void_p, c_void_p, c_int, c_void_p, c_double]
76IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int, c_void_p, c_double]
77
78class DllModel(object):
79    """
80    ctypes wrapper for a single model.
81
[ce27e21]82    *source* and *info* are the model source and interface as returned
[14de349]83    from :func:`gen.make`.
84
85    *dtype* is the desired model precision.  Any numpy dtype for single
86    or double precision floats will do, such as 'f', 'float32' or 'single'
87    for single and 'd', 'float64' or 'double' for double.  Double precision
88    is an optional extension which may not be available on all devices.
[ff7119b]89
90    Call :meth:`release` when done with the kernel.
[14de349]91    """
[ce27e21]92    def __init__(self, dllpath, info):
93        self.info = info
94        self.dllpath = dllpath
95        self.dll = None
[14de349]96
[ce27e21]97    def _load_dll(self):
98        Nfixed1d = len(self.info['partype']['fixed-1d'])
99        Nfixed2d = len(self.info['partype']['fixed-2d'])
100        Npd1d = len(self.info['partype']['pd-1d'])
101        Npd2d = len(self.info['partype']['pd-2d'])
[14de349]102
[df4dc86]103        #print "dll",self.dllpath
[ce27e21]104        self.dll = ct.CDLL(self.dllpath)
[14de349]105
[cb6ecf4]106        self.Iq = self.dll[generate.kernel_name(self.info, False)]
[ce27e21]107        self.Iq.argtypes = IQ_ARGS + [c_double]*Nfixed1d + [c_int]*Npd1d
108
[cb6ecf4]109        self.Iqxy = self.dll[generate.kernel_name(self.info, True)]
[ce27e21]110        self.Iqxy.argtypes = IQXY_ARGS + [c_double]*Nfixed2d + [c_int]*Npd2d
111
112    def __getstate__(self):
113        return {'info': self.info, 'dllpath': self.dllpath, 'dll': None}
114
115    def __setstate__(self, state):
116        self.__dict__ = state
117
118    def __call__(self, input):
[b3f6bc3]119        # Support pure python kernel call
120        if input.is_2D and callable(self.info['Iqxy']):
121            return PyKernel(self.info['Iqxy'], self.info, input)
122        elif not input.is_2D and callable(self.info['Iq']):
123            return PyKernel(self.info['Iq'], self.info, input)
[14de349]124
[b3f6bc3]125        if self.dll is None: self._load_dll()
[14de349]126        kernel = self.Iqxy if input.is_2D else self.Iq
[ce27e21]127        return DllKernel(kernel, self.info, input)
[14de349]128
129    def make_input(self, q_vectors):
130        """
131        Make q input vectors available to the model.
132
[b3f6bc3]133        Note that each model needs its own q vector even if the case of
134        mixture models because some models may be OpenCL, some may be
135        ctypes and some may be pure python.
[14de349]136        """
[b3f6bc3]137        return PyInput(q_vectors, dtype=F64)
[14de349]138
[ff7119b]139    def release(self):
140        pass # TODO: should release the dll
141
[14de349]142
143class DllKernel(object):
[ff7119b]144    """
145    Callable SAS kernel.
146
[b3f6bc3]147    *kernel* is the c function to call.
[ff7119b]148
149    *info* is the module information
150
151    *input* is the DllInput q vectors at which the kernel should be
152    evaluated.
153
154    The resulting call method takes the *pars*, a list of values for
155    the fixed parameters to the kernel, and *pd_pars*, a list of (value,weight)
156    vectors for the polydisperse parameters.  *cutoff* determines the
157    integration limits: any points with combined weight less than *cutoff*
158    will not be calculated.
159
160    Call :meth:`release` when done with the kernel instance.
161    """
[ce27e21]162    def __init__(self, kernel, info, input):
[5d4777d]163        self.info = info
[14de349]164        self.input = input
165        self.kernel = kernel
166        self.res = np.empty(input.nq, input.dtype)
[ce27e21]167        dim = '2d' if input.is_2D else '1d'
168        self.fixed_pars = info['partype']['fixed-'+dim]
169        self.pd_pars = info['partype']['pd-'+dim]
[14de349]170
[ce27e21]171        # In dll kernel, but not in opencl kernel
172        self.p_res = self.res.ctypes.data
[14de349]173
[ce27e21]174    def __call__(self, pars, pd_pars, cutoff):
[14de349]175        real = np.float32 if self.input.dtype == F32 else np.float64
[ce27e21]176        fixed = [real(p) for p in pars]
177        cutoff = real(cutoff)
178        loops = np.hstack(pd_pars)
179        loops = np.ascontiguousarray(loops.T, self.input.dtype).flatten()
180        loops_N = [np.uint32(len(p[0])) for p in pd_pars]
[14de349]181
[ce27e21]182        nq = c_int(self.input.nq)
[14de349]183        p_loops = loops.ctypes.data
[ce27e21]184        args = self.input.q_pointers + [self.p_res, nq, p_loops, cutoff] + fixed + loops_N
[14de349]185        #print pars
[ce27e21]186        self.kernel(*args)
[14de349]187
188        return self.res
189
190    def release(self):
191        pass
Note: See TracBrowser for help on using the repository browser.