source: sasmodels/sasmodels/core.py @ 40a87fa

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 40a87fa was 40a87fa, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

lint and latex cleanup

  • Property mode set to 100644
File size: 9.2 KB
Line 
1"""
2Core model handling routines.
3"""
4from __future__ import print_function
5
6__all__ = [
7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dlls",
9    ]
10
11import os
12from os.path import basename, dirname, join as joinpath
13from glob import glob
14
15import numpy as np # type: ignore
16
17from . import generate
18from . import modelinfo
19from . import product
20from . import mixture
21from . import kernelpy
22from . import kerneldll
23try:
24    from . import kernelcl
25    HAVE_OPENCL = True
26except Exception:
27    HAVE_OPENCL = False
28
29try:
30    from typing import List, Union, Optional, Any
31    from .kernel import KernelModel
32    from .modelinfo import ModelInfo
33except ImportError:
34    pass
35
36try:
37    np.meshgrid([])
38    meshgrid = np.meshgrid
39except Exception:
40    # CRUFT: np.meshgrid requires multiple vectors
41    def meshgrid(*args):
42        """Allow meshgrid with a single argument"""
43        if len(args) > 1:
44            return np.meshgrid(*args)
45        else:
46            return [np.asarray(v) for v in args]
47
48# TODO: refactor composite model support
49# The current load_model_info/build_model does not reuse existing model
50# definitions when loading a composite model, instead reloading and
51# rebuilding the kernel for each component model in the expression.  This
52# is fine in a scripting environment where the model is built when the script
53# starts and is thrown away when the script ends, but may not be the best
54# solution in a long-lived application.  This affects the following functions:
55#
56#    load_model
57#    load_model_info
58#    build_model
59
60KINDS = ("all", "py", "c", "double", "single", "1d", "2d",
61         "nonmagnetic", "magnetic")
62def list_models(kind=None):
63    # type: () -> List[str]
64    """
65    Return the list of available models on the model path.
66
67    *kind* can be one of the following:
68
69        * all: all models
70        * py: python models only
71        * c: compiled models only
72        * single: models which support single precision
73        * double: models which require double precision
74        * 1d: models which are 1D only, or 2D using abs(q)
75        * 2d: models which can be 2D
76        * magnetic: models with an sld
77        * nommagnetic: models without an sld
78    """
79    if kind and kind not in KINDS:
80        raise ValueError("kind not in " + ", ".join(KINDS))
81    root = dirname(__file__)
82    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
83    available_models = [basename(f)[:-3] for f in files]
84    selected = [name for name in available_models if _matches(name, kind)]
85
86    return selected
87
88def _matches(name, kind):
89    if kind is None or kind == "all":
90        return True
91    info = load_model_info(name)
92    pars = info.parameters.kernel_parameters
93    if kind == "py" and callable(info.Iq):
94        return True
95    elif kind == "c" and not callable(info.Iq):
96        return True
97    elif kind == "double" and not info.single:
98        return True
99    elif kind == "single" and info.single:
100        return True
101    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
102        return True
103    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
104        return True
105    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
106        return True
107    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
108        return True
109    return False
110
111def load_model(model_name, dtype=None, platform='ocl'):
112    # type: (str, str, str) -> KernelModel
113    """
114    Load model info and build model.
115
116    *model_name* is the name of the model as used by :func:`load_model_info`.
117    Additional keyword arguments are passed directly to :func:`build_model`.
118    """
119    return build_model(load_model_info(model_name),
120                       dtype=dtype, platform=platform)
121
122
123def load_model_info(model_name):
124    # type: (str) -> modelinfo.ModelInfo
125    """
126    Load a model definition given the model name.
127
128    This returns a handle to the module defining the model.  This can be
129    used with functions in generate to build the docs or extract model info.
130    """
131    parts = model_name.split('+')
132    if len(parts) > 1:
133        model_info_list = [load_model_info(p) for p in parts]
134        return mixture.make_mixture_info(model_info_list)
135
136    parts = model_name.split('*')
137    if len(parts) > 1:
138        if len(parts) > 2:
139            raise ValueError("use P*S to apply structure factor S to model P")
140        P_info, Q_info = [load_model_info(p) for p in parts]
141        return product.make_product_info(P_info, Q_info)
142
143    kernel_module = generate.load_kernel_module(model_name)
144    return modelinfo.make_model_info(kernel_module)
145
146
147def build_model(model_info, dtype=None, platform="ocl"):
148    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
149    """
150    Prepare the model for the default execution platform.
151
152    This will return an OpenCL model, a DLL model or a python model depending
153    on the model and the computing platform.
154
155    *model_info* is the model definition structure returned from
156    :func:`load_model_info`.
157
158    *dtype* indicates whether the model should use single or double precision
159    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
160    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
161    than OpenCL for the calculation.
162
163    *platform* should be "dll" to force the dll to be used for C models,
164    otherwise it uses the default "ocl".
165    """
166    composition = model_info.composition
167    if composition is not None:
168        composition_type, parts = composition
169        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
170        if composition_type == 'mixture':
171            return mixture.MixtureModel(model_info, models)
172        elif composition_type == 'product':
173            P, S = models
174            return product.ProductModel(model_info, P, S)
175        else:
176            raise ValueError('unknown mixture type %s'%composition_type)
177
178    # If it is a python model, return it immediately
179    if callable(model_info.Iq):
180        return kernelpy.PyModel(model_info)
181
182    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
183
184    source = generate.make_source(model_info)
185    if platform == "dll":
186        #print("building dll", numpy_dtype)
187        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
188    else:
189        #print("building ocl", numpy_dtype)
190        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
191
192def precompile_dlls(path, dtype="double"):
193    # type: (str, str) -> List[str]
194    """
195    Precompile the dlls for all builtin models, returning a list of dll paths.
196
197    *path* is the directory in which to save the dlls.  It will be created if
198    it does not already exist.
199
200    This can be used when build the windows distribution of sasmodels
201    which may be missing the OpenCL driver and the dll compiler.
202    """
203    numpy_dtype = np.dtype(dtype)
204    if not os.path.exists(path):
205        os.makedirs(path)
206    compiled_dlls = []
207    for model_name in list_models():
208        model_info = load_model_info(model_name)
209        if not callable(model_info.Iq):
210            source = generate.make_source(model_info)['dll']
211            old_path = kerneldll.DLL_PATH
212            try:
213                kerneldll.DLL_PATH = path
214                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
215            finally:
216                kerneldll.DLL_PATH = old_path
217            compiled_dlls.append(dll)
218    return compiled_dlls
219
220def parse_dtype(model_info, dtype=None, platform=None):
221    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
222    """
223    Interpret dtype string, returning np.dtype and fast flag.
224
225    Possible types include 'half', 'single', 'double' and 'quad'.  If the
226    type is 'fast', then this is equivalent to dtype 'single' with the
227    fast flag set to True.
228    """
229    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
230    if platform is None:
231        platform = "ocl"
232    if platform == "ocl" and not HAVE_OPENCL:
233        platform = "dll"
234
235    # Check if type indicates dll regardless of which platform is given
236    if dtype is not None and dtype.endswith('!'):
237        platform = "dll"
238        dtype = dtype[:-1]
239
240    # Convert special type names "half", "fast", and "quad"
241    fast = (dtype == "fast")
242    if fast:
243        dtype = "single"
244    elif dtype == "quad":
245        dtype = "longdouble"
246    elif dtype == "half":
247        dtype = "f16"
248
249    # Convert dtype string to numpy dtype.
250    if dtype is None:
251        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
252                       else generate.F64)
253    else:
254        numpy_dtype = np.dtype(dtype)
255
256    # Make sure that the type is supported by opencl, otherwise use dll
257    if platform == "ocl":
258        env = kernelcl.environment()
259        if not env.has_type(numpy_dtype):
260            platform = "dll"
261            if dtype is None:
262                numpy_dtype = generate.F64
263
264    return numpy_dtype, fast, platform
265
266def list_models_main():
267    # type: () -> None
268    """
269    Run list_models as a main program.  See :func:`list_models` for the
270    kinds of models that can be requested on the command line.
271    """
272    import sys
273    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
274    print("\n".join(list_models(kind)))
275
276if __name__ == "__main__":
277    list_models_main()
Note: See TracBrowser for help on using the repository browser.