source: sasmodels/sasmodels/core.py @ 481ff64

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 481ff64 was 481ff64, checked in by lewis, 7 years ago

Use '@' for structure factor product

  • Property mode set to 100644
File size: 11.0 KB
Line 
1"""
2Core model handling routines.
3"""
4from __future__ import print_function
5
6__all__ = [
7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dlls",
9    ]
10
11import os
12from os.path import basename, dirname, join as joinpath
13from glob import glob
14
15import numpy as np # type: ignore
16
17from . import generate
18from . import modelinfo
19from . import product
20from . import mixture
21from . import kernelpy
22from . import kerneldll
23
24if os.environ.get("SAS_OPENCL", "").lower() == "none":
25    HAVE_OPENCL = False
26else:
27    try:
28        from . import kernelcl
29        HAVE_OPENCL = True
30    except Exception:
31        HAVE_OPENCL = False
32
33try:
34    from typing import List, Union, Optional, Any
35    from .kernel import KernelModel
36    from .modelinfo import ModelInfo
37except ImportError:
38    pass
39
40# TODO: refactor composite model support
41# The current load_model_info/build_model does not reuse existing model
42# definitions when loading a composite model, instead reloading and
43# rebuilding the kernel for each component model in the expression.  This
44# is fine in a scripting environment where the model is built when the script
45# starts and is thrown away when the script ends, but may not be the best
46# solution in a long-lived application.  This affects the following functions:
47#
48#    load_model
49#    load_model_info
50#    build_model
51
52KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
53         "nonmagnetic", "magnetic")
54def list_models(kind=None):
55    # type: (str) -> List[str]
56    """
57    Return the list of available models on the model path.
58
59    *kind* can be one of the following:
60
61        * all: all models
62        * py: python models only
63        * c: compiled models only
64        * single: models which support single precision
65        * double: models which require double precision
66        * opencl: controls if OpenCL is supperessed
67        * 1d: models which are 1D only, or 2D using abs(q)
68        * 2d: models which can be 2D
69        * magnetic: models with an sld
70        * nommagnetic: models without an sld
71
72    For multiple conditions, combine with plus.  For example, *c+single+2d*
73    would return all oriented models implemented in C which can be computed
74    accurately with single precision arithmetic.
75    """
76    if kind and any(k not in KINDS for k in kind.split('+')):
77        raise ValueError("kind not in " + ", ".join(KINDS))
78    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
79    available_models = [basename(f)[:-3] for f in files]
80    if kind and '+' in kind:
81        all_kinds = kind.split('+')
82        condition = lambda name: all(_matches(name, k) for k in all_kinds)
83    else:
84        condition = lambda name: _matches(name, kind)
85    selected = [name for name in available_models if condition(name)]
86
87    return selected
88
89def _matches(name, kind):
90    if kind is None or kind == "all":
91        return True
92    info = load_model_info(name)
93    pars = info.parameters.kernel_parameters
94    if kind == "py" and callable(info.Iq):
95        return True
96    elif kind == "c" and not callable(info.Iq):
97        return True
98    elif kind == "double" and not info.single:
99        return True
100    elif kind == "single" and info.single:
101        return True
102    elif kind == "opencl" and info.opencl:
103        return True
104    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
105        return True
106    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
107        return True
108    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
109        return True
110    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
111        return True
112    return False
113
114def load_model(model_name, dtype=None, platform='ocl'):
115    # type: (str, str, str) -> KernelModel
116    """
117    Load model info and build model.
118
119    *model_name* is the name of the model, or perhaps a model expression
120    such as sphere*hardsphere or sphere+cylinder.
121
122    *dtype* and *platform* are given by :func:`build_model`.
123    """
124    return build_model(load_model_info(model_name),
125                       dtype=dtype, platform=platform)
126
127def load_model_info(model_string):
128    # type: (str) -> modelinfo.ModelInfo
129    """
130    Load a model definition given the model name.
131
132    *model_string* is the name of the model, or perhaps a model expression
133    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
134    factor product, eg sphere@hardsphere.
135
136    This returns a handle to the module defining the model.  This can be
137    used with functions in generate to build the docs or extract model info.
138    """
139    if '@' in model_string:
140        parts = model_string.split('@')
141        if len(parts) != 2:
142            raise ValueError("Use P@S to apply a structure factor S to model P")
143        P_info, Q_info = [load_model_info(part) for part in parts]
144        return product.make_product_info(P_info, Q_info)
145
146    product_parts = []
147    addition_parts = []
148
149    addition_parts_names = model_string.split('+')
150    if len(addition_parts_names) >= 2:
151        addition_parts = [load_model_info(part) for part in addition_parts_names]
152    elif len(addition_parts_names) == 1:
153        product_parts_names = model_string.split('*')
154        if len(product_parts_names) >= 2:
155            product_parts = [load_model_info(part) for part in product_parts_names]
156        elif len(product_parts_names) == 1:
157            kernel_module = generate.load_kernel_module(product_parts_names[0])
158            return modelinfo.make_model_info(kernel_module)
159
160    model = None
161    if len(product_parts) > 1:
162        model = mixture.make_mixture_info(product_parts, operation='*')
163    if len(addition_parts) > 1:
164        if model is not None:
165            addition_parts.append(model)
166        model = mixture.make_mixture_info(addition_parts, operation='+')
167    return model
168
169
170def build_model(model_info, dtype=None, platform="ocl"):
171    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
172    """
173    Prepare the model for the default execution platform.
174
175    This will return an OpenCL model, a DLL model or a python model depending
176    on the model and the computing platform.
177
178    *model_info* is the model definition structure returned from
179    :func:`load_model_info`.
180
181    *dtype* indicates whether the model should use single or double precision
182    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
183    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
184    than OpenCL for the calculation.
185
186    *platform* should be "dll" to force the dll to be used for C models,
187    otherwise it uses the default "ocl".
188    """
189    composition = model_info.composition
190    if composition is not None:
191        composition_type, parts = composition
192        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
193        if composition_type == 'mixture':
194            return mixture.MixtureModel(model_info, models)
195        elif composition_type == 'product':
196            P, S = models
197            return product.ProductModel(model_info, P, S)
198        else:
199            raise ValueError('unknown mixture type %s'%composition_type)
200
201    # If it is a python model, return it immediately
202    if callable(model_info.Iq):
203        return kernelpy.PyModel(model_info)
204
205    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
206
207    source = generate.make_source(model_info)
208    if platform == "dll":
209        #print("building dll", numpy_dtype)
210        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
211    else:
212        #print("building ocl", numpy_dtype)
213        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
214
215def precompile_dlls(path, dtype="double"):
216    # type: (str, str) -> List[str]
217    """
218    Precompile the dlls for all builtin models, returning a list of dll paths.
219
220    *path* is the directory in which to save the dlls.  It will be created if
221    it does not already exist.
222
223    This can be used when build the windows distribution of sasmodels
224    which may be missing the OpenCL driver and the dll compiler.
225    """
226    numpy_dtype = np.dtype(dtype)
227    if not os.path.exists(path):
228        os.makedirs(path)
229    compiled_dlls = []
230    for model_name in list_models():
231        model_info = load_model_info(model_name)
232        if not callable(model_info.Iq):
233            source = generate.make_source(model_info)['dll']
234            old_path = kerneldll.DLL_PATH
235            try:
236                kerneldll.DLL_PATH = path
237                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
238            finally:
239                kerneldll.DLL_PATH = old_path
240            compiled_dlls.append(dll)
241    return compiled_dlls
242
243def parse_dtype(model_info, dtype=None, platform=None):
244    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
245    """
246    Interpret dtype string, returning np.dtype and fast flag.
247
248    Possible types include 'half', 'single', 'double' and 'quad'.  If the
249    type is 'fast', then this is equivalent to dtype 'single' but using
250    fast native functions rather than those with the precision level guaranteed
251    by the OpenCL standard.
252
253    Platform preference can be specfied ("ocl" vs "dll"), with the default
254    being OpenCL if it is availabe.  If the dtype name ends with '!' then
255    platform is forced to be DLL rather than OpenCL.
256
257    This routine ignores the preferences within the model definition.  This
258    is by design.  It allows us to test models in single precision even when
259    we have flagged them as requiring double precision so we can easily check
260    the performance on different platforms without having to change the model
261    definition.
262    """
263    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
264    # If opencl=False OpenCL is switched off
265
266    if platform is None:
267        platform = "ocl"
268    if platform == "ocl" and not HAVE_OPENCL or not model_info.opencl:
269        platform = "dll"
270
271    # Check if type indicates dll regardless of which platform is given
272    if dtype is not None and dtype.endswith('!'):
273        platform = "dll"
274        dtype = dtype[:-1]
275
276    # Convert special type names "half", "fast", and "quad"
277    fast = (dtype == "fast")
278    if fast:
279        dtype = "single"
280    elif dtype == "quad":
281        dtype = "longdouble"
282    elif dtype == "half":
283        dtype = "float16"
284
285    # Convert dtype string to numpy dtype.
286    if dtype is None:
287        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
288                       else generate.F64)
289    else:
290        numpy_dtype = np.dtype(dtype)
291
292    # Make sure that the type is supported by opencl, otherwise use dll
293    if platform == "ocl":
294        env = kernelcl.environment()
295        if not env.has_type(numpy_dtype):
296            platform = "dll"
297            if dtype is None:
298                numpy_dtype = generate.F64
299
300    return numpy_dtype, fast, platform
301
302def list_models_main():
303    # type: () -> None
304    """
305    Run list_models as a main program.  See :func:`list_models` for the
306    kinds of models that can be requested on the command line.
307    """
308    import sys
309    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
310    print("\n".join(list_models(kind)))
311
312if __name__ == "__main__":
313    list_models_main()
Note: See TracBrowser for help on using the repository browser.