source: sasmodels/sasmodels/core.py @ 5124c969

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5124c969 was 5124c969, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

allow multicompare to select opencl+single+1d

  • Property mode set to 100644
File size: 9.6 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
[2547694]8    "build_model", "precompile_dlls",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[2547694]12from os.path import basename, dirname, join as joinpath
[f734e7d]13from glob import glob
14
[7ae2b7f]15import numpy as np # type: ignore
[f734e7d]16
[aa4946b]17from . import generate
[6d6508e]18from . import modelinfo
19from . import product
[72a081d]20from . import mixture
[aa4946b]21from . import kernelpy
22from . import kerneldll
[880a2ed]23
[a557a99]24if os.environ.get("SAS_OPENCL", "").lower() == "none":
[aa4946b]25    HAVE_OPENCL = False
[880a2ed]26else:
27    try:
28        from . import kernelcl
29        HAVE_OPENCL = True
30    except Exception:
31        HAVE_OPENCL = False
[aa4946b]32
[f619de7]33try:
34    from typing import List, Union, Optional, Any
35    from .kernel import KernelModel
[dd7fc12]36    from .modelinfo import ModelInfo
[f619de7]37except ImportError:
38    pass
39
[4d76711]40# TODO: refactor composite model support
41# The current load_model_info/build_model does not reuse existing model
42# definitions when loading a composite model, instead reloading and
43# rebuilding the kernel for each component model in the expression.  This
44# is fine in a scripting environment where the model is built when the script
45# starts and is thrown away when the script ends, but may not be the best
46# solution in a long-lived application.  This affects the following functions:
47#
48#    load_model
49#    load_model_info
50#    build_model
[f734e7d]51
[8407d8c]52KINDS = ("all", "py", "c", "double", "single", "opencl", "1d", "2d",
[2547694]53         "nonmagnetic", "magnetic")
[0b03001]54def list_models(kind=None):
[52e9a45]55    # type: (str) -> List[str]
[aa4946b]56    """
57    Return the list of available models on the model path.
[40a87fa]58
59    *kind* can be one of the following:
60
61        * all: all models
62        * py: python models only
63        * c: compiled models only
64        * single: models which support single precision
65        * double: models which require double precision
[8407d8c]66        * opencl: controls if OpenCL is supperessed
[40a87fa]67        * 1d: models which are 1D only, or 2D using abs(q)
68        * 2d: models which can be 2D
69        * magnetic: models with an sld
70        * nommagnetic: models without an sld
[5124c969]71
72    For multiple conditions, combine with plus.  For example, *c+single+2d*
73    would return all oriented models implemented in C which can be computed
74    accurately with single precision arithmetic.
[aa4946b]75    """
[5124c969]76    if kind and any(k not in KINDS for k in kind.split('+')):
[2547694]77        raise ValueError("kind not in " + ", ".join(KINDS))
[3d9001f]78    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
[f734e7d]79    available_models = [basename(f)[:-3] for f in files]
[5124c969]80    if kind and '+' in kind:
81        all_kinds = kind.split('+')
82        condition = lambda name: all(_matches(name, k) for k in all_kinds)
83    else:
84        condition = lambda name: _matches(name, kind)
85    selected = [name for name in available_models if condition(name)]
[0b03001]86
87    return selected
88
89def _matches(name, kind):
[2547694]90    if kind is None or kind == "all":
[0b03001]91        return True
92    info = load_model_info(name)
93    pars = info.parameters.kernel_parameters
94    if kind == "py" and callable(info.Iq):
95        return True
96    elif kind == "c" and not callable(info.Iq):
97        return True
98    elif kind == "double" and not info.single:
99        return True
[d2d6100]100    elif kind == "single" and info.single:
101        return True
[8407d8c]102    elif kind == "opencl" and info.opencl:
[407bf48]103        return True
[2547694]104    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
[d2d6100]105        return True
[40a87fa]106    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
[0b03001]107        return True
[2547694]108    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
[0b03001]109        return True
[2547694]110    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
[d2d6100]111        return True
[0b03001]112    return False
[f734e7d]113
[f619de7]114def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]115    # type: (str, str, str) -> KernelModel
[b8e5e21]116    """
117    Load model info and build model.
[f619de7]118
119    *model_name* is the name of the model as used by :func:`load_model_info`.
120    Additional keyword arguments are passed directly to :func:`build_model`.
[b8e5e21]121    """
[f619de7]122    return build_model(load_model_info(model_name),
123                       dtype=dtype, platform=platform)
[aa4946b]124
125
[17bbadd]126def load_model_info(model_name):
[f619de7]127    # type: (str) -> modelinfo.ModelInfo
[aa4946b]128    """
129    Load a model definition given the model name.
[1d4017a]130
131    This returns a handle to the module defining the model.  This can be
132    used with functions in generate to build the docs or extract model info.
[aa4946b]133    """
[72a081d]134    parts = model_name.split('+')
135    if len(parts) > 1:
136        model_info_list = [load_model_info(p) for p in parts]
137        return mixture.make_mixture_info(model_info_list)
138
139    parts = model_name.split('*')
140    if len(parts) > 1:
141        if len(parts) > 2:
142            raise ValueError("use P*S to apply structure factor S to model P")
143        P_info, Q_info = [load_model_info(p) for p in parts]
144        return product.make_product_info(P_info, Q_info)
145
[68e7f9d]146    kernel_module = generate.load_kernel_module(model_name)
[6d6508e]147    return modelinfo.make_model_info(kernel_module)
[d19962c]148
149
[17bbadd]150def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]151    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]152    """
153    Prepare the model for the default execution platform.
154
155    This will return an OpenCL model, a DLL model or a python model depending
156    on the model and the computing platform.
157
[17bbadd]158    *model_info* is the model definition structure returned from
159    :func:`load_model_info`.
[bcd3aa3]160
[aa4946b]161    *dtype* indicates whether the model should use single or double precision
[dd7fc12]162    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
163    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
164    than OpenCL for the calculation.
[aa4946b]165
166    *platform* should be "dll" to force the dll to be used for C models,
167    otherwise it uses the default "ocl".
168    """
[6d6508e]169    composition = model_info.composition
[72a081d]170    if composition is not None:
171        composition_type, parts = composition
172        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
173        if composition_type == 'mixture':
174            return mixture.MixtureModel(model_info, models)
175        elif composition_type == 'product':
[e79f0a5]176            P, S = models
[72a081d]177            return product.ProductModel(model_info, P, S)
178        else:
179            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]180
[fa5fd8d]181    # If it is a python model, return it immediately
182    if callable(model_info.Iq):
183        return kernelpy.PyModel(model_info)
184
[def2c1b]185    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[7891a2a]186
[72a081d]187    source = generate.make_source(model_info)
[7891a2a]188    if platform == "dll":
[f2f67a6]189        #print("building dll", numpy_dtype)
[a4280bd]190        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[aa4946b]191    else:
[f2f67a6]192        #print("building ocl", numpy_dtype)
[dd7fc12]193        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]194
[7bf4757]195def precompile_dlls(path, dtype="double"):
[7891a2a]196    # type: (str, str) -> List[str]
[b8e5e21]197    """
[7bf4757]198    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]199
[7bf4757]200    *path* is the directory in which to save the dlls.  It will be created if
201    it does not already exist.
[b8e5e21]202
203    This can be used when build the windows distribution of sasmodels
[7bf4757]204    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]205    """
[7891a2a]206    numpy_dtype = np.dtype(dtype)
[7bf4757]207    if not os.path.exists(path):
208        os.makedirs(path)
209    compiled_dlls = []
210    for model_name in list_models():
211        model_info = load_model_info(model_name)
[a4280bd]212        if not callable(model_info.Iq):
213            source = generate.make_source(model_info)['dll']
[7bf4757]214            old_path = kerneldll.DLL_PATH
215            try:
216                kerneldll.DLL_PATH = path
[def2c1b]217                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]218            finally:
219                kerneldll.DLL_PATH = old_path
220            compiled_dlls.append(dll)
221    return compiled_dlls
[b8e5e21]222
[7891a2a]223def parse_dtype(model_info, dtype=None, platform=None):
224    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]225    """
226    Interpret dtype string, returning np.dtype and fast flag.
227
228    Possible types include 'half', 'single', 'double' and 'quad'.  If the
229    type is 'fast', then this is equivalent to dtype 'single' with the
230    fast flag set to True.
231    """
[7891a2a]232    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
[8407d8c]233    # If opencl=False OpenCL is switched off
[fe77343]234
[f49675c]235    if platform is None:
[7891a2a]236        platform = "ocl"
[8407d8c]237    if platform == "ocl" and not HAVE_OPENCL or not model_info.opencl:
[7891a2a]238        platform = "dll"
239
240    # Check if type indicates dll regardless of which platform is given
241    if dtype is not None and dtype.endswith('!'):
242        platform = "dll"
[dd7fc12]243        dtype = dtype[:-1]
244
[7891a2a]245    # Convert special type names "half", "fast", and "quad"
[2547694]246    fast = (dtype == "fast")
[7891a2a]247    if fast:
248        dtype = "single"
[2547694]249    elif dtype == "quad":
[7891a2a]250        dtype = "longdouble"
[2547694]251    elif dtype == "half":
[7891a2a]252        dtype = "f16"
253
254    # Convert dtype string to numpy dtype.
255    if dtype is None:
[2547694]256        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
257                       else generate.F64)
[dd7fc12]258    else:
[7891a2a]259        numpy_dtype = np.dtype(dtype)
[9890053]260
[7891a2a]261    # Make sure that the type is supported by opencl, otherwise use dll
[2547694]262    if platform == "ocl":
[7891a2a]263        env = kernelcl.environment()
264        if not env.has_type(numpy_dtype):
265            platform = "dll"
266            if dtype is None:
267                numpy_dtype = generate.F64
[dd7fc12]268
[7891a2a]269    return numpy_dtype, fast, platform
[0c24a82]270
[2547694]271def list_models_main():
[40a87fa]272    # type: () -> None
273    """
274    Run list_models as a main program.  See :func:`list_models` for the
275    kinds of models that can be requested on the command line.
276    """
[0b03001]277    import sys
278    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
279    print("\n".join(list_models(kind)))
[2547694]280
281if __name__ == "__main__":
282    list_models_main()
Note: See TracBrowser for help on using the repository browser.