source: sasmodels/sasmodels/bumps_model.py @ 49d1f8b8

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 49d1f8b8 was 49d1f8b8, checked in by Paul Kienzle <pkienzle@…>, 21 months ago

move multiscat into sasmodels namespace to make it easier to run the example

  • Property mode set to 100644
File size: 8.2 KB
Line 
1"""
2Wrap sasmodels for direct use by bumps.
3
4:class:`Model` is a wrapper for the sasmodels kernel which defines a
5bumps *Parameter* box for each kernel parameter.  *Model* accepts keyword
6arguments to set the initial value for each parameter.
7
8:class:`Experiment` combines the *Model* function with a data file loaded by
9the sasview data loader.  *Experiment* takes a *cutoff* parameter controlling
10how far the polydispersity integral extends.
11
12"""
13from __future__ import print_function
14
15__all__ = ["Model", "Experiment"]
16
17import numpy as np  # type: ignore
18
19from .data import plot_theory
20from .direct_model import DataMixin
21
22# pylint: disable=unused-import
23try:
24    from typing import Dict, Union, Tuple, Any
25    from .data import Data1D, Data2D
26    from .kernel import KernelModel
27    from .modelinfo import ModelInfo
28    Data = Union[Data1D, Data2D]
29except ImportError:
30    pass
31# pylint: enable=unused-import
32
33try:
34    # Optional import. This allows the doc builder and nosetests to run even
35    # when bumps is not on the path.
36    from bumps.names import Parameter # type: ignore
37except ImportError:
38    pass
39
40
41def create_parameters(model_info,  # type: ModelInfo
42                      **kwargs     # type: Union[float, str, Parameter]
43                     ):
44    # type: (...) -> Tuple[Dict[str, Parameter], Dict[str, str]]
45    """
46    Generate Bumps parameters from the model info.
47
48    *model_info* is returned from :func:`generate.model_info` on the
49    model definition module.
50
51    Any additional *key=value* pairs are initial values for the parameters
52    to the models.  Uninitialized parameters will use the model default
53    value.  The value can be a float, a bumps parameter, or in the case
54    of the distribution type parameter, a string.
55
56    Returns a dictionary of *{name: Parameter}* containing the bumps
57    parameters for each model parameter, and a dictionary of
58    *{name: str}* containing the polydispersity distribution types.
59    """
60    pars = {}     # type: Dict[str, Parameter]
61    pd_types = {} # type: Dict[str, str]
62    for p in model_info.parameters.call_parameters:
63        value = kwargs.pop(p.name, p.default)
64        pars[p.name] = Parameter.default(value, name=p.name, limits=p.limits)
65        if p.polydisperse:
66            for part, default, limits in [
67                    ('_pd', 0., pars[p.name].limits),
68                    ('_pd_n', 35., (0, 1000)),
69                    ('_pd_nsigma', 3., (0, 10)),
70                ]:
71                name = p.name + part
72                value = kwargs.pop(name, default)
73                pars[name] = Parameter.default(value, name=name, limits=limits)
74            name = p.name + '_pd_type'
75            pd_types[name] = str(kwargs.pop(name, 'gaussian'))
76
77    if kwargs:  # args not corresponding to parameters
78        raise TypeError("unexpected parameters: %s"
79                        % (", ".join(sorted(kwargs.keys()))))
80
81    return pars, pd_types
82
83class Model(object):
84    """
85    Bumps wrapper for a SAS model.
86
87    *model* is a runnable module as returned from :func:`core.load_model`.
88
89    *cutoff* is the polydispersity weight cutoff.
90
91    Any additional *key=value* pairs are model dependent parameters.
92    """
93    def __init__(self, model, **kwargs):
94        # type: (KernelModel, **Dict[str, Union[float, Parameter]]) -> None
95        self.sasmodel = model
96        pars, pd_types = create_parameters(model.info, **kwargs)
97        for k, v in pars.items():
98            setattr(self, k, v)
99        for k, v in pd_types.items():
100            setattr(self, k, v)
101        self._parameter_names = list(pars.keys())
102        self._pd_type_names = list(pd_types.keys())
103
104    def parameters(self):
105        # type: () -> Dict[str, Parameter]
106        """
107        Return a dictionary of parameters objects for the parameters,
108        excluding polydispersity distribution type.
109        """
110        return dict((k, getattr(self, k)) for k in self._parameter_names)
111
112    def state(self):
113        # type: () -> Dict[str, Union[Parameter, str]]
114        """
115        Return a dictionary of current values for all the parameters,
116        including polydispersity distribution type.
117        """
118        pars = dict((k, getattr(self, k).value) for k in self._parameter_names)
119        pars.update((k, getattr(self, k)) for k in self._pd_type_names)
120        return pars
121
122class Experiment(DataMixin):
123    r"""
124    Bumps wrapper for a SAS experiment.
125
126    *data* is a :class:`data.Data1D`, :class:`data.Data2D` or
127    :class:`data.Sesans` object.  Use :func:`data.empty_data1D` or
128    :func:`data.empty_data2D` to define $q, \Delta q$ calculation
129    points for displaying the SANS curve when there is no measured data.
130
131    *model* is a :class:`Model` object.
132
133    *cutoff* is the integration cutoff, which avoids computing the
134    the SAS model where the polydispersity weight is low.
135
136    The resulting model can be used directly in a Bumps FitProblem call.
137    """
138    _cache = None # type: Dict[str, np.ndarray]
139    def __init__(self, data, model, cutoff=1e-5, name=None, extra_pars=None):
140        # type: (Data, Model, float) -> None
141        # remember inputs so we can inspect from outside
142        self.name = data.filename if name is None else name
143        self.model = model
144        self.cutoff = cutoff
145        self._interpret_data(data, model.sasmodel)
146        self._cache = {}
147        self.extra_pars = extra_pars
148
149    def update(self):
150        # type: () -> None
151        """
152        Call when model parameters have changed and theory needs to be
153        recalculated.
154        """
155        self._cache.clear()
156
157    def numpoints(self):
158        # type: () -> float
159        """
160        Return the number of data points
161        """
162        return len(self.Iq)
163
164    def parameters(self):
165        # type: () -> Dict[str, Parameter]
166        """
167        Return a dictionary of parameters
168        """
169        pars = self.model.parameters()
170        if self.extra_pars:
171            pars.update(self.extra_pars)
172        return pars
173
174    def theory(self):
175        # type: () -> np.ndarray
176        """
177        Return the theory corresponding to the model parameters.
178
179        This method uses lazy evaluation, and requires model.update() to be
180        called when the parameters have changed.
181        """
182        if 'theory' not in self._cache:
183            pars = self.model.state()
184            self._cache['theory'] = self._calc_theory(pars, cutoff=self.cutoff)
185        return self._cache['theory']
186
187    def residuals(self):
188        # type: () -> np.ndarray
189        """
190        Return theory minus data normalized by uncertainty.
191        """
192        #if np.any(self.err ==0): print("zeros in err")
193        return (self.theory() - self.Iq) / self.dIq
194
195    def nllf(self):
196        # type: () -> float
197        """
198        Return the negative log likelihood of seeing data given the model
199        parameters, up to a normalizing constant which depends on the data
200        uncertainty.
201        """
202        delta = self.residuals()
203        #if np.any(np.isnan(R)): print("NaN in residuals")
204        return 0.5 * np.sum(delta**2)
205
206    #def __call__(self):
207    #    return 2 * self.nllf() / self.dof
208
209    def plot(self, view='log'):
210        # type: (str) -> None
211        """
212        Plot the data and residuals.
213        """
214        data, theory, resid = self._data, self.theory(), self.residuals()
215        # TODO: hack to display oriented usans 2-D pattern
216        Iq_calc = self.Iq_calc if isinstance(self.Iq_calc, tuple) else None
217        plot_theory(data, theory, resid, view, Iq_calc=Iq_calc)
218
219    def simulate_data(self, noise=None):
220        # type: (float) -> None
221        """
222        Generate simulated data.
223        """
224        Iq = self.theory()
225        self._set_data(Iq, noise)
226
227    def save(self, basename):
228        # type: (str) -> None
229        """
230        Save the model parameters and data into a file.
231
232        Not Implemented.
233        """
234        if self.data_type == "sesans":
235            np.savetxt(basename+".dat", np.array([self._data.x, self.theory()]).T)
236
237    def __getstate__(self):
238        # type: () -> Dict[str, Any]
239        # Can't pickle gpu functions, so instead make them lazy
240        state = self.__dict__.copy()
241        state['_kernel'] = None
242        return state
243
244    def __setstate__(self, state):
245        # type: (Dict[str, Any]) -> None
246        # pylint: disable=attribute-defined-outside-init
247        self.__dict__ = state
Note: See TracBrowser for help on using the repository browser.