source: sasmodels/sasmodels/core.py @ fb9a3b6

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since fb9a3b6 was fb9a3b6, checked in by lewis, 7 years ago

Allow mixture models to do multiplication as well as addition

  • Property mode set to 100644
File size: 10.8 KB
Line 
1"""
2Core model handling routines.
3"""
4from __future__ import print_function
5
6__all__ = [
7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dlls",
9    ]
10
11import os
12from os.path import basename, dirname, join as joinpath
13from glob import glob
14
15import numpy as np # type: ignore
16
17from . import generate
18from . import modelinfo
19from . import product
20from . import mixture
21from . import kernelpy
22from . import kerneldll
23
24if os.environ.get("SAS_OPENCL", "").lower() == "none":
25    HAVE_OPENCL = False
26else:
27    try:
28        from . import kernelcl
29        HAVE_OPENCL = True
30    except Exception:
31        HAVE_OPENCL = False
32
33try:
34    from typing import List, Union, Optional, Any
35    from .kernel import KernelModel
36    from .modelinfo import ModelInfo
37except ImportError:
38    pass
39
40# TODO: refactor composite model support
41# The current load_model_info/build_model does not reuse existing model
42# definitions when loading a composite model, instead reloading and
43# rebuilding the kernel for each component model in the expression.  This
44# is fine in a scripting environment where the model is built when the script
45# starts and is thrown away when the script ends, but may not be the best
46# solution in a long-lived application.  This affects the following functions:
47#
48#    load_model
49#    load_model_info
50#    build_model
51
52KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
53         "nonmagnetic", "magnetic")
54def list_models(kind=None):
55    # type: (str) -> List[str]
56    """
57    Return the list of available models on the model path.
58
59    *kind* can be one of the following:
60
61        * all: all models
62        * py: python models only
63        * c: compiled models only
64        * single: models which support single precision
65        * double: models which require double precision
66        * opencl: controls if OpenCL is supperessed
67        * 1d: models which are 1D only, or 2D using abs(q)
68        * 2d: models which can be 2D
69        * magnetic: models with an sld
70        * nommagnetic: models without an sld
71
72    For multiple conditions, combine with plus.  For example, *c+single+2d*
73    would return all oriented models implemented in C which can be computed
74    accurately with single precision arithmetic.
75    """
76    if kind and any(k not in KINDS for k in kind.split('+')):
77        raise ValueError("kind not in " + ", ".join(KINDS))
78    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
79    available_models = [basename(f)[:-3] for f in files]
80    if kind and '+' in kind:
81        all_kinds = kind.split('+')
82        condition = lambda name: all(_matches(name, k) for k in all_kinds)
83    else:
84        condition = lambda name: _matches(name, kind)
85    selected = [name for name in available_models if condition(name)]
86
87    return selected
88
89def _matches(name, kind):
90    if kind is None or kind == "all":
91        return True
92    info = load_model_info(name)
93    pars = info.parameters.kernel_parameters
94    if kind == "py" and callable(info.Iq):
95        return True
96    elif kind == "c" and not callable(info.Iq):
97        return True
98    elif kind == "double" and not info.single:
99        return True
100    elif kind == "single" and info.single:
101        return True
102    elif kind == "opencl" and info.opencl:
103        return True
104    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
105        return True
106    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
107        return True
108    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
109        return True
110    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
111        return True
112    return False
113
114def load_model(model_name, dtype=None, platform='ocl'):
115    # type: (str, str, str) -> KernelModel
116    """
117    Load model info and build model.
118
119    *model_name* is the name of the model, or perhaps a model expression
120    such as sphere*hardsphere or sphere+cylinder.
121
122    *dtype* and *platform* are given by :func:`build_model`.
123    """
124    return build_model(load_model_info(model_name),
125                       dtype=dtype, platform=platform)
126
127
128def load_model_info(model_name, force_mixture=False):
129    # type: (str) -> modelinfo.ModelInfo
130    """
131    Load a model definition given the model name.
132
133    *model_name* is the name of the model, or perhaps a model expression
134    such as sphere*hardsphere or sphere+cylinder.
135
136    *force_mixture* if true, MixtureModel will be used for combining models.
137    Otherwise either MixtureModel will be used for addition and ProductModel
138    will be used for multiplication
139
140    This returns a handle to the module defining the model.  This can be
141    used with functions in generate to build the docs or extract model info.
142    """
143    parts = model_name.split('+')
144    if len(parts) > 1:
145        # Always use MixtureModel for addition
146        model_info_list = [load_model_info(p) for p in parts]
147        return mixture.make_mixture_info(model_info_list)
148
149    parts = model_name.split('*')
150    if len(parts) > 1:
151        if force_mixture:
152            # Use MixtureModel for multiplication if forced
153            model_info_list = [load_model_info(p) for p in parts]
154            return mixture.make_mixture_info(model_info_list, operation='*')
155        if len(parts) > 2:
156            raise ValueError("use P*S to apply structure factor S to model P")
157        # Use ProductModel
158        P_info, Q_info = [load_model_info(p) for p in parts]
159        return product.make_product_info(P_info, Q_info)
160
161    kernel_module = generate.load_kernel_module(model_name)
162    return modelinfo.make_model_info(kernel_module)
163
164
165def build_model(model_info, dtype=None, platform="ocl"):
166    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
167    """
168    Prepare the model for the default execution platform.
169
170    This will return an OpenCL model, a DLL model or a python model depending
171    on the model and the computing platform.
172
173    *model_info* is the model definition structure returned from
174    :func:`load_model_info`.
175
176    *dtype* indicates whether the model should use single or double precision
177    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
178    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
179    than OpenCL for the calculation.
180
181    *platform* should be "dll" to force the dll to be used for C models,
182    otherwise it uses the default "ocl".
183    """
184    composition = model_info.composition
185    if composition is not None:
186        composition_type, parts = composition
187        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
188        if composition_type == 'mixture':
189            return mixture.MixtureModel(model_info, models)
190        elif composition_type == 'product':
191            P, S = models
192            return product.ProductModel(model_info, P, S)
193        else:
194            raise ValueError('unknown mixture type %s'%composition_type)
195
196    # If it is a python model, return it immediately
197    if callable(model_info.Iq):
198        return kernelpy.PyModel(model_info)
199
200    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
201
202    source = generate.make_source(model_info)
203    if platform == "dll":
204        #print("building dll", numpy_dtype)
205        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
206    else:
207        #print("building ocl", numpy_dtype)
208        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
209
210def precompile_dlls(path, dtype="double"):
211    # type: (str, str) -> List[str]
212    """
213    Precompile the dlls for all builtin models, returning a list of dll paths.
214
215    *path* is the directory in which to save the dlls.  It will be created if
216    it does not already exist.
217
218    This can be used when build the windows distribution of sasmodels
219    which may be missing the OpenCL driver and the dll compiler.
220    """
221    numpy_dtype = np.dtype(dtype)
222    if not os.path.exists(path):
223        os.makedirs(path)
224    compiled_dlls = []
225    for model_name in list_models():
226        model_info = load_model_info(model_name)
227        if not callable(model_info.Iq):
228            source = generate.make_source(model_info)['dll']
229            old_path = kerneldll.DLL_PATH
230            try:
231                kerneldll.DLL_PATH = path
232                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
233            finally:
234                kerneldll.DLL_PATH = old_path
235            compiled_dlls.append(dll)
236    return compiled_dlls
237
238def parse_dtype(model_info, dtype=None, platform=None):
239    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
240    """
241    Interpret dtype string, returning np.dtype and fast flag.
242
243    Possible types include 'half', 'single', 'double' and 'quad'.  If the
244    type is 'fast', then this is equivalent to dtype 'single' but using
245    fast native functions rather than those with the precision level guaranteed
246    by the OpenCL standard.
247
248    Platform preference can be specfied ("ocl" vs "dll"), with the default
249    being OpenCL if it is availabe.  If the dtype name ends with '!' then
250    platform is forced to be DLL rather than OpenCL.
251
252    This routine ignores the preferences within the model definition.  This
253    is by design.  It allows us to test models in single precision even when
254    we have flagged them as requiring double precision so we can easily check
255    the performance on different platforms without having to change the model
256    definition.
257    """
258    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
259    # If opencl=False OpenCL is switched off
260
261    if platform is None:
262        platform = "ocl"
263    if platform == "ocl" and not HAVE_OPENCL or not model_info.opencl:
264        platform = "dll"
265
266    # Check if type indicates dll regardless of which platform is given
267    if dtype is not None and dtype.endswith('!'):
268        platform = "dll"
269        dtype = dtype[:-1]
270
271    # Convert special type names "half", "fast", and "quad"
272    fast = (dtype == "fast")
273    if fast:
274        dtype = "single"
275    elif dtype == "quad":
276        dtype = "longdouble"
277    elif dtype == "half":
278        dtype = "float16"
279
280    # Convert dtype string to numpy dtype.
281    if dtype is None:
282        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
283                       else generate.F64)
284    else:
285        numpy_dtype = np.dtype(dtype)
286
287    # Make sure that the type is supported by opencl, otherwise use dll
288    if platform == "ocl":
289        env = kernelcl.environment()
290        if not env.has_type(numpy_dtype):
291            platform = "dll"
292            if dtype is None:
293                numpy_dtype = generate.F64
294
295    return numpy_dtype, fast, platform
296
297def list_models_main():
298    # type: () -> None
299    """
300    Run list_models as a main program.  See :func:`list_models` for the
301    kinds of models that can be requested on the command line.
302    """
303    import sys
304    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
305    print("\n".join(list_models(kind)))
306
307if __name__ == "__main__":
308    list_models_main()
Note: See TracBrowser for help on using the repository browser.