source: sasmodels/sasmodels/core.py @ 91c5fdc

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 91c5fdc was 7b3e62c, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

support load_model('p*q')

  • Property mode set to 100644
File size: 8.4 KB
RevLine 
[aa4946b]1"""
2Core model handling routines.
3"""
[f734e7d]4
5from os.path import basename, dirname, join as joinpath
6from glob import glob
7
8import numpy as np
9
10from . import models
11from . import weights
[aa4946b]12from . import generate
[f734e7d]13
[aa4946b]14from . import kernelpy
15from . import kerneldll
[f734e7d]16try:
[aa4946b]17    from . import kernelcl
18    HAVE_OPENCL = True
[3c56da87]19except:
[aa4946b]20    HAVE_OPENCL = False
21
[190fc2b]22__all__ = [
[17bbadd]23    "list_models", "load_model_info", "precompile_dll",
24    "build_model", "make_kernel", "call_kernel", "call_ER_VR",
[190fc2b]25]
[f734e7d]26
27def list_models():
[aa4946b]28    """
29    Return the list of available models on the model path.
30    """
[f734e7d]31    root = dirname(__file__)
32    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
33    available_models = [basename(f)[:-3] for f in files]
34    return available_models
35
[b8e5e21]36def isstr(s):
37    """
38    Return True if *s* is a string-like object.
39    """
40    try: s + ''
41    except: return False
42    return True
43
44def load_model(model_name, **kw):
45    """
46    Load model info and build model.
47    """
[7b3e62c]48    parts = model_name.split('+')
49    if len(parts) > 1:
50        from .mixture import MixtureModel
51        models = [load_model(p, **kw) for p in parts]
52        return MixtureModel(models)
53
54    parts = model_name.split('*')
55    if len(parts) > 1:
56        # Note: currently have circular reference
57        from .product import ProductModel
58        if len(parts) > 2:
59            raise ValueError("use P*S to apply structure factor S to model P")
60        P, Q = [load_model(p, **kw) for p in parts]
61        return ProductModel(P, Q)
62
[b8e5e21]63    return build_model(load_model_info(model_name), **kw)
[aa4946b]64
[17bbadd]65def load_model_info(model_name):
[aa4946b]66    """
67    Load a model definition given the model name.
[1d4017a]68
69    This returns a handle to the module defining the model.  This can be
70    used with functions in generate to build the docs or extract model info.
[aa4946b]71    """
[7b3e62c]72    #import sys; print "\n".join(sys.path)
[f734e7d]73    __import__('sasmodels.models.'+model_name)
[17bbadd]74    kernel_module = getattr(models, model_name, None)
75    return generate.make_model_info(kernel_module)
[f734e7d]76
[aa4946b]77
[17bbadd]78def build_model(model_info, dtype=None, platform="ocl"):
[aa4946b]79    """
80    Prepare the model for the default execution platform.
81
82    This will return an OpenCL model, a DLL model or a python model depending
83    on the model and the computing platform.
84
[17bbadd]85    *model_info* is the model definition structure returned from
86    :func:`load_model_info`.
[bcd3aa3]87
[aa4946b]88    *dtype* indicates whether the model should use single or double precision
89    for the calculation. Any valid numpy single or double precision identifier
90    is valid, such as 'single', 'f', 'f32', or np.float32 for single, or
[d18582e]91    'double', 'd', 'f64'  and np.float64 for double.  If *None*, then use
92    'single' unless the model defines single=False.
[aa4946b]93
94    *platform* should be "dll" to force the dll to be used for C models,
95    otherwise it uses the default "ocl".
96    """
[17bbadd]97    source = generate.make_source(model_info)
[d18582e]98    if dtype is None:
[17bbadd]99        dtype = 'single' if model_info['single'] else 'double'
100    if callable(model_info.get('Iq', None)):
101        return kernelpy.PyModel(model_info)
[aa4946b]102
103    ## for debugging:
104    ##  1. uncomment open().write so that the source will be saved next time
105    ##  2. run "python -m sasmodels.direct_model $MODELNAME" to save the source
106    ##  3. recomment the open.write() and uncomment open().read()
107    ##  4. rerun "python -m sasmodels.direct_model $MODELNAME"
108    ##  5. uncomment open().read() so that source will be regenerated from model
[17bbadd]109    # open(model_info['name']+'.c','w').write(source)
110    # source = open(model_info['name']+'.cl','r').read()
[aa4946b]111
[d15a908]112    if (platform == "dll"
[aa4946b]113            or not HAVE_OPENCL
[5d316e9]114            or not kernelcl.environment().has_type(dtype)):
[17bbadd]115        return kerneldll.load_dll(source, model_info, dtype)
[aa4946b]116    else:
[17bbadd]117        return kernelcl.GpuModel(source, model_info, dtype)
[aa4946b]118
[b8e5e21]119def precompile_dll(model_name, dtype="double"):
120    """
121    Precompile the dll for a model.
122
123    Returns the path to the compiled model, or None if the model is a pure
124    python model.
125
126    This can be used when build the windows distribution of sasmodels
127    (which may be missing the OpenCL driver and the dll compiler), or
128    otherwise sharing models with windows users who do not have a compiler.
129
130    See :func:`sasmodels.kerneldll.make_dll` for details on controlling the
131    dll path and the allowed floating point precision.
132    """
133    model_info = load_model_info(model_name)
134    source = generate.make_source(model_info)
135    return kerneldll.make_dll(source, model_info, dtype=dtype) if source else None
136
137
[f734e7d]138def make_kernel(model, q_vectors):
139    """
140    Return a computation kernel from the model definition and the q input.
141    """
[eafc9fa]142    return model(q_vectors)
[f734e7d]143
[17bbadd]144def get_weights(model_info, pars, name):
[f734e7d]145    """
146    Generate the distribution for parameter *name* given the parameter values
147    in *pars*.
148
[aa4946b]149    Uses "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma"
150    from the *pars* dictionary for parameter value and parameter dispersion.
[f734e7d]151    """
[17bbadd]152    relative = name in model_info['partype']['pd-rel']
153    limits = model_info['limits'][name]
[f734e7d]154    disperser = pars.get(name+'_pd_type', 'gaussian')
[17bbadd]155    value = pars.get(name, model_info['defaults'][name])
[f734e7d]156    npts = pars.get(name+'_pd_n', 0)
157    width = pars.get(name+'_pd', 0.0)
158    nsigma = pars.get(name+'_pd_nsigma', 3.0)
[d15a908]159    value, weight = weights.get_weights(
[7cf2cfd]160        disperser, npts, width, nsigma, value, limits, relative)
161    return value, weight / np.sum(weight)
[f734e7d]162
[9890053]163def dispersion_mesh(pars):
164    """
165    Create a mesh grid of dispersion parameters and weights.
166
167    Returns [p1,p2,...],w where pj is a vector of values for parameter j
168    and w is a vector containing the products for weights for each
169    parameter set in the vector.
170    """
[3c56da87]171    value, weight = zip(*pars)
172    if len(value) > 1:
173        value = [v.flatten() for v in np.meshgrid(*value)]
174        weight = np.vstack([v.flatten() for v in np.meshgrid(*weight)])
175        weight = np.prod(weight, axis=0)
176    return value, weight
[9890053]177
[aa4946b]178def call_kernel(kernel, pars, cutoff=0):
179    """
180    Call *kernel* returned from :func:`make_kernel` with parameters *pars*.
181
182    *cutoff* is the limiting value for the product of dispersion weights used
183    to perform the multidimensional dispersion calculation more quickly at a
184    slight cost to accuracy. The default value of *cutoff=0* integrates over
185    the entire dispersion cube.  Using *cutoff=1e-5* can be 50% faster, but
186    with an error of about 1%, which is usually less than the measurement
187    uncertainty.
188    """
[f734e7d]189    fixed_pars = [pars.get(name, kernel.info['defaults'][name])
190                  for name in kernel.fixed_pars]
[3c56da87]191    pd_pars = [get_weights(kernel.info, pars, name) for name in kernel.pd_pars]
[f734e7d]192    return kernel(fixed_pars, pd_pars, cutoff=cutoff)
193
[17bbadd]194def call_ER_VR(model_info, vol_pars):
195    """
196    Return effect radius and volume ratio for the model.
197
198    *info* is either *kernel.info* for *kernel=make_kernel(model,q)*
199    or *model.info*.
200
201    *pars* are the parameters as expected by :func:`call_kernel`.
202    """
203    ER = model_info.get('ER', None)
204    VR = model_info.get('VR', None)
205    value, weight = dispersion_mesh(vol_pars)
206
207    individual_radii = ER(*value) if ER else 1.0
208    whole, part = VR(*value) if VR else (1.0, 1.0)
209
210    effect_radius = np.sum(weight*individual_radii) / np.sum(weight)
211    volume_ratio = np.sum(weight*part)/np.sum(weight*whole)
212    return effect_radius, volume_ratio
213
214
[3c56da87]215def call_ER(info, pars):
[aa4946b]216    """
217    Call the model ER function using *pars*.
218    *info* is either *model.info* if you have a loaded model, or *kernel.info*
219    if you have a model kernel prepared for evaluation.
220    """
[3c56da87]221    ER = info.get('ER', None)
[9890053]222    if ER is None:
223        return 1.0
224    else:
[3c56da87]225        vol_pars = [get_weights(info, pars, name)
226                    for name in info['partype']['volume']]
227        value, weight = dispersion_mesh(vol_pars)
228        individual_radii = ER(*value)
[9404dd3]229        #print(values[0].shape, weights.shape, fv.shape)
[3c56da87]230        return np.sum(weight*individual_radii) / np.sum(weight)
[9890053]231
[3c56da87]232def call_VR(info, pars):
[aa4946b]233    """
234    Call the model VR function using *pars*.
235    *info* is either *model.info* if you have a loaded model, or *kernel.info*
236    if you have a model kernel prepared for evaluation.
237    """
[3c56da87]238    VR = info.get('VR', None)
[9890053]239    if VR is None:
240        return 1.0
241    else:
[3c56da87]242        vol_pars = [get_weights(info, pars, name)
243                    for name in info['partype']['volume']]
244        value, weight = dispersion_mesh(vol_pars)
[d15a908]245        whole, part = VR(*value)
[3c56da87]246        return np.sum(weight*part)/np.sum(weight*whole)
[9890053]247
[17bbadd]248# TODO: remove call_ER, call_VR
249
Note: See TracBrowser for help on using the repository browser.