source: sasmodels/sasmodels/core.py @ d8b7efa

ticket-1257-vesicle-productticket_1156ticket_822_more_unit_tests
Last change on this file since d8b7efa was d8b7efa, checked in by richardh, 5 years ago

more changes to variable names, but now lost radius_effective in S(Q), so this is not working yet

  • Property mode set to 100644
File size: 15.1 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
[2547694]8    "build_model", "precompile_dlls",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[e65c3ba]12from os.path import basename, join as joinpath
[f734e7d]13from glob import glob
[e65c3ba]14import re
[f734e7d]15
[7ae2b7f]16import numpy as np # type: ignore
[f734e7d]17
[aa4946b]18from . import generate
[6d6508e]19from . import modelinfo
20from . import product
[72a081d]21from . import mixture
[aa4946b]22from . import kernelpy
[b0de252]23from . import kernelcuda
[6dba2f0]24from . import kernelcl
[aa4946b]25from . import kerneldll
[60335cc]26from . import custom
[880a2ed]27
[2d81cfe]28# pylint: disable=unused-import
[f619de7]29try:
30    from typing import List, Union, Optional, Any
31    from .kernel import KernelModel
[dd7fc12]32    from .modelinfo import ModelInfo
[f619de7]33except ImportError:
34    pass
[2d81cfe]35# pylint: enable=unused-import
[f619de7]36
[3221de0]37CUSTOM_MODEL_PATH = os.environ.get('SAS_MODELPATH', "")
38if CUSTOM_MODEL_PATH == "":
39    CUSTOM_MODEL_PATH = joinpath(os.path.expanduser("~"), ".sasmodels", "custom_models")
[4341dd4]40    #if not os.path.isdir(CUSTOM_MODEL_PATH):
41    #    os.makedirs(CUSTOM_MODEL_PATH)
[3221de0]42
[4d76711]43# TODO: refactor composite model support
44# The current load_model_info/build_model does not reuse existing model
45# definitions when loading a composite model, instead reloading and
46# rebuilding the kernel for each component model in the expression.  This
47# is fine in a scripting environment where the model is built when the script
48# starts and is thrown away when the script ends, but may not be the best
49# solution in a long-lived application.  This affects the following functions:
50#
51#    load_model
52#    load_model_info
53#    build_model
[f734e7d]54
[8407d8c]55KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
[2547694]56         "nonmagnetic", "magnetic")
[0b03001]57def list_models(kind=None):
[52e9a45]58    # type: (str) -> List[str]
[aa4946b]59    """
60    Return the list of available models on the model path.
[40a87fa]61
62    *kind* can be one of the following:
63
64        * all: all models
65        * py: python models only
[d92182f]66        * c: c models only
67        * single: c models which support single precision
68        * double: c models which require double precision
69        * opencl: c models which run in opencl
70        * dll: c models which do not run in opencl
71        * 1d: models without orientation
72        * 2d: models with orientation
73        * magnetic: models supporting magnetic sld
74        * nommagnetic: models without magnetic parameter
[5124c969]75
76    For multiple conditions, combine with plus.  For example, *c+single+2d*
77    would return all oriented models implemented in C which can be computed
78    accurately with single precision arithmetic.
[aa4946b]79    """
[5124c969]80    if kind and any(k not in KINDS for k in kind.split('+')):
[2547694]81        raise ValueError("kind not in " + ", ".join(KINDS))
[3d9001f]82    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
[f734e7d]83    available_models = [basename(f)[:-3] for f in files]
[5124c969]84    if kind and '+' in kind:
85        all_kinds = kind.split('+')
86        condition = lambda name: all(_matches(name, k) for k in all_kinds)
87    else:
88        condition = lambda name: _matches(name, kind)
89    selected = [name for name in available_models if condition(name)]
[0b03001]90
91    return selected
92
93def _matches(name, kind):
[2547694]94    if kind is None or kind == "all":
[0b03001]95        return True
96    info = load_model_info(name)
97    pars = info.parameters.kernel_parameters
[d92182f]98    # TODO: may be adding Fq to the list at some point
99    is_pure_py = callable(info.Iq)
100    if kind == "py":
101        return is_pure_py
102    elif kind == "c":
103        return not is_pure_py
104    elif kind == "double":
105        return not info.single and not is_pure_py
106    elif kind == "single":
107        return info.single and not is_pure_py
108    elif kind == "opencl":
109        return info.opencl
110    elif kind == "dll":
111        return not info.opencl and not is_pure_py
112    elif kind == "2d":
113        return any(p.type == 'orientation' for p in pars)
114    elif kind == "1d":
115        return all(p.type != 'orientation' for p in pars)
116    elif kind == "magnetic":
117        return any(p.type == 'sld' for p in pars)
118    elif kind == "nonmagnetic":
119        return not any(p.type == 'sld' for p in pars)
[0b03001]120    return False
[f734e7d]121
[f619de7]122def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]123    # type: (str, str, str) -> KernelModel
[b8e5e21]124    """
125    Load model info and build model.
[f619de7]126
[2e66ef5]127    *model_name* is the name of the model, or perhaps a model expression
128    such as sphere*hardsphere or sphere+cylinder.
129
130    *dtype* and *platform* are given by :func:`build_model`.
[b8e5e21]131    """
[f619de7]132    return build_model(load_model_info(model_name),
133                       dtype=dtype, platform=platform)
[aa4946b]134
[61a4bd4]135def load_model_info(model_string):
[f619de7]136    # type: (str) -> modelinfo.ModelInfo
[aa4946b]137    """
138    Load a model definition given the model name.
[1d4017a]139
[481ff64]140    *model_string* is the name of the model, or perhaps a model expression
141    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
[60335cc]142    factor product, e.g. sphere@hardsphere. Custom models can be specified by
143    prefixing the model name with 'custom.', e.g. 'custom.MyModel+sphere'.
[2e66ef5]144
[1d4017a]145    This returns a handle to the module defining the model.  This can be
146    used with functions in generate to build the docs or extract model info.
[aa4946b]147    """
[ffc2a61]148    if "+" in model_string:
149        parts = [load_model_info(part)
150                 for part in model_string.split("+")]
151        return mixture.make_mixture_info(parts, operation='+')
152    elif "*" in model_string:
153        parts = [load_model_info(part)
154                 for part in model_string.split("*")]
155        return mixture.make_mixture_info(parts, operation='*')
[e68bae9]156    elif "@" in model_string:
157        p_info, q_info = [load_model_info(part)
158                          for part in model_string.split("@")]
159        return product.make_product_info(p_info, q_info)
[ffc2a61]160    # We are now dealing with a pure model
161    elif "custom." in model_string:
162        pattern = "custom.([A-Za-z0-9_-]+)"
163        result = re.match(pattern, model_string)
164        if result is None:
165            raise ValueError("Model name in invalid format: " + model_string)
166        model_name = result.group(1)
167        # Use ModelName to find the path to the custom model file
168        model_path = joinpath(CUSTOM_MODEL_PATH, model_name + ".py")
169        if not os.path.isfile(model_path):
170            raise ValueError("The model file {} doesn't exist".format(model_path))
171        kernel_module = custom.load_custom_kernel_module(model_path)
172        return modelinfo.make_model_info(kernel_module)
173    kernel_module = generate.load_kernel_module(model_string)
174    return modelinfo.make_model_info(kernel_module)
[d19962c]175
176
[17bbadd]177def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]178    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]179    """
180    Prepare the model for the default execution platform.
181
182    This will return an OpenCL model, a DLL model or a python model depending
183    on the model and the computing platform.
184
[17bbadd]185    *model_info* is the model definition structure returned from
186    :func:`load_model_info`.
[bcd3aa3]187
[aa4946b]188    *dtype* indicates whether the model should use single or double precision
[dd7fc12]189    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
190    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
191    than OpenCL for the calculation.
[aa4946b]192
193    *platform* should be "dll" to force the dll to be used for C models,
194    otherwise it uses the default "ocl".
195    """
[6d6508e]196    composition = model_info.composition
[72a081d]197    if composition is not None:
198        composition_type, parts = composition
199        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
200        if composition_type == 'mixture':
201            return mixture.MixtureModel(model_info, models)
202        elif composition_type == 'product':
[e79f0a5]203            P, S = models
[72a081d]204            return product.ProductModel(model_info, P, S)
205        else:
206            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]207
[fa5fd8d]208    # If it is a python model, return it immediately
209    if callable(model_info.Iq):
210        return kernelpy.PyModel(model_info)
211
[def2c1b]212    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[72a081d]213    source = generate.make_source(model_info)
[7891a2a]214    if platform == "dll":
[f2f67a6]215        #print("building dll", numpy_dtype)
[a4280bd]216        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[b0de252]217    elif platform == "cuda":
[0db7dbd]218        return kernelcuda.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]219    else:
[f2f67a6]220        #print("building ocl", numpy_dtype)
[dd7fc12]221        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]222
[7bf4757]223def precompile_dlls(path, dtype="double"):
[7891a2a]224    # type: (str, str) -> List[str]
[b8e5e21]225    """
[7bf4757]226    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]227
[7bf4757]228    *path* is the directory in which to save the dlls.  It will be created if
229    it does not already exist.
[b8e5e21]230
231    This can be used when build the windows distribution of sasmodels
[7bf4757]232    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]233    """
[7891a2a]234    numpy_dtype = np.dtype(dtype)
[7bf4757]235    if not os.path.exists(path):
236        os.makedirs(path)
237    compiled_dlls = []
238    for model_name in list_models():
239        model_info = load_model_info(model_name)
[a4280bd]240        if not callable(model_info.Iq):
241            source = generate.make_source(model_info)['dll']
[2dcd6e7]242            old_path = kerneldll.SAS_DLL_PATH
[7bf4757]243            try:
[2dcd6e7]244                kerneldll.SAS_DLL_PATH = path
[def2c1b]245                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]246            finally:
[2dcd6e7]247                kerneldll.SAS_DLL_PATH = old_path
[7bf4757]248            compiled_dlls.append(dll)
249    return compiled_dlls
[b8e5e21]250
[7891a2a]251def parse_dtype(model_info, dtype=None, platform=None):
252    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]253    """
[b0de252]254    Interpret dtype string, returning np.dtype, fast flag and platform.
[dd7fc12]255
256    Possible types include 'half', 'single', 'double' and 'quad'.  If the
[2e66ef5]257    type is 'fast', then this is equivalent to dtype 'single' but using
[9e771a3]258    fast native functions rather than those with the precision level
259    guaranteed by the OpenCL standard.  'default' will choose the appropriate
260    default for the model and platform.
[2e66ef5]261
[b0de252]262    Platform preference can be specfied ("ocl", "cuda", "dll"), with the
263    default being OpenCL or CUDA if available, otherwise DLL.  If the dtype
264    name ends with '!' then platform is forced to be DLL rather than GPU.
265    The default platform is set by the environment variable SAS_OPENCL,
266    SAS_OPENCL=driver:device for OpenCL, SAS_OPENCL=cuda:device for CUDA
267    or SAS_OPENCL=none for DLL.
[2e66ef5]268
269    This routine ignores the preferences within the model definition.  This
270    is by design.  It allows us to test models in single precision even when
271    we have flagged them as requiring double precision so we can easily check
272    the performance on different platforms without having to change the model
273    definition.
[dd7fc12]274    """
[7891a2a]275    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
[8407d8c]276    # If opencl=False OpenCL is switched off
[f49675c]277    if platform is None:
[7891a2a]278        platform = "ocl"
279
280    # Check if type indicates dll regardless of which platform is given
281    if dtype is not None and dtype.endswith('!'):
282        platform = "dll"
[dd7fc12]283        dtype = dtype[:-1]
284
[b0de252]285    # Make sure model allows opencl/gpu
286    if not model_info.opencl:
287        platform = "dll"
288
289    # Make sure opencl is available, or fallback to cuda then to dll
290    if platform == "ocl" and not kernelcl.use_opencl():
291        platform = "cuda" if kernelcuda.use_cuda() else "dll"
292
[7891a2a]293    # Convert special type names "half", "fast", and "quad"
[2547694]294    fast = (dtype == "fast")
[7891a2a]295    if fast:
296        dtype = "single"
[2547694]297    elif dtype == "quad":
[7891a2a]298        dtype = "longdouble"
[2547694]299    elif dtype == "half":
[650c6d2]300        dtype = "float16"
[7891a2a]301
[b0de252]302    # Convert dtype string to numpy dtype.  Use single precision for GPU
303    # if model allows it, otherwise use double precision.
[bb39b4a]304    if dtype is None or dtype == "default":
[b0de252]305        numpy_dtype = (generate.F32 if model_info.single and platform in ("ocl", "cuda")
[2547694]306                       else generate.F64)
[dd7fc12]307    else:
[7891a2a]308        numpy_dtype = np.dtype(dtype)
[c036ddb]309
[b0de252]310    # Make sure that the type is supported by GPU, otherwise use dll
[2547694]311    if platform == "ocl":
[7891a2a]312        env = kernelcl.environment()
[b0de252]313    elif platform == "cuda":
314        env = kernelcuda.environment()
315    else:
316        env = None
317    if env is not None and not env.has_type(numpy_dtype):
318        platform = "dll"
319        if dtype is None:
320            numpy_dtype = generate.F64
[dd7fc12]321
[7891a2a]322    return numpy_dtype, fast, platform
[0c24a82]323
[a69d8cd]324def test_composite_order():
[b297ba9]325    """
326    Check that mixture models produce the same result independent of ordder.
327    """
[7a516d0]328    def test_models(fst, snd):
329        """Confirm that two models produce the same parameters"""
330        fst = load_model(fst)
331        snd = load_model(snd)
[a69d8cd]332        # Un-disambiguate parameter names so that we can check if the same
333        # parameters are in a pair of composite models. Since each parameter in
334        # the mixture model is tagged as e.g., A_sld, we ought to use a
335        # regex subsitution s/^[A-Z]+_/_/, but removing all uppercase letters
336        # is good enough.
[7a516d0]337        fst = [[x for x in p.name if x == x.lower()] for p in fst.info.parameters.kernel_parameters]
338        snd = [[x for x in p.name if x == x.lower()] for p in snd.info.parameters.kernel_parameters]
339        assert sorted(fst) == sorted(snd), "{} != {}".format(fst, snd)
340
[a69d8cd]341    def build_test(first, second):
[b297ba9]342        """Construct pair model test"""
[a69d8cd]343        test = lambda description: test_models(first, second)
344        description = first + " vs. " + second
345        return test, description
[7a516d0]346
[a69d8cd]347    yield build_test(
[7a516d0]348        "cylinder+sphere",
349        "sphere+cylinder")
[a69d8cd]350    yield build_test(
[7a516d0]351        "cylinder*sphere",
352        "sphere*cylinder")
[a69d8cd]353    yield build_test(
[7a516d0]354        "cylinder@hardsphere*sphere",
355        "sphere*cylinder@hardsphere")
[a69d8cd]356    yield build_test(
[7a516d0]357        "barbell+sphere*cylinder@hardsphere",
358        "sphere*cylinder@hardsphere+barbell")
[a69d8cd]359    yield build_test(
[7a516d0]360        "barbell+cylinder@hardsphere*sphere",
361        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]362    yield build_test(
[7a516d0]363        "barbell+sphere*cylinder@hardsphere",
364        "barbell+cylinder@hardsphere*sphere")
[a69d8cd]365    yield build_test(
[7a516d0]366        "sphere*cylinder@hardsphere+barbell",
367        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]368    yield build_test(
[7a516d0]369        "barbell+sphere*cylinder@hardsphere",
370        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]371    yield build_test(
[7a516d0]372        "barbell+cylinder@hardsphere*sphere",
373        "sphere*cylinder@hardsphere+barbell")
374
[a69d8cd]375def test_composite():
376    # type: () -> None
377    """Check that model load works"""
[7a516d0]378    #Test the the model produces the parameters that we would expect
379    model = load_model("cylinder@hardsphere*sphere")
380    actual = [p.name for p in model.info.parameters.kernel_parameters]
[5399809]381    target = ("sld sld_solvent radius length theta phi"
[ee60aa7]382              " radius_effective volfraction "
[d8b7efa]383              " structure_factor_type effective_radius_type"
[5399809]384              " A_sld A_sld_solvent A_radius").split()
[7a516d0]385    assert target == actual, "%s != %s"%(target, actual)
386
[d92182f]387def list_models_main():
[b297ba9]388    # type: () -> int
[d92182f]389    """
390    Run list_models as a main program.  See :func:`list_models` for the
391    kinds of models that can be requested on the command line.
392    """
393    import sys
394    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
395    try:
396        models = list_models(kind)
[b297ba9]397        print("\n".join(models))
398    except Exception:
[d92182f]399        print(list_models.__doc__)
400        return 1
[b297ba9]401    return 0
[d92182f]402
[2547694]403if __name__ == "__main__":
404    list_models_main()
Note: See TracBrowser for help on using the repository browser.