source: sasmodels/sasmodels/core.py @ ce176ca

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ce176ca was def2c1b, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

honour platform request when selecting kernel

  • Property mode set to 100644
File size: 7.3 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dll",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[72a081d]12from os.path import basename, dirname, join as joinpath, splitext
[f734e7d]13from glob import glob
14
[7ae2b7f]15import numpy as np # type: ignore
[f734e7d]16
[aa4946b]17from . import generate
[6d6508e]18from . import modelinfo
19from . import product
[72a081d]20from . import mixture
[aa4946b]21from . import kernelpy
22from . import kerneldll
[f734e7d]23try:
[aa4946b]24    from . import kernelcl
25    HAVE_OPENCL = True
[ee8f734]26except Exception:
[aa4946b]27    HAVE_OPENCL = False
28
[f619de7]29try:
30    from typing import List, Union, Optional, Any
31    from .kernel import KernelModel
[dd7fc12]32    from .modelinfo import ModelInfo
[f619de7]33except ImportError:
34    pass
35
[56b2687]36try:
[98f60fc]37    np.meshgrid([])
38    meshgrid = np.meshgrid
[f5dde3f]39except Exception:
[98f60fc]40    # CRUFT: np.meshgrid requires multiple vectors
41    def meshgrid(*args):
42        if len(args) > 1:
43            return np.meshgrid(*args)
44        else:
45            return [np.asarray(v) for v in args]
[f619de7]46
[4d76711]47# TODO: refactor composite model support
48# The current load_model_info/build_model does not reuse existing model
49# definitions when loading a composite model, instead reloading and
50# rebuilding the kernel for each component model in the expression.  This
51# is fine in a scripting environment where the model is built when the script
52# starts and is thrown away when the script ends, but may not be the best
53# solution in a long-lived application.  This affects the following functions:
54#
55#    load_model
56#    load_model_info
57#    build_model
[f734e7d]58
59def list_models():
[f619de7]60    # type: () -> List[str]
[aa4946b]61    """
62    Return the list of available models on the model path.
63    """
[f734e7d]64    root = dirname(__file__)
65    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
66    available_models = [basename(f)[:-3] for f in files]
67    return available_models
68
[f619de7]69def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]70    # type: (str, str, str) -> KernelModel
[b8e5e21]71    """
72    Load model info and build model.
[f619de7]73
74    *model_name* is the name of the model as used by :func:`load_model_info`.
75    Additional keyword arguments are passed directly to :func:`build_model`.
[b8e5e21]76    """
[f619de7]77    return build_model(load_model_info(model_name),
78                       dtype=dtype, platform=platform)
[aa4946b]79
80
[17bbadd]81def load_model_info(model_name):
[f619de7]82    # type: (str) -> modelinfo.ModelInfo
[aa4946b]83    """
84    Load a model definition given the model name.
[1d4017a]85
86    This returns a handle to the module defining the model.  This can be
87    used with functions in generate to build the docs or extract model info.
[aa4946b]88    """
[72a081d]89    parts = model_name.split('+')
90    if len(parts) > 1:
91        model_info_list = [load_model_info(p) for p in parts]
92        return mixture.make_mixture_info(model_info_list)
93
94    parts = model_name.split('*')
95    if len(parts) > 1:
96        if len(parts) > 2:
97            raise ValueError("use P*S to apply structure factor S to model P")
98        P_info, Q_info = [load_model_info(p) for p in parts]
99        return product.make_product_info(P_info, Q_info)
100
[68e7f9d]101    kernel_module = generate.load_kernel_module(model_name)
[6d6508e]102    return modelinfo.make_model_info(kernel_module)
[d19962c]103
104
[17bbadd]105def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]106    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]107    """
108    Prepare the model for the default execution platform.
109
110    This will return an OpenCL model, a DLL model or a python model depending
111    on the model and the computing platform.
112
[17bbadd]113    *model_info* is the model definition structure returned from
114    :func:`load_model_info`.
[bcd3aa3]115
[aa4946b]116    *dtype* indicates whether the model should use single or double precision
[dd7fc12]117    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
118    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
119    than OpenCL for the calculation.
[aa4946b]120
121    *platform* should be "dll" to force the dll to be used for C models,
122    otherwise it uses the default "ocl".
123    """
[6d6508e]124    composition = model_info.composition
[72a081d]125    if composition is not None:
126        composition_type, parts = composition
127        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
128        if composition_type == 'mixture':
129            return mixture.MixtureModel(model_info, models)
130        elif composition_type == 'product':
131            from . import product
[e79f0a5]132            P, S = models
[72a081d]133            return product.ProductModel(model_info, P, S)
134        else:
135            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]136
[fa5fd8d]137    # If it is a python model, return it immediately
138    if callable(model_info.Iq):
139        return kernelpy.PyModel(model_info)
140
[def2c1b]141    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[7891a2a]142
[72a081d]143    source = generate.make_source(model_info)
[7891a2a]144    if platform == "dll":
[f2f67a6]145        #print("building dll", numpy_dtype)
[dd7fc12]146        return kerneldll.load_dll(source, model_info, numpy_dtype)
[aa4946b]147    else:
[f2f67a6]148        #print("building ocl", numpy_dtype)
[dd7fc12]149        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]150
[7bf4757]151def precompile_dlls(path, dtype="double"):
[7891a2a]152    # type: (str, str) -> List[str]
[b8e5e21]153    """
[7bf4757]154    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]155
[7bf4757]156    *path* is the directory in which to save the dlls.  It will be created if
157    it does not already exist.
[b8e5e21]158
159    This can be used when build the windows distribution of sasmodels
[7bf4757]160    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]161    """
[7891a2a]162    numpy_dtype = np.dtype(dtype)
[7bf4757]163    if not os.path.exists(path):
164        os.makedirs(path)
165    compiled_dlls = []
166    for model_name in list_models():
167        model_info = load_model_info(model_name)
168        source = generate.make_source(model_info)
169        if source:
170            old_path = kerneldll.DLL_PATH
171            try:
172                kerneldll.DLL_PATH = path
[def2c1b]173                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]174            finally:
175                kerneldll.DLL_PATH = old_path
176            compiled_dlls.append(dll)
177    return compiled_dlls
[b8e5e21]178
[7891a2a]179def parse_dtype(model_info, dtype=None, platform=None):
180    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]181    """
182    Interpret dtype string, returning np.dtype and fast flag.
183
184    Possible types include 'half', 'single', 'double' and 'quad'.  If the
185    type is 'fast', then this is equivalent to dtype 'single' with the
186    fast flag set to True.
187    """
[7891a2a]188    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
189    if platform is None:
190        platform = "ocl"
191    if platform=="ocl" and not HAVE_OPENCL:
192        platform = "dll"
193
194    # Check if type indicates dll regardless of which platform is given
195    if dtype is not None and dtype.endswith('!'):
196        platform = "dll"
[dd7fc12]197        dtype = dtype[:-1]
198
[7891a2a]199    # Convert special type names "half", "fast", and "quad"
200    fast = (dtype=="fast")
201    if fast:
202        dtype = "single"
203    elif dtype=="quad":
204        dtype = "longdouble"
205    elif dtype=="half":
206        dtype = "f16"
207
208    # Convert dtype string to numpy dtype.
209    if dtype is None:
210        numpy_dtype = generate.F32 if platform=="ocl" and model_info.single else generate.F64
[dd7fc12]211    else:
[7891a2a]212        numpy_dtype = np.dtype(dtype)
[9890053]213
[7891a2a]214    # Make sure that the type is supported by opencl, otherwise use dll
215    if platform=="ocl":
216        env = kernelcl.environment()
217        if not env.has_type(numpy_dtype):
218            platform = "dll"
219            if dtype is None:
220                numpy_dtype = generate.F64
[dd7fc12]221
[7891a2a]222    return numpy_dtype, fast, platform
Note: See TracBrowser for help on using the repository browser.