source: sasmodels/sasmodels/core.py @ 0b03001

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 0b03001 was 0b03001, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

list models by kind of model

  • Property mode set to 100644
File size: 8.2 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[6d6508e]4from __future__ import print_function
5
[98f60fc]6__all__ = [
[6d6508e]7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dll",
[98f60fc]9    ]
[f734e7d]10
[7bf4757]11import os
[72a081d]12from os.path import basename, dirname, join as joinpath, splitext
[f734e7d]13from glob import glob
14
[7ae2b7f]15import numpy as np # type: ignore
[f734e7d]16
[aa4946b]17from . import generate
[6d6508e]18from . import modelinfo
19from . import product
[72a081d]20from . import mixture
[aa4946b]21from . import kernelpy
22from . import kerneldll
[f734e7d]23try:
[aa4946b]24    from . import kernelcl
25    HAVE_OPENCL = True
[ee8f734]26except Exception:
[aa4946b]27    HAVE_OPENCL = False
28
[f619de7]29try:
30    from typing import List, Union, Optional, Any
31    from .kernel import KernelModel
[dd7fc12]32    from .modelinfo import ModelInfo
[f619de7]33except ImportError:
34    pass
35
[56b2687]36try:
[98f60fc]37    np.meshgrid([])
38    meshgrid = np.meshgrid
[f5dde3f]39except Exception:
[98f60fc]40    # CRUFT: np.meshgrid requires multiple vectors
41    def meshgrid(*args):
42        if len(args) > 1:
43            return np.meshgrid(*args)
44        else:
45            return [np.asarray(v) for v in args]
[f619de7]46
[4d76711]47# TODO: refactor composite model support
48# The current load_model_info/build_model does not reuse existing model
49# definitions when loading a composite model, instead reloading and
50# rebuilding the kernel for each component model in the expression.  This
51# is fine in a scripting environment where the model is built when the script
52# starts and is thrown away when the script ends, but may not be the best
53# solution in a long-lived application.  This affects the following functions:
54#
55#    load_model
56#    load_model_info
57#    build_model
[f734e7d]58
[0b03001]59def list_models(kind=None):
[f619de7]60    # type: () -> List[str]
[aa4946b]61    """
62    Return the list of available models on the model path.
63    """
[0b03001]64    KINDS = ("all", "py", "c", "double", "oriented", "magnetic")
65    if kind and kind not in KINDS:
66        raise ValueError("kind not in "+", ".join(KINDS))
[f734e7d]67    root = dirname(__file__)
68    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
69    available_models = [basename(f)[:-3] for f in files]
[0b03001]70    selected = [name for name in available_models if _matches(name, kind)]
71
72    return selected
73
74def _matches(name, kind):
75    if kind is None or kind=="all":
76        return True
77    info = load_model_info(name)
78    pars = info.parameters.kernel_parameters
79    if kind == "py" and callable(info.Iq):
80        return True
81    elif kind == "c" and not callable(info.Iq):
82        return True
83    elif kind == "double" and not info.single:
84        return True
85    elif kind == "oriented" and any(p.type=='orientation' for p in pars):
86        return True
87    elif kind == "magnetic" and any(p.type=='sld' for p in pars):
88        return True
89    return False
[f734e7d]90
[f619de7]91def load_model(model_name, dtype=None, platform='ocl'):
[dd7fc12]92    # type: (str, str, str) -> KernelModel
[b8e5e21]93    """
94    Load model info and build model.
[f619de7]95
96    *model_name* is the name of the model as used by :func:`load_model_info`.
97    Additional keyword arguments are passed directly to :func:`build_model`.
[b8e5e21]98    """
[f619de7]99    return build_model(load_model_info(model_name),
100                       dtype=dtype, platform=platform)
[aa4946b]101
102
[17bbadd]103def load_model_info(model_name):
[f619de7]104    # type: (str) -> modelinfo.ModelInfo
[aa4946b]105    """
106    Load a model definition given the model name.
[1d4017a]107
108    This returns a handle to the module defining the model.  This can be
109    used with functions in generate to build the docs or extract model info.
[aa4946b]110    """
[72a081d]111    parts = model_name.split('+')
112    if len(parts) > 1:
113        model_info_list = [load_model_info(p) for p in parts]
114        return mixture.make_mixture_info(model_info_list)
115
116    parts = model_name.split('*')
117    if len(parts) > 1:
118        if len(parts) > 2:
119            raise ValueError("use P*S to apply structure factor S to model P")
120        P_info, Q_info = [load_model_info(p) for p in parts]
121        return product.make_product_info(P_info, Q_info)
122
[68e7f9d]123    kernel_module = generate.load_kernel_module(model_name)
[6d6508e]124    return modelinfo.make_model_info(kernel_module)
[d19962c]125
126
[17bbadd]127def build_model(model_info, dtype=None, platform="ocl"):
[dd7fc12]128    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
[aa4946b]129    """
130    Prepare the model for the default execution platform.
131
132    This will return an OpenCL model, a DLL model or a python model depending
133    on the model and the computing platform.
134
[17bbadd]135    *model_info* is the model definition structure returned from
136    :func:`load_model_info`.
[bcd3aa3]137
[aa4946b]138    *dtype* indicates whether the model should use single or double precision
[dd7fc12]139    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
140    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
141    than OpenCL for the calculation.
[aa4946b]142
143    *platform* should be "dll" to force the dll to be used for C models,
144    otherwise it uses the default "ocl".
145    """
[6d6508e]146    composition = model_info.composition
[72a081d]147    if composition is not None:
148        composition_type, parts = composition
149        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
150        if composition_type == 'mixture':
151            return mixture.MixtureModel(model_info, models)
152        elif composition_type == 'product':
153            from . import product
[e79f0a5]154            P, S = models
[72a081d]155            return product.ProductModel(model_info, P, S)
156        else:
157            raise ValueError('unknown mixture type %s'%composition_type)
[aa4946b]158
[fa5fd8d]159    # If it is a python model, return it immediately
160    if callable(model_info.Iq):
161        return kernelpy.PyModel(model_info)
162
[def2c1b]163    numpy_dtype, fast, platform = parse_dtype(model_info, dtype, platform)
[7891a2a]164
[72a081d]165    source = generate.make_source(model_info)
[7891a2a]166    if platform == "dll":
[f2f67a6]167        #print("building dll", numpy_dtype)
[a4280bd]168        return kerneldll.load_dll(source['dll'], model_info, numpy_dtype)
[aa4946b]169    else:
[f2f67a6]170        #print("building ocl", numpy_dtype)
[dd7fc12]171        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
[aa4946b]172
[7bf4757]173def precompile_dlls(path, dtype="double"):
[7891a2a]174    # type: (str, str) -> List[str]
[b8e5e21]175    """
[7bf4757]176    Precompile the dlls for all builtin models, returning a list of dll paths.
[b8e5e21]177
[7bf4757]178    *path* is the directory in which to save the dlls.  It will be created if
179    it does not already exist.
[b8e5e21]180
181    This can be used when build the windows distribution of sasmodels
[7bf4757]182    which may be missing the OpenCL driver and the dll compiler.
[b8e5e21]183    """
[7891a2a]184    numpy_dtype = np.dtype(dtype)
[7bf4757]185    if not os.path.exists(path):
186        os.makedirs(path)
187    compiled_dlls = []
188    for model_name in list_models():
189        model_info = load_model_info(model_name)
[a4280bd]190        if not callable(model_info.Iq):
191            source = generate.make_source(model_info)['dll']
[7bf4757]192            old_path = kerneldll.DLL_PATH
193            try:
194                kerneldll.DLL_PATH = path
[def2c1b]195                dll = kerneldll.make_dll(source, model_info, dtype=numpy_dtype)
[7bf4757]196            finally:
197                kerneldll.DLL_PATH = old_path
198            compiled_dlls.append(dll)
199    return compiled_dlls
[b8e5e21]200
[7891a2a]201def parse_dtype(model_info, dtype=None, platform=None):
202    # type: (ModelInfo, str, str) -> (np.dtype, bool, str)
[dd7fc12]203    """
204    Interpret dtype string, returning np.dtype and fast flag.
205
206    Possible types include 'half', 'single', 'double' and 'quad'.  If the
207    type is 'fast', then this is equivalent to dtype 'single' with the
208    fast flag set to True.
209    """
[7891a2a]210    # Assign default platform, overriding ocl with dll if OpenCL is unavailable
211    if platform is None:
212        platform = "ocl"
213    if platform=="ocl" and not HAVE_OPENCL:
214        platform = "dll"
215
216    # Check if type indicates dll regardless of which platform is given
217    if dtype is not None and dtype.endswith('!'):
218        platform = "dll"
[dd7fc12]219        dtype = dtype[:-1]
220
[7891a2a]221    # Convert special type names "half", "fast", and "quad"
222    fast = (dtype=="fast")
223    if fast:
224        dtype = "single"
225    elif dtype=="quad":
226        dtype = "longdouble"
227    elif dtype=="half":
228        dtype = "f16"
229
230    # Convert dtype string to numpy dtype.
231    if dtype is None:
232        numpy_dtype = generate.F32 if platform=="ocl" and model_info.single else generate.F64
[dd7fc12]233    else:
[7891a2a]234        numpy_dtype = np.dtype(dtype)
[9890053]235
[7891a2a]236    # Make sure that the type is supported by opencl, otherwise use dll
237    if platform=="ocl":
238        env = kernelcl.environment()
239        if not env.has_type(numpy_dtype):
240            platform = "dll"
241            if dtype is None:
242                numpy_dtype = generate.F64
[dd7fc12]243
[7891a2a]244    return numpy_dtype, fast, platform
[0c24a82]245
246if __name__ == "__main__":
[0b03001]247    import sys
248    kind = sys.argv[1] if len(sys.argv) > 1 else "all"
249    print("\n".join(list_models(kind)))
Note: See TracBrowser for help on using the repository browser.