source: sasmodels/sasmodels/core.py @ 9890053

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9890053 was 9890053, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

add smoke tests for ER/VR; check that smoke test results are valid floats

  • Property mode set to 100644
File size: 3.2 KB
Line 
1__all__ = ["list_models", "load_model_cl", "load_model_dll",
2           "load_model_definition", ]
3
4from os.path import basename, dirname, join as joinpath
5from glob import glob
6
7import numpy as np
8
9from . import models
10from . import weights
11
12try:
13    from .kernelcl import load_model as load_model_cl
14except Exception,exc:
15    load_model_cl = None
16from .kerneldll import load_model as load_model_dll
17
18def list_models():
19    root = dirname(__file__)
20    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
21    available_models = [basename(f)[:-3] for f in files]
22    return available_models
23
24def load_model_definition(model_name):
25    __import__('sasmodels.models.'+model_name)
26    model_definition = getattr(models, model_name, None)
27    return model_definition
28
29def make_kernel(model, q_vectors):
30    """
31    Return a computation kernel from the model definition and the q input.
32    """
33    input = model.make_input(q_vectors)
34    return model(input)
35
36def get_weights(kernel, pars, name):
37    """
38    Generate the distribution for parameter *name* given the parameter values
39    in *pars*.
40
41    Searches for "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma"
42    """
43    relative = name in kernel.info['partype']['pd-rel']
44    limits = kernel.info['limits']
45    disperser = pars.get(name+'_pd_type', 'gaussian')
46    value = pars.get(name, kernel.info['defaults'][name])
47    npts = pars.get(name+'_pd_n', 0)
48    width = pars.get(name+'_pd', 0.0)
49    nsigma = pars.get(name+'_pd_nsigma', 3.0)
50    v,w = weights.get_weights(
51        disperser, npts, width, nsigma,
52        value, limits[name], relative)
53    return v,w/np.sum(w)
54
55def dispersion_mesh(pars):
56    """
57    Create a mesh grid of dispersion parameters and weights.
58
59    Returns [p1,p2,...],w where pj is a vector of values for parameter j
60    and w is a vector containing the products for weights for each
61    parameter set in the vector.
62    """
63    values, weights = zip(*pars)
64    if len(values) > 1:
65        values = [v.flatten() for v in np.meshgrid(*values)]
66        weights = np.vstack([v.flatten() for v in np.meshgrid(*weights)])
67        weights = np.prod(weights, axis=0)
68    return values, weights
69
70def call_kernel(kernel, pars, cutoff=1e-5):
71    fixed_pars = [pars.get(name, kernel.info['defaults'][name])
72                  for name in kernel.fixed_pars]
73    pd_pars = [get_weights(kernel, pars, name) for name in kernel.pd_pars]
74    return kernel(fixed_pars, pd_pars, cutoff=cutoff)
75
76def call_ER(kernel, pars):
77    ER = kernel.info.get('ER', None)
78    if ER is None:
79        return 1.0
80    else:
81        vol_pars = [get_weights(kernel, pars, name)
82                    for name in kernel.info['partype']['volume']]
83        values, weights = dispersion_mesh(vol_pars)
84        fv = ER(*values)
85        #print values[0].shape, weights.shape, fv.shape
86        return np.sum(weights*fv) / np.sum(weights)
87
88def call_VR(kernel, pars):
89    VR = kernel.info.get('VR', None)
90    if VR is None:
91        return 1.0
92    else:
93        vol_pars = [get_weights(kernel, pars, name)
94                    for name in kernel.info['partype']['volume']]
95        values, weights = dispersion_mesh(vol_pars)
96        whole,part = VR(*values)
97        return np.sum(weights*part)/np.sum(weights*whole)
98
Note: See TracBrowser for help on using the repository browser.