source: sasmodels/sasmodels/kerneldll.py @ 3e428ec

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3e428ec was 3c56da87, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

lint cleanup

  • Property mode set to 100644
File size: 7.0 KB
Line 
1"""
2C types wrapper for sasview models.
3"""
4import sys
5import os
6import tempfile
7import ctypes as ct
8from ctypes import c_void_p, c_int, c_double
9
10import numpy as np
11
12from . import generate
13from .kernelpy import PyInput, PyModel
14
15from .generate import F32, F64
16# Compiler platform details
17if sys.platform == 'darwin':
18    #COMPILE = "gcc-mp-4.7 -shared -fPIC -std=c99 -fopenmp -O2 -Wall %s -o %s -lm -lgomp"
19    COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
20elif os.name == 'nt':
21    # call vcvarsall.bat before compiling to set path, headers, libs, etc.
22    if "VCINSTALLDIR" in os.environ:
23        # MSVC compiler is available, so use it.
24        # TODO: remove intermediate OBJ file created in the directory
25        # TODO: maybe don't use randomized name for the c file
26        COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /openmp /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
27        # Can't find VCOMP90.DLL (don't know why), so remove openmp support
28        # from windows compiler build
29        #COMPILE = "cl /nologo /Ox /MD /W3 /GS- /DNDEBUG /Tp%(source)s /link /DLL /INCREMENTAL:NO /MANIFEST /OUT:%(output)s"
30    else:
31        #COMPILE = "gcc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
32        COMPILE = "gcc -shared -fPIC -std=c99 -O2 -Wall %(source)s -o %(output)s -lm"
33else:
34    COMPILE = "cc -shared -fPIC -std=c99 -fopenmp -O2 -Wall %(source)s -o %(output)s -lm"
35
36DLL_PATH = tempfile.gettempdir()
37
38
39def dll_path(info):
40    """
41    Path to the compiled model defined by *info*.
42    """
43    from os.path import join as joinpath, split as splitpath, splitext
44    basename = splitext(splitpath(info['filename'])[1])[0]
45    return joinpath(DLL_PATH, basename+'.so')
46
47
48def load_model(kernel_module, dtype=None):
49    """
50    Load the compiled model defined by *kernel_module*.
51
52    Recompile if any files are newer than the model file.
53
54    *dtype* is ignored.  Compiled files are always double.
55
56    The DLL is not loaded until the kernel is called so models an
57    be defined without using too many resources.
58    """
59    source, info = generate.make(kernel_module)
60    if callable(info.get('Iq',None)):
61        return PyModel(info)
62    source_files = generate.sources(info) + [info['filename']]
63    newest = max(os.path.getmtime(f) for f in source_files)
64    dllpath = dll_path(info)
65    if not os.path.exists(dllpath) or os.path.getmtime(dllpath)<newest:
66        # Replace with a proper temp file
67        fid, filename = tempfile.mkstemp(suffix=".c",prefix="sas_"+info['name'])
68        os.fdopen(fid,"w").write(source)
69        command = COMPILE%{"source":filename, "output":dllpath}
70        print "Compile command:",command
71        status = os.system(command)
72        if status != 0:
73            raise RuntimeError("compile failed.  File is in %r"%filename)
74        else:
75            ## uncomment the following to keep the generated c file
76            #os.unlink(filename); print "saving compiled file in %r"%filename
77            pass
78    return DllModel(dllpath, info)
79
80
81IQ_ARGS = [c_void_p, c_void_p, c_int]
82IQXY_ARGS = [c_void_p, c_void_p, c_void_p, c_int]
83
84class DllModel(object):
85    """
86    ctypes wrapper for a single model.
87
88    *source* and *info* are the model source and interface as returned
89    from :func:`gen.make`.
90
91    *dtype* is the desired model precision.  Any numpy dtype for single
92    or double precision floats will do, such as 'f', 'float32' or 'single'
93    for single and 'd', 'float64' or 'double' for double.  Double precision
94    is an optional extension which may not be available on all devices.
95
96    Call :meth:`release` when done with the kernel.
97    """
98    def __init__(self, dllpath, info):
99        self.info = info
100        self.dllpath = dllpath
101        self.dll = None
102
103    def _load_dll(self):
104        Nfixed1d = len(self.info['partype']['fixed-1d'])
105        Nfixed2d = len(self.info['partype']['fixed-2d'])
106        Npd1d = len(self.info['partype']['pd-1d'])
107        Npd2d = len(self.info['partype']['pd-2d'])
108
109        #print "dll",self.dllpath
110        self.dll = ct.CDLL(self.dllpath)
111
112        pd_args_1d = [c_void_p, c_double] + [c_int]*Npd1d if Npd1d else []
113        pd_args_2d= [c_void_p, c_double] + [c_int]*Npd2d if Npd2d else []
114        self.Iq = self.dll[generate.kernel_name(self.info, False)]
115        self.Iq.argtypes = IQ_ARGS + pd_args_1d + [c_double]*Nfixed1d
116
117        self.Iqxy = self.dll[generate.kernel_name(self.info, True)]
118        self.Iqxy.argtypes = IQXY_ARGS + pd_args_2d + [c_double]*Nfixed2d
119
120    def __getstate__(self):
121        return {'info': self.info, 'dllpath': self.dllpath, 'dll': None}
122
123    def __setstate__(self, state):
124        self.__dict__ = state
125
126    def __call__(self, q_input):
127        if self.dll is None: self._load_dll()
128        kernel = self.Iqxy if q_input.is_2D else self.Iq
129        return DllKernel(kernel, self.info, q_input)
130
131    # pylint: disable=no-self-use
132    def make_input(self, q_vectors):
133        """
134        Make q input vectors available to the model.
135
136        Note that each model needs its own q vector even if the case of
137        mixture models because some models may be OpenCL, some may be
138        ctypes and some may be pure python.
139        """
140        return PyInput(q_vectors, dtype=F64)
141
142    def release(self):
143        pass # TODO: should release the dll
144
145
146class DllKernel(object):
147    """
148    Callable SAS kernel.
149
150    *kernel* is the c function to call.
151
152    *info* is the module information
153
154    *q_input* is the DllInput q vectors at which the kernel should be
155    evaluated.
156
157    The resulting call method takes the *pars*, a list of values for
158    the fixed parameters to the kernel, and *pd_pars*, a list of (value,weight)
159    vectors for the polydisperse parameters.  *cutoff* determines the
160    integration limits: any points with combined weight less than *cutoff*
161    will not be calculated.
162
163    Call :meth:`release` when done with the kernel instance.
164    """
165    def __init__(self, kernel, info, q_input):
166        self.info = info
167        self.q_input = q_input
168        self.kernel = kernel
169        self.res = np.empty(q_input.nq, q_input.dtype)
170        dim = '2d' if q_input.is_2D else '1d'
171        self.fixed_pars = info['partype']['fixed-'+dim]
172        self.pd_pars = info['partype']['pd-'+dim]
173
174        # In dll kernel, but not in opencl kernel
175        self.p_res = self.res.ctypes.data
176
177    def __call__(self, fixed_pars, pd_pars, cutoff):
178        real = np.float32 if self.q_input.dtype == F32 else np.float64
179
180        nq = c_int(self.q_input.nq)
181        if pd_pars:
182            cutoff = real(cutoff)
183            loops_N = [np.uint32(len(p[0])) for p in pd_pars]
184            loops = np.hstack(pd_pars)
185            loops = np.ascontiguousarray(loops.T, self.q_input.dtype).flatten()
186            p_loops = loops.ctypes.data
187            dispersed = [p_loops, cutoff] + loops_N
188        else:
189            dispersed = []
190        fixed = [real(p) for p in fixed_pars]
191        args = self.q_input.q_pointers + [self.p_res, nq] + dispersed + fixed
192        #print pars
193        self.kernel(*args)
194
195        return self.res
196
197    def release(self):
198        pass
Note: See TracBrowser for help on using the repository browser.