source: sasmodels/sasmodels/core.py @ 71bf6de

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 71bf6de was 71bf6de, checked in by Adam Washington <adam.washington@…>, 6 years ago

First draft of solution

  • Property mode set to 100644
File size: 12.9 KB
Line 
1"""
2Core model handling routines.
3"""
4from __future__ import print_function
5
6__all__ = [
7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dlls",
9    ]
10
11import os
12from os.path import basename, join as joinpath
13from glob import glob
14import re
15
16import numpy as np # type: ignore
17
18from . import generate
19from . import modelinfo
20from . import product
21from . import mixture
22from . import kernelpy
23from . import kerneldll
24from . import custom
25
26if os.environ.get("SAS_OPENCL", "").lower() == "none":
27    HAVE_OPENCL = False
28else:
29    try:
30        from . import kernelcl
31        HAVE_OPENCL = True
32    except Exception:
33        HAVE_OPENCL = False
34
35CUSTOM_MODEL_PATH = os.environ.get('SAS_MODELPATH', "")
36if CUSTOM_MODEL_PATH == "":
37    CUSTOM_MODEL_PATH = joinpath(os.path.expanduser("~"), ".sasmodels", "custom_models")
38    if not os.path.isdir(CUSTOM_MODEL_PATH):
39        os.makedirs(CUSTOM_MODEL_PATH)
40
41# pylint: disable=unused-import
42try:
43    from typing import List, Union, Optional, Any
44    from .kernel import KernelModel
45    from .modelinfo import ModelInfo
46except ImportError:
47    pass
48# pylint: enable=unused-import
49
50# TODO: refactor composite model support
51# The current load_model_info/build_model does not reuse existing model
52# definitions when loading a composite model, instead reloading and
53# rebuilding the kernel for each component model in the expression.  This
54# is fine in a scripting environment where the model is built when the script
55# starts and is thrown away when the script ends, but may not be the best
56# solution in a long-lived application.  This affects the following functions:
57#
58#    load_model
59#    load_model_info
60#    build_model
61
62KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
63         "nonmagnetic", "magnetic")
64def list_models(kind=None):
65    # type: (str) -> List[str]
66    """
67    Return the list of available models on the model path.
68
69    *kind* can be one of the following:
70
71        * all: all models
72        * py: python models only
73        * c: compiled models only
74        * single: models which support single precision
75        * double: models which require double precision
76        * opencl: controls if OpenCL is supperessed
77        * 1d: models which are 1D only, or 2D using abs(q)
78        * 2d: models which can be 2D
79        * magnetic: models with an sld
80        * nommagnetic: models without an sld
81
82    For multiple conditions, combine with plus.  For example, *c+single+2d*
83    would return all oriented models implemented in C which can be computed
84    accurately with single precision arithmetic.
85    """
86    if kind and any(k not in KINDS for k in kind.split('+')):
87        raise ValueError("kind not in " + ", ".join(KINDS))
88    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
89    available_models = [basename(f)[:-3] for f in files]
90    if kind and '+' in kind:
91        all_kinds = kind.split('+')
92        condition = lambda name: all(_matches(name, k) for k in all_kinds)
93    else:
94        condition = lambda name: _matches(name, kind)
95    selected = [name for name in available_models if condition(name)]
96
97    return selected
98
99def _matches(name, kind):
100    if kind is None or kind == "all":
101        return True
102    info = load_model_info(name)
103    pars = info.parameters.kernel_parameters
104    if kind == "py" and callable(info.Iq):
105        return True
106    elif kind == "c" and not callable(info.Iq):
107        return True
108    elif kind == "double" and not info.single:
109        return True
110    elif kind == "single" and info.single:
111        return True
112    elif kind == "opencl" and info.opencl:
113        return True
114    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
115        return True
116    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
117        return True
118    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
119        return True
120    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
121        return True
122    return False
123
124def load_model(model_name, dtype=None, platform='ocl'):
125    # type: (str, str, str) -> KernelModel
126    """
127    Load model info and build model.
128
129    *model_name* is the name of the model, or perhaps a model expression
130    such as sphere*hardsphere or sphere+cylinder.
131
132    *dtype* and *platform* are given by :func:`build_model`.
133    """
134    return build_model(load_model_info(model_name),
135                       dtype=dtype, platform=platform)
136
137def load_model_info(model_string):
138    # type: (str) -> modelinfo.ModelInfo
139    """
140    Load a model definition given the model name.
141
142    *model_string* is the name of the model, or perhaps a model expression
143    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
144    factor product, e.g. sphere@hardsphere. Custom models can be specified by
145    prefixing the model name with 'custom.', e.g. 'custom.MyModel+sphere'.
146
147    This returns a handle to the module defining the model.  This can be
148    used with functions in generate to build the docs or extract model info.
149    """
150    if '@' in model_string:
151        terms = model_string.split('+')
152        results = []
153        for term in terms:
154            if '@' in term:
155                p_info, q_info = [load_model_info(part)
156                                  for part in term.split("@")]
157                results.append(product.make_product_info(p_info, q_info))
158            else:
159                results.append(load_model_info(term))
160        return mixture.make_mixture_info(results, operation='+')
161        # parts = model_string.split('@')
162        # if len(parts) != 2:
163        #     raise ValueError("Use P@S to apply a structure factor S to model P")
164        # P_info, Q_info = [load_model_info(part) for part in parts]
165        # return product.make_product_info(P_info, Q_info)
166
167    product_parts = []
168    addition_parts = []
169
170    addition_parts_names = model_string.split('+')
171    if len(addition_parts_names) >= 2:
172        addition_parts = [load_model_info(part) for part in addition_parts_names]
173    elif len(addition_parts_names) == 1:
174        product_parts_names = model_string.split('*')
175        if len(product_parts_names) >= 2:
176            product_parts = [load_model_info(part) for part in product_parts_names]
177        elif len(product_parts_names) == 1:
178            if "custom." in product_parts_names[0]:
179                # Extract ModelName from "custom.ModelName"
180                pattern = "custom.([A-Za-z0-9_-]+)"
181                result = re.match(pattern, product_parts_names[0])
182                if result is None:
183                    raise ValueError("Model name in invalid format: " + product_parts_names[0])
184                model_name = result.group(1)
185                # Use ModelName to find the path to the custom model file
186                model_path = joinpath(CUSTOM_MODEL_PATH, model_name + ".py")
187                if not os.path.isfile(model_path):
188                    raise ValueError("The model file {} doesn't exist".format(model_path))
189                kernel_module = custom.load_custom_kernel_module(model_path)
190                return modelinfo.make_model_info(kernel_module)
191            # Model is a core model
192            kernel_module = generate.load_kernel_module(product_parts_names[0])
193            return modelinfo.make_model_info(kernel_module)
194
195    model = None
196    if len(product_parts) > 1:
197        model = mixture.make_mixture_info(product_parts, operation='*')
198    if len(addition_parts) > 1:
199        if model is not None:
200            addition_parts.append(model)
201        model = mixture.make_mixture_info(addition_parts, operation='+')
202    return model
203
204
205def build_model(model_info, dtype=None, platform="ocl"):
206    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
207    """
208    Prepare the model for the default execution platform.
209
210    This will return an OpenCL model, a DLL model or a python model depending
211    on the model and the computing platform.
212
213    *model_info* is the model definition structure returned from
214    :func:`load_model_info`.
215
216    *dtype* indicates whether the model should use single or double precision
217    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
218    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
219    than OpenCL for the calculation.
220
221    *platform* should be "dll" to force the dll to be used for C models,
222    otherwise it uses the default "ocl".
223    """
224    composition = model_info.composition
225    if composition is not None:
226        composition_type, parts = composition
227        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
228        if composition_type == 'mixture':
229            return mixture.MixtureModel(model_info, models)
230        elif composition_type == 'product':
231            P, S = models
232            return product.ProductModel(model_info, P, S)
233        else:
234            raise ValueError('unknown mixture type %s'%composition_type)
235
236    # If it is a python model, return it immediately
237    if callable(model_info.Iq):
238        return kernelpy.PyModel(model_info)
239
240    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
241
242    source = generate.make_source(model_info)
243    if platform == "dll":
244        #print("building dll", numpy_dtype)
245        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
246    else:
247        #print("building ocl", numpy_dtype)
248        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
249
250def precompile_dlls(path, dtype="double"):
251    # type: (str, str) -> List[str]
252    """
253    Precompile the dlls for all builtin models, returning a list of dll paths.
254
255    *path* is the directory in which to save the dlls.  It will be created if
256    it does not already exist.
257
258    This can be used when build the windows distribution of sasmodels
259    which may be missing the OpenCL driver and the dll compiler.
260    """
261    numpy_dtype = np.dtype(dtype)
262    if not os.path.exists(path):
263        os.makedirs(path)
264    compiled_dlls = []
265    for model_name in list_models():
266        model_info = load_model_info(model_name)
267        if not callable(model_info.Iq):
268            source = generate.make_source(model_info)['dll']
269            old_path = kerneldll.DLL_PATH
270            try:
271                kerneldll.DLL_PATH = path
272                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
273            finally:
274                kerneldll.DLL_PATH = old_path
275            compiled_dlls.append(dll)
276    return compiled_dlls
277
278def parse_dtype(model_info, dtype=None, platform=None):
279    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
280    """
281    Interpret dtype string, returning np.dtype and fast flag.
282
283    Possible types include 'half', 'single', 'double' and 'quad'.  If the
284    type is 'fast', then this is equivalent to dtype 'single' but using
285    fast native functions rather than those with the precision level
286    guaranteed by the OpenCL standard.  'default' will choose the appropriate
287    default for the model and platform.
288
289    Platform preference can be specfied ("ocl" vs "dll"), with the default
290    being OpenCL if it is availabe.  If the dtype name ends with '!' then
291    platform is forced to be DLL rather than OpenCL.
292
293    This routine ignores the preferences within the model definition.  This
294    is by design.  It allows us to test models in single precision even when
295    we have flagged them as requiring double precision so we can easily check
296    the performance on different platforms without having to change the model
297    definition.
298    """
299    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
300    # If opencl=False OpenCL is switched off
301
302    if platform is None:
303        platform = "ocl"
304    if platform == "ocl" and not HAVE_OPENCL or not model_info.opencl:
305        platform = "dll"
306
307    # Check if type indicates dll regardless of which platform is given
308    if dtype is not None and dtype.endswith('!'):
309        platform = "dll"
310        dtype = dtype[:-1]
311
312    # Convert special type names "half", "fast", and "quad"
313    fast = (dtype == "fast")
314    if fast:
315        dtype = "single"
316    elif dtype == "quad":
317        dtype = "longdouble"
318    elif dtype == "half":
319        dtype = "float16"
320
321    # Convert dtype string to numpy dtype.
322    if dtype is None or dtype == "default":
323        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
324                       else generate.F64)
325    else:
326        numpy_dtype = np.dtype(dtype)
327
328    # Make sure that the type is supported by opencl, otherwise use dll
329    if platform == "ocl":
330        env = kernelcl.environment()
331        if not env.has_type(numpy_dtype):
332            platform = "dll"
333            if dtype is None:
334                numpy_dtype = generate.F64
335
336    return numpy_dtype, fast, platform
337
338def list_models_main():
339    # type: () -> None
340    """
341    Run list_models as a main program.  See :func:`list_models` for the
342    kinds of models that can be requested on the command line.
343    """
344    import sys
345    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
346    print("\n".join(list_models(kind)))
347
348if __name__ == "__main__":
349    list_models_main()
Note: See TracBrowser for help on using the repository browser.