source: sasmodels/sasmodels/bumps_model.py @ 7cf2cfd

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 7cf2cfd was 7cf2cfd, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor compare.py so that bumps/sasview not required for simple tests

  • Property mode set to 100644
File size: 6.1 KB
Line 
1"""
2Wrap sasmodels for direct use by bumps.
3
4:class:`Model` is a wrapper for the sasmodels kernel which defines a
5bumps *Parameter* box for each kernel parameter.  *Model* accepts keyword
6arguments to set the initial value for each parameter.
7
8:class:`Experiment` combines the *Model* function with a data file loaded by the
9sasview data loader.  *Experiment* takes a *cutoff* parameter controlling
10how far the polydispersity integral extends.
11
12"""
13
14import datetime
15import warnings
16
17import numpy as np
18
19from bumps.names import Parameter
20
21from . import sesans
22from . import weights
23from .data import plot_theory
24from .direct_model import DataMixin
25
26# CRUFT: old style bumps wrapper which doesn't separate data and model
27def BumpsModel(data, model, cutoff=1e-5, **kw):
28    warnings.warn("Use of BumpsModel is deprecated.  Use bumps_model.Experiment instead.")
29    model = Model(model, **kw)
30    experiment = Experiment(data=data, model=model, cutoff=cutoff)
31    for k in model._parameter_names:
32        setattr(experiment, k, getattr(model, k))
33    return experiment
34
35
36class Model(object):
37    def __init__(self, model, **kw):
38        self._sasmodel = model
39        partype = model.info['partype']
40
41        pars = []
42        for p in model.info['parameters']:
43            name, default, limits = p[0], p[2], p[3]
44            value = kw.pop(name, default)
45            setattr(self, name, Parameter.default(value, name=name, limits=limits))
46            pars.append(name)
47        for name in partype['pd-2d']:
48            for xpart, xdefault, xlimits in [
49                ('_pd', 0, limits),
50                ('_pd_n', 35, (0, 1000)),
51                ('_pd_nsigma', 3, (0, 10)),
52                ('_pd_type', 'gaussian', None),
53                ]:
54                xname = name + xpart
55                xvalue = kw.pop(xname, xdefault)
56                if xlimits is not None:
57                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
58                    pars.append(xname)
59                setattr(self, xname, xvalue)
60        self._parameter_names = pars
61        if kw:
62            raise TypeError("unexpected parameters: %s"
63                            % (", ".join(sorted(kw.keys()))))
64
65    def parameters(self):
66        """
67        Return a dictionary of parameters
68        """
69        return dict((k, getattr(self, k)) for k in self._parameter_names)
70
71
72class Experiment(DataMixin):
73    """
74    Return a bumps wrapper for a SAS model.
75
76    *data* is the data to be fitted.
77
78    *model* is the SAS model from :func:`core.load_model`.
79
80    *cutoff* is the integration cutoff, which avoids computing the
81    the SAS model where the polydispersity weight is low.
82
83    Model parameters can be initialized with additional keyword
84    arguments, or by assigning to model.parameter_name.value.
85
86    The resulting bumps model can be used directly in a FitProblem call.
87    """
88    def __init__(self, data, model, cutoff=1e-5):
89
90        # remember inputs so we can inspect from outside
91        self.model = model
92        self.cutoff = cutoff
93        self._interpret_data(data, model._sasmodel)
94        self.update()
95
96    def update(self):
97        self._cache = {}
98
99    def numpoints(self):
100        """
101            Return the number of points
102        """
103        return len(self.Iq)
104
105    def parameters(self):
106        """
107        Return a dictionary of parameters
108        """
109        return self.model.parameters()
110
111    def theory(self):
112        if 'theory' not in self._cache:
113            pars = dict((k, v.value) for k,v in self.model.parameters().items())
114            self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff)
115            """
116            if self._fn is None:
117                q_input = self.model.kernel.make_input(self._kernel_inputs)
118                self._fn = self.model.kernel(q_input)
119
120            fixed_pars = [getattr(self.model, p).value for p in self._fn.fixed_pars]
121            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
122            #print fixed_pars,pd_pars
123            Iq_calc = self._fn(fixed_pars, pd_pars, self.cutoff)
124            #self._theory[:] = self._fn.eval(pars, pd_pars)
125            if self.data_type == 'sesans':
126                result = sesans.hankel(self.data.x, self.data.lam * 1e-9,
127                                       self.data.sample.thickness / 10,
128                                       self._kernel_inputs[0], Iq_calc)
129                self._cache['theory'] = result
130            else:
131                Iq = self.resolution.apply(Iq_calc)
132                self._cache['theory'] = Iq
133            """
134        return self._cache['theory']
135
136    def residuals(self):
137        #if np.any(self.err ==0): print "zeros in err"
138        return (self.theory() - self.Iq) / self.dIq
139
140    def nllf(self):
141        delta = self.residuals()
142        #if np.any(np.isnan(R)): print "NaN in residuals"
143        return 0.5 * np.sum(delta ** 2)
144
145    #def __call__(self):
146    #    return 2 * self.nllf() / self.dof
147
148    def plot(self, view='log'):
149        """
150        Plot the data and residuals.
151        """
152        data, theory, resid = self._data, self.theory(), self.residuals()
153        plot_theory(data, theory, resid, view)
154
155    def simulate_data(self, noise=None):
156        Iq = self.theory()
157        self._set_data(Iq, noise)
158
159    def save(self, basename):
160        pass
161
162    def remove_get_weights(self, name):
163        """
164        Get parameter dispersion weights
165        """
166        info = self.model.kernel.info
167        relative = name in info['partype']['pd-rel']
168        limits = info['limits'][name]
169        disperser, value, npts, width, nsigma = [
170            getattr(self.model, name + ext)
171            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')]
172        value, weight = weights.get_weights(
173            disperser, int(npts.value), width.value, nsigma.value,
174            value.value, limits, relative)
175        return value, weight / np.sum(weight)
176
177    def __getstate__(self):
178        # Can't pickle gpu functions, so instead make them lazy
179        state = self.__dict__.copy()
180        state['_kernel'] = None
181        return state
182
183    def __setstate__(self, state):
184        # pylint: disable=attribute-defined-outside-init
185        self.__dict__ = state
Note: See TracBrowser for help on using the repository browser.