source: sasmodels/sasmodels/core.py @ dd7fc12

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since dd7fc12 was dd7fc12, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

fix kerneldll dtype problem; more type hinting

  • Property mode set to 100644
File size: 6.7 KB
Line 
1"""
2Core model handling routines.
3"""
4from __future__ import print_function
5
6__all__ = [
7    "list_models", "load_model", "load_model_info",
8    "build_model", "precompile_dll",
9    ]
10
11from os.path import basename, dirname, join as joinpath
12from glob import glob
13
14import numpy as np # type: ignore
15
16from . import generate
17from . import modelinfo
18from . import product
19from . import mixture
20from . import kernelpy
21from . import kerneldll
22try:
23    from . import kernelcl
24    HAVE_OPENCL = True
25except Exception:
26    HAVE_OPENCL = False
27
28try:
29    from typing import List, Union, Optional, Any
30    from .kernel import KernelModel
31    from .modelinfo import ModelInfo
32except ImportError:
33    pass
34
35
36# TODO: refactor composite model support
37# The current load_model_info/build_model does not reuse existing model
38# definitions when loading a composite model, instead reloading and
39# rebuilding the kernel for each component model in the expression.  This
40# is fine in a scripting environment where the model is built when the script
41# starts and is thrown away when the script ends, but may not be the best
42# solution in a long-lived application.  This affects the following functions:
43#
44#    load_model
45#    load_model_info
46#    build_model
47
48def list_models():
49    # type: () -> List[str]
50    """
51    Return the list of available models on the model path.
52    """
53    root = dirname(__file__)
54    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
55    available_models = [basename(f)[:-3] for f in files]
56    return available_models
57
58def load_model(model_name, dtype=None, platform='ocl'):
59    # type: (str, str, str) -> KernelModel
60    """
61    Load model info and build model.
62
63    *model_name* is the name of the model as used by :func:`load_model_info`.
64    Additional keyword arguments are passed directly to :func:`build_model`.
65    """
66    return build_model(load_model_info(model_name),
67                       dtype=dtype, platform=platform)
68
69
70def load_model_info(model_name):
71    # type: (str) -> modelinfo.ModelInfo
72    """
73    Load a model definition given the model name.
74
75    This returns a handle to the module defining the model.  This can be
76    used with functions in generate to build the docs or extract model info.
77    """
78    parts = model_name.split('+')
79    if len(parts) > 1:
80        model_info_list = [load_model_info(p) for p in parts]
81        return mixture.make_mixture_info(model_info_list)
82
83    parts = model_name.split('*')
84    if len(parts) > 1:
85        if len(parts) > 2:
86            raise ValueError("use P*S to apply structure factor S to model P")
87        P_info, Q_info = [load_model_info(p) for p in parts]
88        return product.make_product_info(P_info, Q_info)
89
90    kernel_module = generate.load_kernel_module(model_name)
91    return modelinfo.make_model_info(kernel_module)
92
93
94def build_model(model_info, dtype=None, platform="ocl"):
95    # type: (modelinfo.ModelInfo, str, str) -> KernelModel
96    """
97    Prepare the model for the default execution platform.
98
99    This will return an OpenCL model, a DLL model or a python model depending
100    on the model and the computing platform.
101
102    *model_info* is the model definition structure returned from
103    :func:`load_model_info`.
104
105    *dtype* indicates whether the model should use single or double precision
106    for the calculation.  Choices are 'single', 'double', 'quad', 'half',
107    or 'fast'.  If *dtype* ends with '!', then force the use of the DLL rather
108    than OpenCL for the calculation.
109
110    *platform* should be "dll" to force the dll to be used for C models,
111    otherwise it uses the default "ocl".
112    """
113    composition = model_info.composition
114    if composition is not None:
115        composition_type, parts = composition
116        models = [build_model(p, dtype=dtype, platform=platform) for p in parts]
117        if composition_type == 'mixture':
118            return mixture.MixtureModel(model_info, models)
119        elif composition_type == 'product':
120            from . import product
121            P, S = models
122            return product.ProductModel(model_info, P, S)
123        else:
124            raise ValueError('unknown mixture type %s'%composition_type)
125
126    # If it is a python model, return it immediately
127    if callable(model_info.Iq):
128        return kernelpy.PyModel(model_info)
129
130    ## for debugging:
131    ##  1. uncomment open().write so that the source will be saved next time
132    ##  2. run "python -m sasmodels.direct_model $MODELNAME" to save the source
133    ##  3. recomment the open.write() and uncomment open().read()
134    ##  4. rerun "python -m sasmodels.direct_model $MODELNAME"
135    ##  5. uncomment open().read() so that source will be regenerated from model
136    # open(model_info.name+'.c','w').write(source)
137    # source = open(model_info.name+'.cl','r').read()
138    source = generate.make_source(model_info)
139    numpy_dtype, fast = parse_dtype(model_info, dtype)
140    if (platform == "dll"
141            or dtype.endswith('!')
142            or not HAVE_OPENCL
143            or not kernelcl.environment().has_type(numpy_dtype)):
144        return kerneldll.load_dll(source, model_info, numpy_dtype)
145    else:
146        return kernelcl.GpuModel(source, model_info, numpy_dtype, fast=fast)
147
148def precompile_dll(model_name, dtype="double"):
149    # type: (str, str) -> Optional[str]
150    """
151    Precompile the dll for a model.
152
153    Returns the path to the compiled model, or None if the model is a pure
154    python model.
155
156    This can be used when build the windows distribution of sasmodels
157    (which may be missing the OpenCL driver and the dll compiler), or
158    otherwise sharing models with windows users who do not have a compiler.
159
160    See :func:`sasmodels.kerneldll.make_dll` for details on controlling the
161    dll path and the allowed floating point precision.
162    """
163    model_info = load_model_info(model_name)
164    numpy_dtype, fast = parse_dtype(model_info, dtype)
165    source = generate.make_source(model_info)
166    return kerneldll.make_dll(source, model_info, dtype=numpy_dtype) if source else None
167
168def parse_dtype(model_info, dtype):
169    # type: (ModelInfo, str) -> Tuple[np.dtype, bool]
170    """
171    Interpret dtype string, returning np.dtype and fast flag.
172
173    Possible types include 'half', 'single', 'double' and 'quad'.  If the
174    type is 'fast', then this is equivalent to dtype 'single' with the
175    fast flag set to True.
176    """
177    # Fill in default type based on required precision in the model
178    if dtype is None:
179        dtype = 'single' if model_info.single else 'double'
180
181    # Ignore platform indicator
182    if dtype.endswith('!'):
183        dtype = dtype[:-1]
184
185    # Convert type string to type
186    if dtype == 'quad':
187        return generate.F128, False
188    elif dtype == 'half':
189        return generate.F16, False
190    elif dtype == 'fast':
191        return generate.F32, True
192    else:
193        return np.dtype(dtype), False
194
Note: See TracBrowser for help on using the repository browser.