source: sasmodels/sasmodels/core.py @ 01c8d9e

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 01c8d9e was 01c8d9e, checked in by Suczewski <ges3@…>, 6 years ago

beta approximation, first pass

  • Property mode set to 100644
File size: 13.8 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
[2547694]8    "build_model", "precompile_dlls",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[e65c3ba]12from os.path import basename, join as joinpath
[f734e7d]13from glob import glob
[e65c3ba]14import re
[f734e7d]15
[7ae2b7f]16import numpy as np # type: ignore
[f734e7d]17
[aa4946b]18from . import generate
[6d6508e]19from . import modelinfo
20from . import product
[72a081d]21from . import mixture
[aa4946b]22from . import kernelpy
[6dba2f0]23from . import kernelcl
[aa4946b]24from . import kerneldll
[60335cc]25from . import custom
[880a2ed]26
[2d81cfe]27# pylint: disable=unused-import
[f619de7]28try:
29    from typing import List, Union, Optional, Any
30    from .kernel import KernelModel
[dd7fc12]31    from .modelinfo import ModelInfo
[f619de7]32except ImportError:
33    pass
[2d81cfe]34# pylint: enable=unused-import
[f619de7]35
[3221de0]36CUSTOM_MODEL_PATH = os.environ.get('SAS_MODELPATH', "")
37if CUSTOM_MODEL_PATH == "":
38    CUSTOM_MODEL_PATH = joinpath(os.path.expanduser("~"), ".sasmodels", "custom_models")
39    if not os.path.isdir(CUSTOM_MODEL_PATH):
40        os.makedirs(CUSTOM_MODEL_PATH)
41
[4d76711]42# TODO: refactor composite model support
43# The current load_model_info/build_model does not reuse existing model
44# definitions when loading a composite model, instead reloading and
45# rebuilding the kernel for each component model in the expression.  This
46# is fine in a scripting environment where the model is built when the script
47# starts and is thrown away when the script ends, but may not be the best
48# solution in a long-lived application.  This affects the following functions:
49#
50#    load_model
51#    load_model_info
52#    build_model
[f734e7d]53
[8407d8c]54KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
[2547694]55         "nonmagnetic", "magnetic")
[0b03001]56def list_models(kind=None):
[52e9a45]57    # type: (str) -> List[str]
[aa4946b]58    """
59    Return the list of available models on the model path.
[40a87fa]60
61    *kind* can be one of the following:
62
63        * all: all models
64        * py: python models only
65        * c: compiled models only
66        * single: models which support single precision
67        * double: models which require double precision
[8407d8c]68        * opencl: controls if OpenCL is supperessed
[40a87fa]69        * 1d: models which are 1D only, or 2D using abs(q)
70        * 2d: models which can be 2D
71        * magnetic: models with an sld
72        * nommagnetic: models without an sld
[5124c969]73
74    For multiple conditions, combine with plus.  For example, *c+single+2d*
75    would return all oriented models implemented in C which can be computed
76    accurately with single precision arithmetic.
[aa4946b]77    """
[5124c969]78    if kind and any(k not in KINDS for k in kind.split('+')):
[2547694]79        raise ValueError("kind not in " + ", ".join(KINDS))
[3d9001f]80    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
[f734e7d]81    available_models = [basename(f)[:-3] for f in files]
[5124c969]82    if kind and '+' in kind:
83        all_kinds = kind.split('+')
84        condition = lambda name: all(_matches(name, k) for k in all_kinds)
85    else:
86        condition = lambda name: _matches(name, kind)
87    selected = [name for name in available_models if condition(name)]
[0b03001]88
89    return selected
90
91def _matches(name, kind):
[2547694]92    if kind is None or kind == "all":
[0b03001]93        return True
94    info = load_model_info(name)
95    pars = info.parameters.kernel_parameters
96    if kind == "py" and callable(info.Iq):
97        return True
98    elif kind == "c" and not callable(info.Iq):
99        return True
100    elif kind == "double" and not info.single:
101        return True
[d2d6100]102    elif kind == "single" and info.single:
103        return True
[8407d8c]104    elif kind == "opencl" and info.opencl:
[407bf48]105        return True
[2547694]106    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
[d2d6100]107        return True
[40a87fa]108    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
[0b03001]109        return True
[2547694]110    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
[0b03001]111        return True
[2547694]112    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
[d2d6100]113        return True
[0b03001]114    return False
[f734e7d]115
[f619de7]116def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]117    # type: (str, str, str) -> KernelModel
[b8e5e21]118    """
119    Load model info and build model.
[f619de7]120
[2e66ef5]121    *model_name* is the name of the model, or perhaps a model expression
122    such as sphere*hardsphere or sphere+cylinder.
123
124    *dtype* and *platform* are given by :func:`build_model`.
[b8e5e21]125    """
[f619de7]126    return build_model(load_model_info(model_name),
127                       dtype=dtype, platform=platform)
[aa4946b]128
[61a4bd4]129def load_model_info(model_string):
[f619de7]130    # type: (str) -> modelinfo.ModelInfo
[aa4946b]131    """
132    Load a model definition given the model name.
[1d4017a]133
[481ff64]134    *model_string* is the name of the model, or perhaps a model expression
135    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
[60335cc]136    factor product, e.g. sphere@hardsphere. Custom models can be specified by
137    prefixing the model name with 'custom.', e.g. 'custom.MyModel+sphere'.
[2e66ef5]138
[1d4017a]139    This returns a handle to the module defining the model.  This can be
140    used with functions in generate to build the docs or extract model info.
[aa4946b]141    """
[01c8d9e]142
[ffc2a61]143    if "+" in model_string:
144        parts = [load_model_info(part)
145                 for part in model_string.split("+")]
146        return mixture.make_mixture_info(parts, operation='+')
147    elif "*" in model_string:
148        parts = [load_model_info(part)
149                 for part in model_string.split("*")]
150        return mixture.make_mixture_info(parts, operation='*')
[e68bae9]151    elif "@" in model_string:
152        p_info, q_info = [load_model_info(part)
153                          for part in model_string.split("@")]
154        return product.make_product_info(p_info, q_info)
[ffc2a61]155    # We are now dealing with a pure model
156    elif "custom." in model_string:
157        pattern = "custom.([A-Za-z0-9_-]+)"
158        result = re.match(pattern, model_string)
159        if result is None:
160            raise ValueError("Model name in invalid format: " + model_string)
161        model_name = result.group(1)
162        # Use ModelName to find the path to the custom model file
163        model_path = joinpath(CUSTOM_MODEL_PATH, model_name + ".py")
164        if not os.path.isfile(model_path):
165            raise ValueError("The model file {} doesn't exist".format(model_path))
166        kernel_module = custom.load_custom_kernel_module(model_path)
167        return modelinfo.make_model_info(kernel_module)
168    kernel_module = generate.load_kernel_module(model_string)
169    return modelinfo.make_model_info(kernel_module)
[d19962c]170
171
[17bbadd]172def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]173    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]174    """
175    Prepare the model for the default execution platform.
176
177    This will return an OpenCL model, a DLL model or a python model depending
178    on the model and the computing platform.
179
[17bbadd]180    *model_info* is the model definition structure returned from
181    :func:`load_model_info`.
[bcd3aa3]182
[aa4946b]183    *dtype* indicates whether the model should use single or double precision
[dd7fc12]184    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
185    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
186    than OpenCL for the calculation.
[aa4946b]187
188    *platform* should be "dll" to force the dll to be used for C models,
189    otherwise it uses the default "ocl".
190    """
[6d6508e]191    composition = model_info.composition
[72a081d]192    if composition is not None:
193        composition_type, parts = composition
194        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
195        if composition_type == 'mixture':
196            return mixture.MixtureModel(model_info, models)
197        elif composition_type == 'product':
[e79f0a5]198            P, S = models
[72a081d]199            return product.ProductModel(model_info, P, S)
200        else:
201            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]202
[fa5fd8d]203    # If it is a python model, return it immediately
204    if callable(model_info.Iq):
205        return kernelpy.PyModel(model_info)
206
[def2c1b]207    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[72a081d]208    source = generate.make_source(model_info)
[7891a2a]209    if platform == "dll":
[f2f67a6]210        #print("building dll", numpy_dtype)
[a4280bd]211        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[aa4946b]212    else:
[f2f67a6]213        #print("building ocl", numpy_dtype)
[dd7fc12]214        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]215
[7bf4757]216def precompile_dlls(path, dtype="double"):
[7891a2a]217    # type: (str, str) -> List[str]
[b8e5e21]218    """
[7bf4757]219    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]220
[7bf4757]221    *path* is the directory in which to save the dlls.  It will be created if
222    it does not already exist.
[b8e5e21]223
224    This can be used when build the windows distribution of sasmodels
[7bf4757]225    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]226    """
[7891a2a]227    numpy_dtype = np.dtype(dtype)
[7bf4757]228    if not os.path.exists(path):
229        os.makedirs(path)
230    compiled_dlls = []
231    for model_name in list_models():
232        model_info = load_model_info(model_name)
[a4280bd]233        if not callable(model_info.Iq):
234            source = generate.make_source(model_info)['dll']
[7bf4757]235            old_path = kerneldll.DLL_PATH
236            try:
237                kerneldll.DLL_PATH = path
[def2c1b]238                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]239            finally:
240                kerneldll.DLL_PATH = old_path
241            compiled_dlls.append(dll)
242    return compiled_dlls
[b8e5e21]243
[7891a2a]244def parse_dtype(model_info, dtype=None, platform=None):
245    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]246    """
247    Interpret dtype string, returning np.dtype and fast flag.
248
249    Possible types include 'half', 'single', 'double' and 'quad'.  If the
[2e66ef5]250    type is 'fast', then this is equivalent to dtype 'single' but using
[9e771a3]251    fast native functions rather than those with the precision level
252    guaranteed by the OpenCL standard.  'default' will choose the appropriate
253    default for the model and platform.
[2e66ef5]254
255    Platform preference can be specfied ("ocl" vs "dll"), with the default
256    being OpenCL if it is availabe.  If the dtype name ends with '!' then
257    platform is forced to be DLL rather than OpenCL.
258
259    This routine ignores the preferences within the model definition.  This
260    is by design.  It allows us to test models in single precision even when
261    we have flagged them as requiring double precision so we can easily check
262    the performance on different platforms without having to change the model
263    definition.
[dd7fc12]264    """
[7891a2a]265    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
[8407d8c]266    # If opencl=False OpenCL is switched off
[f49675c]267    if platform is None:
[7891a2a]268        platform = "ocl"
[3221de0]269    if not kernelcl.use_opencl() or not model_info.opencl:
[7891a2a]270        platform = "dll"
271
272    # Check if type indicates dll regardless of which platform is given
273    if dtype is not None and dtype.endswith('!'):
274        platform = "dll"
[dd7fc12]275        dtype = dtype[:-1]
276
[7891a2a]277    # Convert special type names "half", "fast", and "quad"
[2547694]278    fast = (dtype == "fast")
[7891a2a]279    if fast:
280        dtype = "single"
[2547694]281    elif dtype == "quad":
[7891a2a]282        dtype = "longdouble"
[2547694]283    elif dtype == "half":
[650c6d2]284        dtype = "float16"
[7891a2a]285
286    # Convert dtype string to numpy dtype.
[bb39b4a]287    if dtype is None or dtype == "default":
[2547694]288        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
289                       else generate.F64)
[dd7fc12]290    else:
[7891a2a]291        numpy_dtype = np.dtype(dtype)
292    # Make sure that the type is supported by opencl, otherwise use dll
[2547694]293    if platform == "ocl":
[7891a2a]294        env = kernelcl.environment()
295        if not env.has_type(numpy_dtype):
296            platform = "dll"
297            if dtype is None:
298                numpy_dtype = generate.F64
[dd7fc12]299
[7891a2a]300    return numpy_dtype, fast, platform
[0c24a82]301
[2547694]302def list_models_main():
[40a87fa]303    # type: () -> None
304    """
305    Run list_models as a main program.  See :func:`list_models` for the
306    kinds of models that can be requested on the command line.
307    """
[0b03001]308    import sys
309    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
310    print("\n".join(list_models(kind)))
[2547694]311
[a69d8cd]312def test_composite_order():
[7a516d0]313    def test_models(fst, snd):
314        """Confirm that two models produce the same parameters"""
315        fst = load_model(fst)
316        snd = load_model(snd)
[a69d8cd]317        # Un-disambiguate parameter names so that we can check if the same
318        # parameters are in a pair of composite models. Since each parameter in
319        # the mixture model is tagged as e.g., A_sld, we ought to use a
320        # regex subsitution s/^[A-Z]+_/_/, but removing all uppercase letters
321        # is good enough.
[7a516d0]322        fst = [[x for x in p.name if x == x.lower()] for p in fst.info.parameters.kernel_parameters]
323        snd = [[x for x in p.name if x == x.lower()] for p in snd.info.parameters.kernel_parameters]
324        assert sorted(fst) == sorted(snd), "{} != {}".format(fst, snd)
325
[a69d8cd]326    def build_test(first, second):
327        test = lambda description: test_models(first, second)
328        description = first + " vs. " + second
329        return test, description
[7a516d0]330
[a69d8cd]331    yield build_test(
[7a516d0]332        "cylinder+sphere",
333        "sphere+cylinder")
[a69d8cd]334    yield build_test(
[7a516d0]335        "cylinder*sphere",
336        "sphere*cylinder")
[a69d8cd]337    yield build_test(
[7a516d0]338        "cylinder@hardsphere*sphere",
339        "sphere*cylinder@hardsphere")
[a69d8cd]340    yield build_test(
[7a516d0]341        "barbell+sphere*cylinder@hardsphere",
342        "sphere*cylinder@hardsphere+barbell")
[a69d8cd]343    yield build_test(
[7a516d0]344        "barbell+cylinder@hardsphere*sphere",
345        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]346    yield build_test(
[7a516d0]347        "barbell+sphere*cylinder@hardsphere",
348        "barbell+cylinder@hardsphere*sphere")
[a69d8cd]349    yield build_test(
[7a516d0]350        "sphere*cylinder@hardsphere+barbell",
351        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]352    yield build_test(
[7a516d0]353        "barbell+sphere*cylinder@hardsphere",
354        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]355    yield build_test(
[7a516d0]356        "barbell+cylinder@hardsphere*sphere",
357        "sphere*cylinder@hardsphere+barbell")
358
[a69d8cd]359def test_composite():
360    # type: () -> None
361    """Check that model load works"""
[7a516d0]362    #Test the the model produces the parameters that we would expect
363    model = load_model("cylinder@hardsphere*sphere")
364    actual = [p.name for p in model.info.parameters.kernel_parameters]
365    target = ("sld sld_solvent radius length theta phi volfraction"
366              " A_sld A_sld_solvent A_radius").split()
367    assert target == actual, "%s != %s"%(target, actual)
368
[2547694]369if __name__ == "__main__":
370    list_models_main()
Note: See TracBrowser for help on using the repository browser.