source: sasmodels/sasmodels/bumps_model.py @ a430f5f

ticket-1257-vesicle-productticket_1156ticket_822_more_unit_tests
Last change on this file since a430f5f was b297ba9, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

lint

  • Property mode set to 100644
File size: 10.3 KB
Line 
1"""
2Wrap sasmodels for direct use by bumps.
3
4:class:`Model` is a wrapper for the sasmodels kernel which defines a
5bumps *Parameter* box for each kernel parameter.  *Model* accepts keyword
6arguments to set the initial value for each parameter.
7
8:class:`Experiment` combines the *Model* function with a data file loaded by
9the sasview data loader.  *Experiment* takes a *cutoff* parameter controlling
10how far the polydispersity integral extends.
11
12"""
13from __future__ import print_function
14
15__all__ = ["Model", "Experiment"]
16
17import numpy as np  # type: ignore
18
19from .data import plot_theory
20from .direct_model import DataMixin
21
22# pylint: disable=unused-import
23try:
24    from typing import Dict, Union, Tuple, Any
25    from .data import Data1D, Data2D
26    from .kernel import KernelModel
27    from .modelinfo import ModelInfo
28    from .resolution import Resolution
29    Data = Union[Data1D, Data2D]
30except ImportError:
31    pass
32# pylint: enable=unused-import
33
34try:
35    # Optional import. This allows the doc builder and nosetests to run even
36    # when bumps is not on the path.
37    from bumps.names import Parameter # type: ignore
38    from bumps.parameter import Reference # type: ignore
39except ImportError:
40    pass
41
42
43def create_parameters(model_info,  # type: ModelInfo
44                      **kwargs     # type: Union[float, str, Parameter]
45                     ):
46    # type: (...) -> Tuple[Dict[str, Parameter], Dict[str, str]]
47    """
48    Generate Bumps parameters from the model info.
49
50    *model_info* is returned from :func:`generate.model_info` on the
51    model definition module.
52
53    Any additional *key=value* pairs are initial values for the parameters
54    to the models.  Uninitialized parameters will use the model default
55    value.  The value can be a float, a bumps parameter, or in the case
56    of the distribution type parameter, a string.
57
58    Returns a dictionary of *{name: Parameter}* containing the bumps
59    parameters for each model parameter, and a dictionary of
60    *{name: str}* containing the polydispersity distribution types.
61    """
62    pars = {}     # type: Dict[str, Parameter]
63    pd_types = {} # type: Dict[str, str]
64    for p in model_info.parameters.call_parameters:
65        value = kwargs.pop(p.name, p.default)
66        pars[p.name] = Parameter.default(value, name=p.name, limits=p.limits)
67        if p.polydisperse:
68            for part, default, limits in [
69                    ('_pd', 0., pars[p.name].limits),
70                    ('_pd_n', 35., (0, 1000)),
71                    ('_pd_nsigma', 3., (0, 10)),
72                ]:
73                name = p.name + part
74                value = kwargs.pop(name, default)
75                pars[name] = Parameter.default(value, name=name, limits=limits)
76            name = p.name + '_pd_type'
77            pd_types[name] = str(kwargs.pop(name, 'gaussian'))
78
79    if kwargs:  # args not corresponding to parameters
80        raise TypeError("unexpected parameters: %s"
81                        % (", ".join(sorted(kwargs.keys()))))
82
83    return pars, pd_types
84
85class Model(object):
86    """
87    Bumps wrapper for a SAS model.
88
89    *model* is a runnable module as returned from :func:`core.load_model`.
90
91    *cutoff* is the polydispersity weight cutoff.
92
93    Any additional *key=value* pairs are model dependent parameters.
94    """
95    def __init__(self, model, **kwargs):
96        # type: (KernelModel, **Dict[str, Union[float, Parameter]]) -> None
97        self.sasmodel = model
98        pars, pd_types = create_parameters(model.info, **kwargs)
99        for k, v in pars.items():
100            setattr(self, k, v)
101        for k, v in pd_types.items():
102            setattr(self, k, v)
103        self._parameter_names = list(pars.keys())
104        self._pd_type_names = list(pd_types.keys())
105
106    def parameters(self):
107        # type: () -> Dict[str, Parameter]
108        """
109        Return a dictionary of parameters objects for the parameters,
110        excluding polydispersity distribution type.
111        """
112        return dict((k, getattr(self, k)) for k in self._parameter_names)
113
114    def state(self):
115        # type: () -> Dict[str, Union[Parameter, str]]
116        """
117        Return a dictionary of current values for all the parameters,
118        including polydispersity distribution type.
119        """
120        pars = dict((k, getattr(self, k).value) for k in self._parameter_names)
121        pars.update((k, getattr(self, k)) for k in self._pd_type_names)
122        return pars
123
124class Experiment(DataMixin):
125    r"""
126    Bumps wrapper for a SAS experiment.
127
128    *data* is a :class:`data.Data1D`, :class:`data.Data2D` or
129    :class:`data.Sesans` object.  Use :func:`data.empty_data1D` or
130    :func:`data.empty_data2D` to define $q, \Delta q$ calculation
131    points for displaying the SANS curve when there is no measured data.
132
133    *model* is a :class:`Model` object.
134
135    *cutoff* is the integration cutoff, which avoids computing the
136    the SAS model where the polydispersity weight is low.
137
138    The resulting model can be used directly in a Bumps FitProblem call.
139    """
140    _cache = None # type: Dict[str, np.ndarray]
141    def __init__(self, data, model, cutoff=1e-5, name=None, extra_pars=None):
142        # type: (Data, Model, float) -> None
143        # Allow resolution function to define fittable parameters.  We do this
144        # by creating reference parameters within the resolution object rather
145        # than modifying the object itself to use bumps parameters.  We need
146        # to reset the parameters each time the object has changed.  These
147        # additional parameters need to be returned from the fitting engine.
148        # To make them available to the user, they are added as top-level
149        # attributes to the experiment object.  The only change to the
150        # resolution function is that it needs an optional 'fittable' attribute
151        # which maps the internal name to the user visible name for the
152        # for the parameter.
153        self._resolution = None
154        self._resolution_pars = {}
155        # remember inputs so we can inspect from outside
156        self.name = data.filename if name is None else name
157        self.model = model
158        self.cutoff = cutoff
159        self._interpret_data(data, model.sasmodel)
160        self._cache = {}
161        # CRUFT: no longer need extra parameters
162        # Multiple scattering probability is now retrieved directly from the
163        # multiple scattering resolution function.
164        self.extra_pars = extra_pars
165
166    def update(self):
167        # type: () -> None
168        """
169        Call when model parameters have changed and theory needs to be
170        recalculated.
171        """
172        self._cache.clear()
173
174    def numpoints(self):
175        # type: () -> float
176        """
177        Return the number of data points
178        """
179        return len(self.Iq)
180
181    @property
182    def resolution(self):
183        # type: () -> Union[None, Resolution]
184        """
185        :class:`sasmodels.Resolution` applied to the data, if any.
186        """
187        return self._resolution
188
189    @resolution.setter
190    def resolution(self, value):
191        # type: (Resolution) -> None
192        """
193        :class:`sasmodels.Resolution` applied to the data, if any.
194        """
195        self._resolution = value
196
197        # Remove old resolution fitting parameters from experiment
198        for name in self._resolution_pars:
199            delattr(self, name)
200
201        # Create new resolution fitting parameters
202        res_pars = getattr(self._resolution, 'fittable', {})
203        self._resolution_pars = {
204            name: Reference(self._resolution, refname, name=name)
205            for refname, name in res_pars.items()
206        }
207
208        # Add new resolution fitting parameters as experiment attributes
209        for name, ref in self._resolution_pars.items():
210            setattr(self, name, ref)
211
212    def parameters(self):
213        # type: () -> Dict[str, Parameter]
214        """
215        Return a dictionary of parameters
216        """
217        pars = self.model.parameters()
218        if self.extra_pars is not None:
219            pars.update(self.extra_pars)
220        pars.update(self._resolution_pars)
221        return pars
222
223    def theory(self):
224        # type: () -> np.ndarray
225        """
226        Return the theory corresponding to the model parameters.
227
228        This method uses lazy evaluation, and requires model.update() to be
229        called when the parameters have changed.
230        """
231        if 'theory' not in self._cache:
232            pars = self.model.state()
233            self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff)
234        return self._cache['theory']
235
236    def residuals(self):
237        # type: () -> np.ndarray
238        """
239        Return theory minus data normalized by uncertainty.
240        """
241        #if np.any(self.err ==0): print("zeros in err")
242        return (self.theory() - self.Iq) / self.dIq
243
244    def nllf(self):
245        # type: () -> float
246        """
247        Return the negative log likelihood of seeing data given the model
248        parameters, up to a normalizing constant which depends on the data
249        uncertainty.
250        """
251        delta = self.residuals()
252        #if np.any(np.isnan(R)): print("NaN in residuals")
253        return 0.5 * np.sum(delta**2)
254
255    #def __call__(self):
256    #    return 2 * self.nllf() / self.dof
257
258    def plot(self, view='log'):
259        # type: (str) -> None
260        """
261        Plot the data and residuals.
262        """
263        data, theory, resid = self._data, self.theory(), self.residuals()
264        # TODO: hack to display oriented usans 2-D pattern
265        Iq_calc = self.Iq_calc if isinstance(self.Iq_calc, tuple) else None
266        plot_theory(data, theory, resid, view, Iq_calc=Iq_calc)
267
268    def simulate_data(self, noise=None):
269        # type: (float) -> None
270        """
271        Generate simulated data.
272        """
273        Iq = self.theory()
274        self._set_data(Iq, noise)
275
276    def save(self, basename):
277        # type: (str) -> None
278        """
279        Save the model parameters and data into a file.
280
281        Not Implemented except for sesans fits.
282        """
283        if self.data_type == "sesans":
284            np.savetxt(basename+".dat", np.array([self._data.x, self.theory()]).T)
285
286    def __getstate__(self):
287        # type: () -> Dict[str, Any]
288        # Can't pickle gpu functions, so instead make them lazy
289        state = self.__dict__.copy()
290        state['_kernel'] = None
291        return state
292
293    def __setstate__(self, state):
294        # type: (Dict[str, Any]) -> None
295        # pylint: disable=attribute-defined-outside-init
296        self.__dict__ = state
Note: See TracBrowser for help on using the repository browser.