source: sasmodels/sasmodels/bumps_model.py @ 2c4a190

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 2c4a190 was 2c4a190, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

simplify use of multiple scattering in bumps fits

  • Property mode set to 100644
File size: 10.0 KB
RevLine 
[3330bb4]1"""
2Wrap sasmodels for direct use by bumps.
3
4:class:`Model` is a wrapper for the sasmodels kernel which defines a
5bumps *Parameter* box for each kernel parameter.  *Model* accepts keyword
6arguments to set the initial value for each parameter.
7
8:class:`Experiment` combines the *Model* function with a data file loaded by
9the sasview data loader.  *Experiment* takes a *cutoff* parameter controlling
10how far the polydispersity integral extends.
11
12"""
13from __future__ import print_function
14
15__all__ = ["Model", "Experiment"]
16
17import numpy as np  # type: ignore
18
19from .data import plot_theory
20from .direct_model import DataMixin
21
[2d81cfe]22# pylint: disable=unused-import
[3330bb4]23try:
24    from typing import Dict, Union, Tuple, Any
25    from .data import Data1D, Data2D
26    from .kernel import KernelModel
27    from .modelinfo import ModelInfo
28    Data = Union[Data1D, Data2D]
29except ImportError:
30    pass
[2d81cfe]31# pylint: enable=unused-import
[3330bb4]32
33try:
34    # Optional import. This allows the doc builder and nosetests to run even
35    # when bumps is not on the path.
36    from bumps.names import Parameter # type: ignore
[2c4a190]37    from bumps.parameter import Reference # type: ignore
[3330bb4]38except ImportError:
39    pass
40
41
[2d81cfe]42def create_parameters(model_info,  # type: ModelInfo
43                      **kwargs     # type: Union[float, str, Parameter]
44                     ):
45    # type: (...) -> Tuple[Dict[str, Parameter], Dict[str, str]]
[3330bb4]46    """
47    Generate Bumps parameters from the model info.
48
49    *model_info* is returned from :func:`generate.model_info` on the
50    model definition module.
51
52    Any additional *key=value* pairs are initial values for the parameters
53    to the models.  Uninitialized parameters will use the model default
54    value.  The value can be a float, a bumps parameter, or in the case
55    of the distribution type parameter, a string.
56
57    Returns a dictionary of *{name: Parameter}* containing the bumps
58    parameters for each model parameter, and a dictionary of
59    *{name: str}* containing the polydispersity distribution types.
60    """
61    pars = {}     # type: Dict[str, Parameter]
62    pd_types = {} # type: Dict[str, str]
63    for p in model_info.parameters.call_parameters:
64        value = kwargs.pop(p.name, p.default)
65        pars[p.name] = Parameter.default(value, name=p.name, limits=p.limits)
66        if p.polydisperse:
67            for part, default, limits in [
68                    ('_pd', 0., pars[p.name].limits),
69                    ('_pd_n', 35., (0, 1000)),
70                    ('_pd_nsigma', 3., (0, 10)),
71                ]:
72                name = p.name + part
73                value = kwargs.pop(name, default)
74                pars[name] = Parameter.default(value, name=name, limits=limits)
75            name = p.name + '_pd_type'
76            pd_types[name] = str(kwargs.pop(name, 'gaussian'))
77
78    if kwargs:  # args not corresponding to parameters
79        raise TypeError("unexpected parameters: %s"
80                        % (", ".join(sorted(kwargs.keys()))))
81
82    return pars, pd_types
83
84class Model(object):
85    """
86    Bumps wrapper for a SAS model.
87
88    *model* is a runnable module as returned from :func:`core.load_model`.
89
90    *cutoff* is the polydispersity weight cutoff.
91
92    Any additional *key=value* pairs are model dependent parameters.
93    """
94    def __init__(self, model, **kwargs):
95        # type: (KernelModel, **Dict[str, Union[float, Parameter]]) -> None
96        self.sasmodel = model
97        pars, pd_types = create_parameters(model.info, **kwargs)
98        for k, v in pars.items():
99            setattr(self, k, v)
100        for k, v in pd_types.items():
101            setattr(self, k, v)
102        self._parameter_names = list(pars.keys())
103        self._pd_type_names = list(pd_types.keys())
104
105    def parameters(self):
106        # type: () -> Dict[str, Parameter]
107        """
108        Return a dictionary of parameters objects for the parameters,
109        excluding polydispersity distribution type.
110        """
111        return dict((k, getattr(self, k)) for k in self._parameter_names)
112
113    def state(self):
114        # type: () -> Dict[str, Union[Parameter, str]]
115        """
116        Return a dictionary of current values for all the parameters,
117        including polydispersity distribution type.
118        """
119        pars = dict((k, getattr(self, k).value) for k in self._parameter_names)
120        pars.update((k, getattr(self, k)) for k in self._pd_type_names)
121        return pars
122
123class Experiment(DataMixin):
124    r"""
125    Bumps wrapper for a SAS experiment.
126
127    *data* is a :class:`data.Data1D`, :class:`data.Data2D` or
128    :class:`data.Sesans` object.  Use :func:`data.empty_data1D` or
129    :func:`data.empty_data2D` to define $q, \Delta q$ calculation
130    points for displaying the SANS curve when there is no measured data.
131
132    *model* is a :class:`Model` object.
133
134    *cutoff* is the integration cutoff, which avoids computing the
135    the SAS model where the polydispersity weight is low.
136
137    The resulting model can be used directly in a Bumps FitProblem call.
138    """
139    _cache = None # type: Dict[str, np.ndarray]
[49d1f8b8]140    def __init__(self, data, model, cutoff=1e-5, name=None, extra_pars=None):
[3330bb4]141        # type: (Data, Model, float) -> None
[2c4a190]142        # Allow resolution function to define fittable parameters.  We do this
143        # by creating reference parameters within the resolution object rather
144        # than modifying the object itself to use bumps parameters.  We need
145        # to reset the parameters each time the object has changed.  These
146        # additional parameters need to be returned from the fitting engine.
147        # To make them available to the user, they are added as top-level
148        # attributes to the experiment object.  The only change to the
149        # resolution function is that it needs an optional 'fittable' attribute
150        # which maps the internal name to the user visible name for the
151        # for the parameter.
152        self._resolution = None
153        self._resolution_pars = {}
[3330bb4]154        # remember inputs so we can inspect from outside
[74b0495]155        self.name = data.filename if name is None else name
[3330bb4]156        self.model = model
157        self.cutoff = cutoff
158        self._interpret_data(data, model.sasmodel)
159        self._cache = {}
[2c4a190]160        # CRUFT: no longer need extra parameters
161        # Multiple scattering probability is now retrieved directly from the
162        # multiple scattering resolution function.
[49d1f8b8]163        self.extra_pars = extra_pars
[3330bb4]164
165    def update(self):
166        # type: () -> None
167        """
168        Call when model parameters have changed and theory needs to be
169        recalculated.
170        """
171        self._cache.clear()
172
173    def numpoints(self):
174        # type: () -> float
175        """
176        Return the number of data points
177        """
178        return len(self.Iq)
179
[2c4a190]180    @property
181    def resolution(self):
182        return self._resolution
183
184    @resolution.setter
185    def resolution(self, value):
186        self._resolution = value
187
188        # Remove old resolution fitting parameters from experiment
189        for name in self._resolution_pars:
190            delattr(self, name)
191
192        # Create new resolution fitting parameters
193        res_pars = getattr(self._resolution, 'fittable', {})
194        self._resolution_pars = {
195            name: Reference(self._resolution, refname, name=name)
196            for refname, name in res_pars.items()
197        }
198
199        # Add new resolution fitting parameters as experiment attributes
200        for name, ref in self._resolution_pars.items():
201            setattr(self, name, ref)
202
[3330bb4]203    def parameters(self):
204        # type: () -> Dict[str, Parameter]
205        """
206        Return a dictionary of parameters
207        """
[49d1f8b8]208        pars = self.model.parameters()
[2c4a190]209        if self.extra_pars is not None:
[49d1f8b8]210            pars.update(self.extra_pars)
[2c4a190]211        pars.update(self._resolution_pars)
[49d1f8b8]212        return pars
[3330bb4]213
214    def theory(self):
215        # type: () -> np.ndarray
216        """
217        Return the theory corresponding to the model parameters.
218
219        This method uses lazy evaluation, and requires model.update() to be
220        called when the parameters have changed.
221        """
222        if 'theory' not in self._cache:
223            pars = self.model.state()
224            self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff)
225        return self._cache['theory']
226
227    def residuals(self):
228        # type: () -> np.ndarray
229        """
230        Return theory minus data normalized by uncertainty.
231        """
232        #if np.any(self.err ==0): print("zeros in err")
233        return (self.theory() - self.Iq) / self.dIq
234
235    def nllf(self):
236        # type: () -> float
237        """
238        Return the negative log likelihood of seeing data given the model
239        parameters, up to a normalizing constant which depends on the data
240        uncertainty.
241        """
242        delta = self.residuals()
243        #if np.any(np.isnan(R)): print("NaN in residuals")
244        return 0.5 * np.sum(delta**2)
245
246    #def __call__(self):
247    #    return 2 * self.nllf() / self.dof
248
249    def plot(self, view='log'):
250        # type: (str) -> None
251        """
252        Plot the data and residuals.
253        """
254        data, theory, resid = self._data, self.theory(), self.residuals()
[74b0495]255        # TODO: hack to display oriented usans 2-D pattern
256        Iq_calc = self.Iq_calc if isinstance(self.Iq_calc, tuple) else None
257        plot_theory(data, theory, resid, view, Iq_calc=Iq_calc)
[3330bb4]258
259    def simulate_data(self, noise=None):
260        # type: (float) -> None
261        """
262        Generate simulated data.
263        """
264        Iq = self.theory()
265        self._set_data(Iq, noise)
266
267    def save(self, basename):
268        # type: (str) -> None
269        """
270        Save the model parameters and data into a file.
271
272        Not Implemented.
273        """
274        if self.data_type == "sesans":
275            np.savetxt(basename+".dat", np.array([self._data.x, self.theory()]).T)
276
277    def __getstate__(self):
278        # type: () -> Dict[str, Any]
279        # Can't pickle gpu functions, so instead make them lazy
280        state = self.__dict__.copy()
281        state['_kernel'] = None
282        return state
283
284    def __setstate__(self, state):
285        # type: (Dict[str, Any]) -> None
286        # pylint: disable=attribute-defined-outside-init
287        self.__dict__ = state
Note: See TracBrowser for help on using the repository browser.