source: sasmodels/sasmodels/core.py @ 4341dd4

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 4341dd4 was 4341dd4, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

no need to create custom model directory if it doesn't already exist

  • Property mode set to 100644
File size: 13.8 KB
Line 
1"""
2Core model handling routines.
3"""
4from __future__ import print_function
5
6__all__ = [
7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dlls",
9    ]
10
11import os
12from os.path import basename, join as joinpath
13from glob import glob
14import re
15
16import numpy as np # type: ignore
17
18from . import generate
19from . import modelinfo
20from . import product
21from . import mixture
22from . import kernelpy
23from . import kernelcl
24from . import kerneldll
25from . import custom
26
27# pylint: disable=unused-import
28try:
29    from typing import List, Union, Optional, Any
30    from .kernel import KernelModel
31    from .modelinfo import ModelInfo
32except ImportError:
33    pass
34# pylint: enable=unused-import
35
36CUSTOM_MODEL_PATH = os.environ.get('SAS_MODELPATH', "")
37if CUSTOM_MODEL_PATH == "":
38    CUSTOM_MODEL_PATH = joinpath(os.path.expanduser("~"), ".sasmodels", "custom_models")
39    #if not os.path.isdir(CUSTOM_MODEL_PATH):
40    #    os.makedirs(CUSTOM_MODEL_PATH)
41
42# TODO: refactor composite model support
43# The current load_model_info/build_model does not reuse existing model
44# definitions when loading a composite model, instead reloading and
45# rebuilding the kernel for each component model in the expression.  This
46# is fine in a scripting environment where the model is built when the script
47# starts and is thrown away when the script ends, but may not be the best
48# solution in a long-lived application.  This affects the following functions:
49#
50#    load_model
51#    load_model_info
52#    build_model
53
54KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
55         "nonmagnetic", "magnetic")
56def list_models(kind=None):
57    # type: (str) -> List[str]
58    """
59    Return the list of available models on the model path.
60
61    *kind* can be one of the following:
62
63        * all: all models
64        * py: python models only
65        * c: compiled models only
66        * single: models which support single precision
67        * double: models which require double precision
68        * opencl: controls if OpenCL is supperessed
69        * 1d: models which are 1D only, or 2D using abs(q)
70        * 2d: models which can be 2D
71        * magnetic: models with an sld
72        * nommagnetic: models without an sld
73
74    For multiple conditions, combine with plus.  For example, *c+single+2d*
75    would return all oriented models implemented in C which can be computed
76    accurately with single precision arithmetic.
77    """
78    if kind and any(k not in KINDS for k in kind.split('+')):
79        raise ValueError("kind not in " + ", ".join(KINDS))
80    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
81    available_models = [basename(f)[:-3] for f in files]
82    if kind and '+' in kind:
83        all_kinds = kind.split('+')
84        condition = lambda name: all(_matches(name, k) for k in all_kinds)
85    else:
86        condition = lambda name: _matches(name, kind)
87    selected = [name for name in available_models if condition(name)]
88
89    return selected
90
91def _matches(name, kind):
92    if kind is None or kind == "all":
93        return True
94    info = load_model_info(name)
95    pars = info.parameters.kernel_parameters
96    if kind == "py" and callable(info.Iq):
97        return True
98    elif kind == "c" and not callable(info.Iq):
99        return True
100    elif kind == "double" and not info.single:
101        return True
102    elif kind == "single" and info.single:
103        return True
104    elif kind == "opencl" and info.opencl:
105        return True
106    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
107        return True
108    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
109        return True
110    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
111        return True
112    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
113        return True
114    return False
115
116def load_model(model_name, dtype=None, platform='ocl'):
117    # type: (str, str, str) -> KernelModel
118    """
119    Load model info and build model.
120
121    *model_name* is the name of the model, or perhaps a model expression
122    such as sphere*hardsphere or sphere+cylinder.
123
124    *dtype* and *platform* are given by :func:`build_model`.
125    """
126    return build_model(load_model_info(model_name),
127                       dtype=dtype, platform=platform)
128
129def load_model_info(model_string):
130    # type: (str) -> modelinfo.ModelInfo
131    """
132    Load a model definition given the model name.
133
134    *model_string* is the name of the model, or perhaps a model expression
135    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
136    factor product, e.g. sphere@hardsphere. Custom models can be specified by
137    prefixing the model name with 'custom.', e.g. 'custom.MyModel+sphere'.
138
139    This returns a handle to the module defining the model.  This can be
140    used with functions in generate to build the docs or extract model info.
141    """
142    if "+" in model_string:
143        parts = [load_model_info(part)
144                 for part in model_string.split("+")]
145        return mixture.make_mixture_info(parts, operation='+')
146    elif "*" in model_string:
147        parts = [load_model_info(part)
148                 for part in model_string.split("*")]
149        return mixture.make_mixture_info(parts, operation='*')
150    elif "@" in model_string:
151        p_info, q_info = [load_model_info(part)
152                          for part in model_string.split("@")]
153        return product.make_product_info(p_info, q_info)
154    # We are now dealing with a pure model
155    elif "custom." in model_string:
156        pattern = "custom.([A-Za-z0-9_-]+)"
157        result = re.match(pattern, model_string)
158        if result is None:
159            raise ValueError("Model name in invalid format: " + model_string)
160        model_name = result.group(1)
161        # Use ModelName to find the path to the custom model file
162        model_path = joinpath(CUSTOM_MODEL_PATH, model_name + ".py")
163        if not os.path.isfile(model_path):
164            raise ValueError("The model file {} doesn't exist".format(model_path))
165        kernel_module = custom.load_custom_kernel_module(model_path)
166        return modelinfo.make_model_info(kernel_module)
167    kernel_module = generate.load_kernel_module(model_string)
168    return modelinfo.make_model_info(kernel_module)
169
170
171def build_model(model_info, dtype=None, platform="ocl"):
172    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
173    """
174    Prepare the model for the default execution platform.
175
176    This will return an OpenCL model, a DLL model or a python model depending
177    on the model and the computing platform.
178
179    *model_info* is the model definition structure returned from
180    :func:`load_model_info`.
181
182    *dtype* indicates whether the model should use single or double precision
183    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
184    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
185    than OpenCL for the calculation.
186
187    *platform* should be "dll" to force the dll to be used for C models,
188    otherwise it uses the default "ocl".
189    """
190    composition = model_info.composition
191    if composition is not None:
192        composition_type, parts = composition
193        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
194        if composition_type == 'mixture':
195            return mixture.MixtureModel(model_info, models)
196        elif composition_type == 'product':
197            P, S = models
198            return product.ProductModel(model_info, P, S)
199        else:
200            raise ValueError('unknown mixture type %s'%composition_type)
201
202    # If it is a python model, return it immediately
203    if callable(model_info.Iq):
204        return kernelpy.PyModel(model_info)
205
206    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
207
208    source = generate.make_source(model_info)
209    if platform == "dll":
210        #print("building dll", numpy_dtype)
211        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
212    else:
213        #print("building ocl", numpy_dtype)
214        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
215
216def precompile_dlls(path, dtype="double"):
217    # type: (str, str) -> List[str]
218    """
219    Precompile the dlls for all builtin models, returning a list of dll paths.
220
221    *path* is the directory in which to save the dlls.  It will be created if
222    it does not already exist.
223
224    This can be used when build the windows distribution of sasmodels
225    which may be missing the OpenCL driver and the dll compiler.
226    """
227    numpy_dtype = np.dtype(dtype)
228    if not os.path.exists(path):
229        os.makedirs(path)
230    compiled_dlls = []
231    for model_name in list_models():
232        model_info = load_model_info(model_name)
233        if not callable(model_info.Iq):
234            source = generate.make_source(model_info)['dll']
235            old_path = kerneldll.DLL_PATH
236            try:
237                kerneldll.DLL_PATH = path
238                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
239            finally:
240                kerneldll.DLL_PATH = old_path
241            compiled_dlls.append(dll)
242    return compiled_dlls
243
244def parse_dtype(model_info, dtype=None, platform=None):
245    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
246    """
247    Interpret dtype string, returning np.dtype and fast flag.
248
249    Possible types include 'half', 'single', 'double' and 'quad'.  If the
250    type is 'fast', then this is equivalent to dtype 'single' but using
251    fast native functions rather than those with the precision level
252    guaranteed by the OpenCL standard.  'default' will choose the appropriate
253    default for the model and platform.
254
255    Platform preference can be specfied ("ocl" vs "dll"), with the default
256    being OpenCL if it is availabe.  If the dtype name ends with '!' then
257    platform is forced to be DLL rather than OpenCL.
258
259    This routine ignores the preferences within the model definition.  This
260    is by design.  It allows us to test models in single precision even when
261    we have flagged them as requiring double precision so we can easily check
262    the performance on different platforms without having to change the model
263    definition.
264    """
265    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
266    # If opencl=False OpenCL is switched off
267
268    if platform is None:
269        platform = "ocl"
270    if not kernelcl.use_opencl() or not model_info.opencl:
271        platform = "dll"
272
273    # Check if type indicates dll regardless of which platform is given
274    if dtype is not None and dtype.endswith('!'):
275        platform = "dll"
276        dtype = dtype[:-1]
277
278    # Convert special type names "half", "fast", and "quad"
279    fast = (dtype == "fast")
280    if fast:
281        dtype = "single"
282    elif dtype == "quad":
283        dtype = "longdouble"
284    elif dtype == "half":
285        dtype = "float16"
286
287    # Convert dtype string to numpy dtype.
288    if dtype is None or dtype == "default":
289        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
290                       else generate.F64)
291    else:
292        numpy_dtype = np.dtype(dtype)
293
294    # Make sure that the type is supported by opencl, otherwise use dll
295    if platform == "ocl":
296        env = kernelcl.environment()
297        if not env.has_type(numpy_dtype):
298            platform = "dll"
299            if dtype is None:
300                numpy_dtype = generate.F64
301
302    return numpy_dtype, fast, platform
303
304def list_models_main():
305    # type: () -> None
306    """
307    Run list_models as a main program.  See :func:`list_models` for the
308    kinds of models that can be requested on the command line.
309    """
310    import sys
311    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
312    print("\n".join(list_models(kind)))
313
314def test_composite_order():
315    def test_models(fst, snd):
316        """Confirm that two models produce the same parameters"""
317        fst = load_model(fst)
318        snd = load_model(snd)
319        # Un-disambiguate parameter names so that we can check if the same
320        # parameters are in a pair of composite models. Since each parameter in
321        # the mixture model is tagged as e.g., A_sld, we ought to use a
322        # regex subsitution s/^[A-Z]+_/_/, but removing all uppercase letters
323        # is good enough.
324        fst = [[x for x in p.name if x == x.lower()] for p in fst.info.parameters.kernel_parameters]
325        snd = [[x for x in p.name if x == x.lower()] for p in snd.info.parameters.kernel_parameters]
326        assert sorted(fst) == sorted(snd), "{} != {}".format(fst, snd)
327
328    def build_test(first, second):
329        test = lambda description: test_models(first, second)
330        description = first + " vs. " + second
331        return test, description
332
333    yield build_test(
334        "cylinder+sphere",
335        "sphere+cylinder")
336    yield build_test(
337        "cylinder*sphere",
338        "sphere*cylinder")
339    yield build_test(
340        "cylinder@hardsphere*sphere",
341        "sphere*cylinder@hardsphere")
342    yield build_test(
343        "barbell+sphere*cylinder@hardsphere",
344        "sphere*cylinder@hardsphere+barbell")
345    yield build_test(
346        "barbell+cylinder@hardsphere*sphere",
347        "cylinder@hardsphere*sphere+barbell")
348    yield build_test(
349        "barbell+sphere*cylinder@hardsphere",
350        "barbell+cylinder@hardsphere*sphere")
351    yield build_test(
352        "sphere*cylinder@hardsphere+barbell",
353        "cylinder@hardsphere*sphere+barbell")
354    yield build_test(
355        "barbell+sphere*cylinder@hardsphere",
356        "cylinder@hardsphere*sphere+barbell")
357    yield build_test(
358        "barbell+cylinder@hardsphere*sphere",
359        "sphere*cylinder@hardsphere+barbell")
360
361def test_composite():
362    # type: () -> None
363    """Check that model load works"""
364    #Test the the model produces the parameters that we would expect
365    model = load_model("cylinder@hardsphere*sphere")
366    actual = [p.name for p in model.info.parameters.kernel_parameters]
367    target = ("sld sld_solvent radius length theta phi volfraction"
368              " A_sld A_sld_solvent A_radius").split()
369    assert target == actual, "%s != %s"%(target, actual)
370
371if __name__ == "__main__":
372    list_models_main()
Note: See TracBrowser for help on using the repository browser.