source: sasmodels/sasmodels/direct_model.py @ ae7b97b

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ae7b97b was ae7b97b, checked in by pkienzle, 9 years ago

create a directly callable model without sasview or bumps parameters

  • Property mode set to 100644
File size: 2.9 KB
Line 
1import warnings
2
3import numpy as np
4
5from . import models
6from . import weights
7
8try:
9    from .kernelcl import load_model
10except ImportError,exc:
11    warnings.warn(str(exc))
12    warnings.warn("using ctypes instead")
13    from .kerneldll import load_model
14
15def load_model_definition(model_name):
16    __import__('sasmodels.models.'+model_name)
17    model_definition = getattr(models, model_name, None)
18    return model_definition
19
20# load_model is imported above.  It looks like the following
21#def load_model(model_definition, dtype='single):
22#    if kerneldll:
23#        if source is newer than compiled: compile
24#        load dll
25#        return kernel
26#    elif kernelcl:
27#        compile source on context
28#        return kernel
29
30
31def make_kernel(model, q_vectors):
32    """
33    Return a computation kernel from the model definition and the q input.
34    """
35    input = model.make_input(q_vectors)
36    return model(input)
37
38def get_weights(kernel, pars, name):
39    """
40    Generate the distribution for parameter *name* given the parameter values
41    in *pars*.
42
43    Searches for "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma"
44    """
45    relative = name in kernel.info['partype']['pd-rel']
46    limits = kernel.info['limits']
47    disperser = pars.get(name+'_pd_type', 'gaussian')
48    value = pars.get(name)
49    npts = pars.get(name+'_pd_n', 0)
50    width = pars.get(name+'_pd', 0.0)
51    nsigma = pars.get(name+'_pd_nsigma', 3.0)
52    v,w = weights.get_weights(
53        disperser, npts, width, nsigma,
54        value, limits[name], relative)
55    return v,w/np.sum(w)
56
57def call_kernel(kernel, pars):
58    fixed_pars = [pars.get(name, kernel.info['defaults'][name])
59                  for name in kernel.fixed_pars]
60    pd_pars = [get_weights(kernel, pars, name) for name in kernel.pd_pars]
61    return kernel(fixed_pars, pd_pars)
62
63class DirectModel:
64    def __init__(self, name, q_vectors, dtype='single'):
65        self.model_definition = load_model_definition(name)
66        self.model = load_model(self.model_definition, dtype=dtype)
67        q_vectors = [np.ascontiguousarray(q,dtype=dtype) for q in q_vectors]
68        self.kernel = make_kernel(self.model, q_vectors)
69    def __call__(self, pars):
70        return call_kernel(self.kernel, pars)
71
72def demo():
73    import sys
74    if len(sys.argv) < 3:
75        print "usage: python -m sasmodels.direct_model modelname (q|qx,qy) par=val ..."
76        sys.exit(1)
77    model_name = sys.argv[1]
78    values = [float(v) for v in sys.argv[2].split(',')]
79    if len(values) == 1:
80        q = values[0]
81        q_vectors = [[q]]
82    elif len(values) == 2:
83        qx,qy = values
84        q_vectors = [[qx],[qy]]
85    else:
86        print "use q or qx,qy"
87        sys.exit(1)
88    model = DirectModel(model_name, q_vectors)
89    pars = dict((k,float(v))
90                for pair in sys.argv[3:]
91                for k,v in [pair.split('=')])
92    Iq = model(pars)
93    print Iq[0]
94
95if __name__ == "__main__":
96    demo()
Note: See TracBrowser for help on using the repository browser.