source: sasmodels/sasmodels/core.py @ 672978c

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 672978c was 672978c, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

Merge branch 'master' of github.com:sasview/sasmodels

  • Property mode set to 100644
File size: 9.2 KB
Line 
1"""
2Core model handling routines.
3"""
4from __future__ import print_function
5
6__all__ = [
7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dlls",
9    ]
10
11import os
12from os.path import basename, dirname, join as joinpath
13from glob import glob
14
15import numpy as np # type: ignore
16
17from . import generate
18from . import modelinfo
19from . import product
20from . import mixture
21from . import kernelpy
22from . import kerneldll
23try:
24    from . import kernelcl
25    HAVE_OPENCL = True
26except Exception:
27    HAVE_OPENCL = False
28
29try:
30    from typing import List, Union, Optional, Any
31    from .kernel import KernelModel
32    from .modelinfo import ModelInfo
33except ImportError:
34    pass
35
36try:
37    np.meshgrid([])
38    meshgrid = np.meshgrid
39except Exception:
40    # CRUFT: np.meshgrid requires multiple vectors
41    def meshgrid(*args):
42        """Allow meshgrid with a single argument"""
43        if len(args) > 1:
44            return np.meshgrid(*args)
45        else:
46            return [np.asarray(v) for v in args]
47
48# TODO: refactor composite model support
49# The current load_model_info/build_model does not reuse existing model
50# definitions when loading a composite model, instead reloading and
51# rebuilding the kernel for each component model in the expression.  This
52# is fine in a scripting environment where the model is built when the script
53# starts and is thrown away when the script ends, but may not be the best
54# solution in a long-lived application.  This affects the following functions:
55#
56#    load_model
57#    load_model_info
58#    build_model
59
60KINDS = ("all", "py", "c", "double", "single", "1d", "2d",
61         "nonmagnetic", "magnetic")
62def list_models(kind=None):
63    # type: () -> List[str]
64    """
65    Return the list of available models on the model path.
66
67    *kind* can be one of the following:
68
69        * all: all models
70        * py: python models only
71        * c: compiled models only
72        * single: models which support single precision
73        * double: models which require double precision
74        * 1d: models which are 1D only, or 2D using abs(q)
75        * 2d: models which can be 2D
76        * magnetic: models with an sld
77        * nommagnetic: models without an sld
78    """
79    if kind and kind not in KINDS:
80        raise ValueError("kind not in " + ", ".join(KINDS))
81    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
82    available_models = [basename(f)[:-3] for f in files]
83    selected = [name for name in available_models if _matches(name, kind)]
84
85    return selected
86
87def _matches(name, kind):
88    if kind is None or kind == "all":
89        return True
90    info = load_model_info(name)
91    pars = info.parameters.kernel_parameters
92    if kind == "py" and callable(info.Iq):
93        return True
94    elif kind == "c" and not callable(info.Iq):
95        return True
96    elif kind == "double" and not info.single:
97        return True
98    elif kind == "single" and info.single:
99        return True
100    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
101        return True
102    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
103        return True
104    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
105        return True
106    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
107        return True
108    return False
109
110def load_model(model_name, dtype=None, platform='ocl'):
111    # type: (str, str, str) -> KernelModel
112    """
113    Load model info and build model.
114
115    *model_name* is the name of the model as used by :func:`load_model_info`.
116    Additional keyword arguments are passed directly to :func:`build_model`.
117    """
118    return build_model(load_model_info(model_name),
119                       dtype=dtype, platform=platform)
120
121
122def load_model_info(model_name):
123    # type: (str) -> modelinfo.ModelInfo
124    """
125    Load a model definition given the model name.
126
127    This returns a handle to the module defining the model.  This can be
128    used with functions in generate to build the docs or extract model info.
129    """
130    parts = model_name.split('+')
131    if len(parts) > 1:
132        model_info_list = [load_model_info(p) for p in parts]
133        return mixture.make_mixture_info(model_info_list)
134
135    parts = model_name.split('*')
136    if len(parts) > 1:
137        if len(parts) > 2:
138            raise ValueError("use P*S to apply structure factor S to model P")
139        P_info, Q_info = [load_model_info(p) for p in parts]
140        return product.make_product_info(P_info, Q_info)
141
142    kernel_module = generate.load_kernel_module(model_name)
143    return modelinfo.make_model_info(kernel_module)
144
145
146def build_model(model_info, dtype=None, platform="ocl"):
147    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
148    """
149    Prepare the model for the default execution platform.
150
151    This will return an OpenCL model, a DLL model or a python model depending
152    on the model and the computing platform.
153
154    *model_info* is the model definition structure returned from
155    :func:`load_model_info`.
156
157    *dtype* indicates whether the model should use single or double precision
158    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
159    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
160    than OpenCL for the calculation.
161
162    *platform* should be "dll" to force the dll to be used for C models,
163    otherwise it uses the default "ocl".
164    """
165    composition = model_info.composition
166    if composition is not None:
167        composition_type, parts = composition
168        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
169        if composition_type == 'mixture':
170            return mixture.MixtureModel(model_info, models)
171        elif composition_type == 'product':
172            P, S = models
173            return product.ProductModel(model_info, P, S)
174        else:
175            raise ValueError('unknown mixture type %s'%composition_type)
176
177    # If it is a python model, return it immediately
178    if callable(model_info.Iq):
179        return kernelpy.PyModel(model_info)
180
181    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
182
183    source = generate.make_source(model_info)
184    if platform == "dll":
185        #print("building dll", numpy_dtype)
186        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
187    else:
188        #print("building ocl", numpy_dtype)
189        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
190
191def precompile_dlls(path, dtype="double"):
192    # type: (str, str) -> List[str]
193    """
194    Precompile the dlls for all builtin models, returning a list of dll paths.
195
196    *path* is the directory in which to save the dlls.  It will be created if
197    it does not already exist.
198
199    This can be used when build the windows distribution of sasmodels
200    which may be missing the OpenCL driver and the dll compiler.
201    """
202    numpy_dtype = np.dtype(dtype)
203    if not os.path.exists(path):
204        os.makedirs(path)
205    compiled_dlls = []
206    for model_name in list_models():
207        model_info = load_model_info(model_name)
208        if not callable(model_info.Iq):
209            source = generate.make_source(model_info)['dll']
210            old_path = kerneldll.DLL_PATH
211            try:
212                kerneldll.DLL_PATH = path
213                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
214            finally:
215                kerneldll.DLL_PATH = old_path
216            compiled_dlls.append(dll)
217    return compiled_dlls
218
219def parse_dtype(model_info, dtype=None, platform=None):
220    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
221    """
222    Interpret dtype string, returning np.dtype and fast flag.
223
224    Possible types include 'half', 'single', 'double' and 'quad'.  If the
225    type is 'fast', then this is equivalent to dtype 'single' with the
226    fast flag set to True.
227    """
228    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
229    if platform is None:
230        platform = "ocl"
231    if platform == "ocl" and not HAVE_OPENCL:
232        platform = "dll"
233
234    # Check if type indicates dll regardless of which platform is given
235    if dtype is not None and dtype.endswith('!'):
236        platform = "dll"
237        dtype = dtype[:-1]
238
239    # Convert special type names "half", "fast", and "quad"
240    fast = (dtype == "fast")
241    if fast:
242        dtype = "single"
243    elif dtype == "quad":
244        dtype = "longdouble"
245    elif dtype == "half":
246        dtype = "f16"
247
248    # Convert dtype string to numpy dtype.
249    if dtype is None:
250        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
251                       else generate.F64)
252    else:
253        numpy_dtype = np.dtype(dtype)
254
255    # Make sure that the type is supported by opencl, otherwise use dll
256    if platform == "ocl":
257        env = kernelcl.environment()
258        if not env.has_type(numpy_dtype):
259            platform = "dll"
260            if dtype is None:
261                numpy_dtype = generate.F64
262
263    return numpy_dtype, fast, platform
264
265def list_models_main():
266    # type: () -> None
267    """
268    Run list_models as a main program.  See :func:`list_models` for the
269    kinds of models that can be requested on the command line.
270    """
271    import sys
272    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
273    print("\n".join(list_models(kind)))
274
275if __name__ == "__main__":
276    list_models_main()
Note: See TracBrowser for help on using the repository browser.