source: sasmodels/sasmodels/core.py @ a69d8cd

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a69d8cd was a69d8cd, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

add support for pytest and use it on travis/appveyor

  • Property mode set to 100644
File size: 14.0 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
[2547694]8    "build_model", "precompile_dlls",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[e65c3ba]12from os.path import basename, join as joinpath
[f734e7d]13from glob import glob
[e65c3ba]14import re
[f734e7d]15
[7ae2b7f]16import numpy as np # type: ignore
[f734e7d]17
[aa4946b]18from . import generate
[6d6508e]19from . import modelinfo
20from . import product
[72a081d]21from . import mixture
[aa4946b]22from . import kernelpy
23from . import kerneldll
[60335cc]24from . import custom
[880a2ed]25
[a557a99]26if os.environ.get("SAS_OPENCL", "").lower() == "none":
[aa4946b]27    HAVE_OPENCL = False
[880a2ed]28else:
29    try:
30        from . import kernelcl
31        HAVE_OPENCL = True
32    except Exception:
33        HAVE_OPENCL = False
[aa4946b]34
[60335cc]35CUSTOM_MODEL_PATH = os.environ.get('SAS_MODELPATH', "")
36if CUSTOM_MODEL_PATH == "":
[e65c3ba]37    CUSTOM_MODEL_PATH = joinpath(os.path.expanduser("~"), ".sasmodels", "custom_models")
38    if not os.path.isdir(CUSTOM_MODEL_PATH):
39        os.makedirs(CUSTOM_MODEL_PATH)
[60335cc]40
[2d81cfe]41# pylint: disable=unused-import
[f619de7]42try:
43    from typing import List, Union, Optional, Any
44    from .kernel import KernelModel
[dd7fc12]45    from .modelinfo import ModelInfo
[f619de7]46except ImportError:
47    pass
[2d81cfe]48# pylint: enable=unused-import
[f619de7]49
[4d76711]50# TODO: refactor composite model support
51# The current load_model_info/build_model does not reuse existing model
52# definitions when loading a composite model, instead reloading and
53# rebuilding the kernel for each component model in the expression.  This
54# is fine in a scripting environment where the model is built when the script
55# starts and is thrown away when the script ends, but may not be the best
56# solution in a long-lived application.  This affects the following functions:
57#
58#    load_model
59#    load_model_info
60#    build_model
[f734e7d]61
[8407d8c]62KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
[2547694]63         "nonmagnetic", "magnetic")
[0b03001]64def list_models(kind=None):
[52e9a45]65    # type: (str) -> List[str]
[aa4946b]66    """
67    Return the list of available models on the model path.
[40a87fa]68
69    *kind* can be one of the following:
70
71        * all: all models
72        * py: python models only
73        * c: compiled models only
74        * single: models which support single precision
75        * double: models which require double precision
[8407d8c]76        * opencl: controls if OpenCL is supperessed
[40a87fa]77        * 1d: models which are 1D only, or 2D using abs(q)
78        * 2d: models which can be 2D
79        * magnetic: models with an sld
80        * nommagnetic: models without an sld
[5124c969]81
82    For multiple conditions, combine with plus.  For example, *c+single+2d*
83    would return all oriented models implemented in C which can be computed
84    accurately with single precision arithmetic.
[aa4946b]85    """
[5124c969]86    if kind and any(k not in KINDS for k in kind.split('+')):
[2547694]87        raise ValueError("kind not in " + ", ".join(KINDS))
[3d9001f]88    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
[f734e7d]89    available_models = [basename(f)[:-3] for f in files]
[5124c969]90    if kind and '+' in kind:
91        all_kinds = kind.split('+')
92        condition = lambda name: all(_matches(name, k) for k in all_kinds)
93    else:
94        condition = lambda name: _matches(name, kind)
95    selected = [name for name in available_models if condition(name)]
[0b03001]96
97    return selected
98
99def _matches(name, kind):
[2547694]100    if kind is None or kind == "all":
[0b03001]101        return True
102    info = load_model_info(name)
103    pars = info.parameters.kernel_parameters
104    if kind == "py" and callable(info.Iq):
105        return True
106    elif kind == "c" and not callable(info.Iq):
107        return True
108    elif kind == "double" and not info.single:
109        return True
[d2d6100]110    elif kind == "single" and info.single:
111        return True
[8407d8c]112    elif kind == "opencl" and info.opencl:
[407bf48]113        return True
[2547694]114    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
[d2d6100]115        return True
[40a87fa]116    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
[0b03001]117        return True
[2547694]118    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
[0b03001]119        return True
[2547694]120    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
[d2d6100]121        return True
[0b03001]122    return False
[f734e7d]123
[f619de7]124def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]125    # type: (str, str, str) -> KernelModel
[b8e5e21]126    """
127    Load model info and build model.
[f619de7]128
[2e66ef5]129    *model_name* is the name of the model, or perhaps a model expression
130    such as sphere*hardsphere or sphere+cylinder.
131
132    *dtype* and *platform* are given by :func:`build_model`.
[b8e5e21]133    """
[f619de7]134    return build_model(load_model_info(model_name),
135                       dtype=dtype, platform=platform)
[aa4946b]136
[61a4bd4]137def load_model_info(model_string):
[f619de7]138    # type: (str) -> modelinfo.ModelInfo
[aa4946b]139    """
140    Load a model definition given the model name.
[1d4017a]141
[481ff64]142    *model_string* is the name of the model, or perhaps a model expression
143    such as sphere*cylinder or sphere+cylinder. Use '@' for a structure
[60335cc]144    factor product, e.g. sphere@hardsphere. Custom models can be specified by
145    prefixing the model name with 'custom.', e.g. 'custom.MyModel+sphere'.
[2e66ef5]146
[1d4017a]147    This returns a handle to the module defining the model.  This can be
148    used with functions in generate to build the docs or extract model info.
[aa4946b]149    """
[ffc2a61]150    if "+" in model_string:
151        parts = [load_model_info(part)
152                 for part in model_string.split("+")]
153        return mixture.make_mixture_info(parts, operation='+')
154    elif "*" in model_string:
155        parts = [load_model_info(part)
156                 for part in model_string.split("*")]
157        return mixture.make_mixture_info(parts, operation='*')
[e68bae9]158    elif "@" in model_string:
159        p_info, q_info = [load_model_info(part)
160                          for part in model_string.split("@")]
161        return product.make_product_info(p_info, q_info)
[ffc2a61]162    # We are now dealing with a pure model
163    elif "custom." in model_string:
164        pattern = "custom.([A-Za-z0-9_-]+)"
165        result = re.match(pattern, model_string)
166        if result is None:
167            raise ValueError("Model name in invalid format: " + model_string)
168        model_name = result.group(1)
169        # Use ModelName to find the path to the custom model file
170        model_path = joinpath(CUSTOM_MODEL_PATH, model_name + ".py")
171        if not os.path.isfile(model_path):
172            raise ValueError("The model file {} doesn't exist".format(model_path))
173        kernel_module = custom.load_custom_kernel_module(model_path)
174        return modelinfo.make_model_info(kernel_module)
175    kernel_module = generate.load_kernel_module(model_string)
176    return modelinfo.make_model_info(kernel_module)
[d19962c]177
178
[17bbadd]179def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]180    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]181    """
182    Prepare the model for the default execution platform.
183
184    This will return an OpenCL model, a DLL model or a python model depending
185    on the model and the computing platform.
186
[17bbadd]187    *model_info* is the model definition structure returned from
188    :func:`load_model_info`.
[bcd3aa3]189
[aa4946b]190    *dtype* indicates whether the model should use single or double precision
[dd7fc12]191    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
192    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
193    than OpenCL for the calculation.
[aa4946b]194
195    *platform* should be "dll" to force the dll to be used for C models,
196    otherwise it uses the default "ocl".
197    """
[6d6508e]198    composition = model_info.composition
[72a081d]199    if composition is not None:
200        composition_type, parts = composition
201        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
202        if composition_type == 'mixture':
203            return mixture.MixtureModel(model_info, models)
204        elif composition_type == 'product':
[e79f0a5]205            P, S = models
[72a081d]206            return product.ProductModel(model_info, P, S)
207        else:
208            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]209
[fa5fd8d]210    # If it is a python model, return it immediately
211    if callable(model_info.Iq):
212        return kernelpy.PyModel(model_info)
213
[def2c1b]214    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[7891a2a]215
[72a081d]216    source = generate.make_source(model_info)
[7891a2a]217    if platform == "dll":
[f2f67a6]218        #print("building dll", numpy_dtype)
[a4280bd]219        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[aa4946b]220    else:
[f2f67a6]221        #print("building ocl", numpy_dtype)
[dd7fc12]222        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]223
[7bf4757]224def precompile_dlls(path, dtype="double"):
[7891a2a]225    # type: (str, str) -> List[str]
[b8e5e21]226    """
[7bf4757]227    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]228
[7bf4757]229    *path* is the directory in which to save the dlls.  It will be created if
230    it does not already exist.
[b8e5e21]231
232    This can be used when build the windows distribution of sasmodels
[7bf4757]233    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]234    """
[7891a2a]235    numpy_dtype = np.dtype(dtype)
[7bf4757]236    if not os.path.exists(path):
237        os.makedirs(path)
238    compiled_dlls = []
239    for model_name in list_models():
240        model_info = load_model_info(model_name)
[a4280bd]241        if not callable(model_info.Iq):
242            source = generate.make_source(model_info)['dll']
[7bf4757]243            old_path = kerneldll.DLL_PATH
244            try:
245                kerneldll.DLL_PATH = path
[def2c1b]246                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]247            finally:
248                kerneldll.DLL_PATH = old_path
249            compiled_dlls.append(dll)
250    return compiled_dlls
[b8e5e21]251
[7891a2a]252def parse_dtype(model_info, dtype=None, platform=None):
253    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]254    """
255    Interpret dtype string, returning np.dtype and fast flag.
256
257    Possible types include 'half', 'single', 'double' and 'quad'.  If the
[2e66ef5]258    type is 'fast', then this is equivalent to dtype 'single' but using
[9e771a3]259    fast native functions rather than those with the precision level
260    guaranteed by the OpenCL standard.  'default' will choose the appropriate
261    default for the model and platform.
[2e66ef5]262
263    Platform preference can be specfied ("ocl" vs "dll"), with the default
264    being OpenCL if it is availabe.  If the dtype name ends with '!' then
265    platform is forced to be DLL rather than OpenCL.
266
267    This routine ignores the preferences within the model definition.  This
268    is by design.  It allows us to test models in single precision even when
269    we have flagged them as requiring double precision so we can easily check
270    the performance on different platforms without having to change the model
271    definition.
[dd7fc12]272    """
[7891a2a]273    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
[8407d8c]274    # If opencl=False OpenCL is switched off
[fe77343]275
[f49675c]276    if platform is None:
[7891a2a]277        platform = "ocl"
[8407d8c]278    if platform == "ocl" and not HAVE_OPENCL or not model_info.opencl:
[7891a2a]279        platform = "dll"
280
281    # Check if type indicates dll regardless of which platform is given
282    if dtype is not None and dtype.endswith('!'):
283        platform = "dll"
[dd7fc12]284        dtype = dtype[:-1]
285
[7891a2a]286    # Convert special type names "half", "fast", and "quad"
[2547694]287    fast = (dtype == "fast")
[7891a2a]288    if fast:
289        dtype = "single"
[2547694]290    elif dtype == "quad":
[7891a2a]291        dtype = "longdouble"
[2547694]292    elif dtype == "half":
[650c6d2]293        dtype = "float16"
[7891a2a]294
295    # Convert dtype string to numpy dtype.
[bb39b4a]296    if dtype is None or dtype == "default":
[2547694]297        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
298                       else generate.F64)
[dd7fc12]299    else:
[7891a2a]300        numpy_dtype = np.dtype(dtype)
[9890053]301
[7891a2a]302    # Make sure that the type is supported by opencl, otherwise use dll
[2547694]303    if platform == "ocl":
[7891a2a]304        env = kernelcl.environment()
305        if not env.has_type(numpy_dtype):
306            platform = "dll"
307            if dtype is None:
308                numpy_dtype = generate.F64
[dd7fc12]309
[7891a2a]310    return numpy_dtype, fast, platform
[0c24a82]311
[2547694]312def list_models_main():
[40a87fa]313    # type: () -> None
314    """
315    Run list_models as a main program.  See :func:`list_models` for the
316    kinds of models that can be requested on the command line.
317    """
[0b03001]318    import sys
319    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
320    print("\n".join(list_models(kind)))
[2547694]321
[a69d8cd]322def test_composite_order():
[7a516d0]323    def test_models(fst, snd):
324        """Confirm that two models produce the same parameters"""
325        fst = load_model(fst)
326        snd = load_model(snd)
[a69d8cd]327        # Un-disambiguate parameter names so that we can check if the same
328        # parameters are in a pair of composite models. Since each parameter in
329        # the mixture model is tagged as e.g., A_sld, we ought to use a
330        # regex subsitution s/^[A-Z]+_/_/, but removing all uppercase letters
331        # is good enough.
[7a516d0]332        fst = [[x for x in p.name if x == x.lower()] for p in fst.info.parameters.kernel_parameters]
333        snd = [[x for x in p.name if x == x.lower()] for p in snd.info.parameters.kernel_parameters]
334        assert sorted(fst) == sorted(snd), "{} != {}".format(fst, snd)
335
[a69d8cd]336    def build_test(first, second):
337        test = lambda description: test_models(first, second)
338        description = first + " vs. " + second
339        return test, description
[7a516d0]340
[a69d8cd]341    yield build_test(
[7a516d0]342        "cylinder+sphere",
343        "sphere+cylinder")
[a69d8cd]344    yield build_test(
[7a516d0]345        "cylinder*sphere",
346        "sphere*cylinder")
[a69d8cd]347    yield build_test(
[7a516d0]348        "cylinder@hardsphere*sphere",
349        "sphere*cylinder@hardsphere")
[a69d8cd]350    yield build_test(
[7a516d0]351        "barbell+sphere*cylinder@hardsphere",
352        "sphere*cylinder@hardsphere+barbell")
[a69d8cd]353    yield build_test(
[7a516d0]354        "barbell+cylinder@hardsphere*sphere",
355        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]356    yield build_test(
[7a516d0]357        "barbell+sphere*cylinder@hardsphere",
358        "barbell+cylinder@hardsphere*sphere")
[a69d8cd]359    yield build_test(
[7a516d0]360        "sphere*cylinder@hardsphere+barbell",
361        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]362    yield build_test(
[7a516d0]363        "barbell+sphere*cylinder@hardsphere",
364        "cylinder@hardsphere*sphere+barbell")
[a69d8cd]365    yield build_test(
[7a516d0]366        "barbell+cylinder@hardsphere*sphere",
367        "sphere*cylinder@hardsphere+barbell")
368
[a69d8cd]369def test_composite():
370    # type: () -> None
371    """Check that model load works"""
[7a516d0]372    #Test the the model produces the parameters that we would expect
373    model = load_model("cylinder@hardsphere*sphere")
374    actual = [p.name for p in model.info.parameters.kernel_parameters]
375    target = ("sld sld_solvent radius length theta phi volfraction"
376              " A_sld A_sld_solvent A_radius").split()
377    assert target == actual, "%s != %s"%(target, actual)
378
[2547694]379if __name__ == "__main__":
380    list_models_main()
Note: See TracBrowser for help on using the repository browser.