source: sasmodels/sasmodels/core.py @ 9f60c06

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9f60c06 was 672978c, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

Merge branch 'master' of github.com:sasview/sasmodels

  • Property mode set to 100644
File size: 9.2 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
[2547694]8    "build_model", "precompile_dlls",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[2547694]12from os.path import basename, dirname, join as joinpath
[f734e7d]13from glob import glob
14
[7ae2b7f]15import numpy as np # type: ignore
[f734e7d]16
[aa4946b]17from . import generate
[6d6508e]18from . import modelinfo
19from . import product
[72a081d]20from . import mixture
[aa4946b]21from . import kernelpy
22from . import kerneldll
[f734e7d]23try:
[aa4946b]24    from . import kernelcl
25    HAVE_OPENCL = True
[ee8f734]26except Exception:
[aa4946b]27    HAVE_OPENCL = False
28
[f619de7]29try:
30    from typing import List, Union, Optional, Any
31    from .kernel import KernelModel
[dd7fc12]32    from .modelinfo import ModelInfo
[f619de7]33except ImportError:
34    pass
35
[56b2687]36try:
[98f60fc]37    np.meshgrid([])
38    meshgrid = np.meshgrid
[f5dde3f]39except Exception:
[98f60fc]40    # CRUFT: np.meshgrid requires multiple vectors
41    def meshgrid(*args):
[40a87fa]42        """Allow meshgrid with a single argument"""
[98f60fc]43        if len(args) > 1:
44            return np.meshgrid(*args)
45        else:
46            return [np.asarray(v) for v in args]
[f619de7]47
[4d76711]48# TODO: refactor composite model support
49# The current load_model_info/build_model does not reuse existing model
50# definitions when loading a composite model, instead reloading and
51# rebuilding the kernel for each component model in the expression.  This
52# is fine in a scripting environment where the model is built when the script
53# starts and is thrown away when the script ends, but may not be the best
54# solution in a long-lived application.  This affects the following functions:
55#
56#    load_model
57#    load_model_info
58#    build_model
[f734e7d]59
[2547694]60KINDS = ("all", "py", "c", "double", "single", "1d", "2d",
61         "nonmagnetic", "magnetic")
[0b03001]62def list_models(kind=None):
[f619de7]63    # type: () -> List[str]
[aa4946b]64    """
65    Return the list of available models on the model path.
[40a87fa]66
67    *kind* can be one of the following:
68
69        * all: all models
70        * py: python models only
71        * c: compiled models only
72        * single: models which support single precision
73        * double: models which require double precision
74        * 1d: models which are 1D only, or 2D using abs(q)
75        * 2d: models which can be 2D
76        * magnetic: models with an sld
77        * nommagnetic: models without an sld
[aa4946b]78    """
[0b03001]79    if kind and kind not in KINDS:
[2547694]80        raise ValueError("kind not in " + ", ".join(KINDS))
[3d9001f]81    files = sorted(glob(joinpath(generate.MODEL_PATH, "[a-zA-Z]*.py")))
[f734e7d]82    available_models = [basename(f)[:-3] for f in files]
[0b03001]83    selected = [name for name in available_models if _matches(name, kind)]
84
85    return selected
86
87def _matches(name, kind):
[2547694]88    if kind is None or kind == "all":
[0b03001]89        return True
90    info = load_model_info(name)
91    pars = info.parameters.kernel_parameters
92    if kind == "py" and callable(info.Iq):
93        return True
94    elif kind == "c" and not callable(info.Iq):
95        return True
96    elif kind == "double" and not info.single:
97        return True
[d2d6100]98    elif kind == "single" and info.single:
99        return True
[2547694]100    elif kind == "2d" and any(p.type == 'orientation' for p in pars):
[d2d6100]101        return True
[40a87fa]102    elif kind == "1d" and all(p.type != 'orientation' for p in pars):
[0b03001]103        return True
[2547694]104    elif kind == "magnetic" and any(p.type == 'sld' for p in pars):
[0b03001]105        return True
[2547694]106    elif kind == "nonmagnetic" and any(p.type != 'sld' for p in pars):
[d2d6100]107        return True
[0b03001]108    return False
[f734e7d]109
[f619de7]110def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]111    # type: (str, str, str) -> KernelModel
[b8e5e21]112    """
113    Load model info and build model.
[f619de7]114
115    *model_name* is the name of the model as used by :func:`load_model_info`.
116    Additional keyword arguments are passed directly to :func:`build_model`.
[b8e5e21]117    """
[f619de7]118    return build_model(load_model_info(model_name),
119                       dtype=dtype, platform=platform)
[aa4946b]120
121
[17bbadd]122def load_model_info(model_name):
[f619de7]123    # type: (str) -> modelinfo.ModelInfo
[aa4946b]124    """
125    Load a model definition given the model name.
[1d4017a]126
127    This returns a handle to the module defining the model.  This can be
128    used with functions in generate to build the docs or extract model info.
[aa4946b]129    """
[72a081d]130    parts = model_name.split('+')
131    if len(parts) > 1:
132        model_info_list = [load_model_info(p) for p in parts]
133        return mixture.make_mixture_info(model_info_list)
134
135    parts = model_name.split('*')
136    if len(parts) > 1:
137        if len(parts) > 2:
138            raise ValueError("use P*S to apply structure factor S to model P")
139        P_info, Q_info = [load_model_info(p) for p in parts]
140        return product.make_product_info(P_info, Q_info)
141
[68e7f9d]142    kernel_module = generate.load_kernel_module(model_name)
[6d6508e]143    return modelinfo.make_model_info(kernel_module)
[d19962c]144
145
[17bbadd]146def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]147    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]148    """
149    Prepare the model for the default execution platform.
150
151    This will return an OpenCL model, a DLL model or a python model depending
152    on the model and the computing platform.
153
[17bbadd]154    *model_info* is the model definition structure returned from
155    :func:`load_model_info`.
[bcd3aa3]156
[aa4946b]157    *dtype* indicates whether the model should use single or double precision
[dd7fc12]158    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
159    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
160    than OpenCL for the calculation.
[aa4946b]161
162    *platform* should be "dll" to force the dll to be used for C models,
163    otherwise it uses the default "ocl".
164    """
[6d6508e]165    composition = model_info.composition
[72a081d]166    if composition is not None:
167        composition_type, parts = composition
168        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
169        if composition_type == 'mixture':
170            return mixture.MixtureModel(model_info, models)
171        elif composition_type == 'product':
[e79f0a5]172            P, S = models
[72a081d]173            return product.ProductModel(model_info, P, S)
174        else:
175            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]176
[fa5fd8d]177    # If it is a python model, return it immediately
178    if callable(model_info.Iq):
179        return kernelpy.PyModel(model_info)
180
[def2c1b]181    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[7891a2a]182
[72a081d]183    source = generate.make_source(model_info)
[7891a2a]184    if platform == "dll":
[f2f67a6]185        #print("building dll", numpy_dtype)
[a4280bd]186        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[aa4946b]187    else:
[f2f67a6]188        #print("building ocl", numpy_dtype)
[dd7fc12]189        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]190
[7bf4757]191def precompile_dlls(path, dtype="double"):
[7891a2a]192    # type: (str, str) -> List[str]
[b8e5e21]193    """
[7bf4757]194    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]195
[7bf4757]196    *path* is the directory in which to save the dlls.  It will be created if
197    it does not already exist.
[b8e5e21]198
199    This can be used when build the windows distribution of sasmodels
[7bf4757]200    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]201    """
[7891a2a]202    numpy_dtype = np.dtype(dtype)
[7bf4757]203    if not os.path.exists(path):
204        os.makedirs(path)
205    compiled_dlls = []
206    for model_name in list_models():
207        model_info = load_model_info(model_name)
[a4280bd]208        if not callable(model_info.Iq):
209            source = generate.make_source(model_info)['dll']
[7bf4757]210            old_path = kerneldll.DLL_PATH
211            try:
212                kerneldll.DLL_PATH = path
[def2c1b]213                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]214            finally:
215                kerneldll.DLL_PATH = old_path
216            compiled_dlls.append(dll)
217    return compiled_dlls
[b8e5e21]218
[7891a2a]219def parse_dtype(model_info, dtype=None, platform=None):
220    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]221    """
222    Interpret dtype string, returning np.dtype and fast flag.
223
224    Possible types include 'half', 'single', 'double' and 'quad'.  If the
225    type is 'fast', then this is equivalent to dtype 'single' with the
226    fast flag set to True.
227    """
[7891a2a]228    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
229    if platform is None:
230        platform = "ocl"
[2547694]231    if platform == "ocl" and not HAVE_OPENCL:
[7891a2a]232        platform = "dll"
233
234    # Check if type indicates dll regardless of which platform is given
235    if dtype is not None and dtype.endswith('!'):
236        platform = "dll"
[dd7fc12]237        dtype = dtype[:-1]
238
[7891a2a]239    # Convert special type names "half", "fast", and "quad"
[2547694]240    fast = (dtype == "fast")
[7891a2a]241    if fast:
242        dtype = "single"
[2547694]243    elif dtype == "quad":
[7891a2a]244        dtype = "longdouble"
[2547694]245    elif dtype == "half":
[7891a2a]246        dtype = "f16"
247
248    # Convert dtype string to numpy dtype.
249    if dtype is None:
[2547694]250        numpy_dtype = (generate.F32 if platform == "ocl" and model_info.single
251                       else generate.F64)
[dd7fc12]252    else:
[7891a2a]253        numpy_dtype = np.dtype(dtype)
[9890053]254
[7891a2a]255    # Make sure that the type is supported by opencl, otherwise use dll
[2547694]256    if platform == "ocl":
[7891a2a]257        env = kernelcl.environment()
258        if not env.has_type(numpy_dtype):
259            platform = "dll"
260            if dtype is None:
261                numpy_dtype = generate.F64
[dd7fc12]262
[7891a2a]263    return numpy_dtype, fast, platform
[0c24a82]264
[2547694]265def list_models_main():
[40a87fa]266    # type: () -> None
267    """
268    Run list_models as a main program.  See :func:`list_models` for the
269    kinds of models that can be requested on the command line.
270    """
[0b03001]271    import sys
272    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
273    print("\n".join(list_models(kind)))
[2547694]274
275if __name__ == "__main__":
276    list_models_main()
Note: See TracBrowser for help on using the repository browser.