source: sasmodels/sasmodels/direct_model.py @ a738209

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a738209 was a738209, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

simplify kernels by remove coordination parameter logic

  • Property mode set to 100644
File size: 12.9 KB
Line 
1"""
2Class interface to the model calculator.
3
4Calling a model is somewhat non-trivial since the functions called depend
5on the data type.  For 1D data the *Iq* kernel needs to be called, for
62D data the *Iqxy* kernel needs to be called, and for SESANS data the
7*Iq* kernel needs to be called followed by a Hankel transform.  Before
8the kernel is called an appropriate *q* calculation vector needs to be
9constructed.  This is not the simple *q* vector where you have measured
10the data since the resolution calculation will require values beyond the
11range of the measured data.  After the calculation the resolution calculator
12must be called to return the predicted value for each measured data point.
13
14:class:`DirectModel` is a callable object that takes *parameter=value*
15keyword arguments and returns the appropriate theory values for the data.
16
17:class:`DataMixin` does the real work of interpreting the data and calling
18the model calculator.  This is used by :class:`DirectModel`, which uses
19direct parameter values and by :class:`bumps_model.Experiment` which wraps
20the parameter values in boxes so that the user can set fitting ranges, etc.
21on the individual parameters and send the model to the Bumps optimizers.
22"""
23from __future__ import print_function
24
25import numpy as np  # type: ignore
26
27# TODO: fix sesans module
28from . import sesans  # type: ignore
29from . import weights
30from . import resolution
31from . import resolution2d
32from . import kernel
33
34try:
35    from typing import Optional, Dict, Tuple
36except ImportError:
37    pass
38else:
39    from .data import Data
40    from .kernel import Kernel, KernelModel
41    from .modelinfo import Parameter, ParameterSet
42
43def call_kernel(calculator, pars, cutoff=0., mono=False):
44    # type: (Kernel, ParameterSet, float, bool) -> np.ndarray
45    """
46    Call *kernel* returned from *model.make_kernel* with parameters *pars*.
47
48    *cutoff* is the limiting value for the product of dispersion weights used
49    to perform the multidimensional dispersion calculation more quickly at a
50    slight cost to accuracy. The default value of *cutoff=0* integrates over
51    the entire dispersion cube.  Using *cutoff=1e-5* can be 50% faster, but
52    with an error of about 1%, which is usually less than the measurement
53    uncertainty.
54
55    *mono* is True if polydispersity should be set to none on all parameters.
56    """
57    parameters = calculator.info.parameters
58    if mono:
59        active = lambda name: False
60    elif calculator.dim == '1d':
61        active = lambda name: name in parameters.pd_1d
62    elif calculator.dim == '2d':
63        active = lambda name: name in parameters.pd_2d
64    else:
65        active = lambda name: True
66
67    vw_pairs = [(get_weights(p, pars) if active(p.name)
68                 else ([pars.get(p.name, p.default)], [1.0]))
69                for p in parameters.call_parameters]
70
71    call_details, values = kernel.build_details(calculator, vw_pairs)
72    return calculator(call_details, values, cutoff)
73
74def get_weights(parameter, values):
75    # type: (Parameter, Dict[str, float]) -> Tuple[np.ndarray, np.ndarray]
76    """
77    Generate the distribution for parameter *name* given the parameter values
78    in *pars*.
79
80    Uses "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma"
81    from the *pars* dictionary for parameter value and parameter dispersion.
82    """
83    value = float(values.get(parameter.name, parameter.default))
84    relative = parameter.relative_pd
85    limits = parameter.limits
86    disperser = values.get(parameter.name+'_pd_type', 'gaussian')
87    npts = values.get(parameter.name+'_pd_n', 0)
88    width = values.get(parameter.name+'_pd', 0.0)
89    nsigma = values.get(parameter.name+'_pd_nsigma', 3.0)
90    if npts == 0 or width == 0:
91        return [value], [1.0]
92    value, weight = weights.get_weights(
93        disperser, npts, width, nsigma, value, limits, relative)
94    return value, weight / np.sum(weight)
95
96class DataMixin(object):
97    """
98    DataMixin captures the common aspects of evaluating a SAS model for a
99    particular data set, including calculating Iq and evaluating the
100    resolution function.  It is used in particular by :class:`DirectModel`,
101    which evaluates a SAS model parameters as key word arguments to the
102    calculator method, and by :class:`bumps_model.Experiment`, which wraps the
103    model and data for use with the Bumps fitting engine.  It is not
104    currently used by :class:`sasview_model.SasviewModel` since this will
105    require a number of changes to SasView before we can do it.
106
107    :meth:`_interpret_data` initializes the data structures necessary
108    to manage the calculations.  This sets attributes in the child class
109    such as *data_type* and *resolution*.
110
111    :meth:`_calc_theory` evaluates the model at the given control values.
112
113    :meth:`_set_data` sets the intensity data in the data object,
114    possibly with random noise added.  This is useful for simulating a
115    dataset with the results from :meth:`_calc_theory`.
116    """
117    def _interpret_data(self, data, model):
118        # type: (Data, KernelModel) -> None
119        # pylint: disable=attribute-defined-outside-init
120
121        self._data = data
122        self._model = model
123
124        # interpret data
125        if hasattr(data, 'lam'):
126            self.data_type = 'sesans'
127        elif hasattr(data, 'qx_data'):
128            self.data_type = 'Iqxy'
129        elif getattr(data, 'oriented', False):
130            self.data_type = 'Iq-oriented'
131        else:
132            self.data_type = 'Iq'
133
134        if self.data_type == 'sesans':
135            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
136            index = slice(None, None)
137            res = None
138            if data.y is not None:
139                Iq, dIq = data.y, data.dy
140            else:
141                Iq, dIq = None, None
142            #self._theory = np.zeros_like(q)
143            q_vectors = [q]           
144            q_mono = sesans.make_all_q(data)
145        elif self.data_type == 'Iqxy':
146            #if not model.info.parameters.has_2d:
147            #    raise ValueError("not 2D without orientation or magnetic parameters")
148            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
149            qmin = getattr(data, 'qmin', 1e-16)
150            qmax = getattr(data, 'qmax', np.inf)
151            accuracy = getattr(data, 'accuracy', 'Low')
152            index = ~data.mask & (q >= qmin) & (q <= qmax)
153            if data.data is not None:
154                index &= ~np.isnan(data.data)
155                Iq = data.data[index]
156                dIq = data.err_data[index]
157            else:
158                Iq, dIq = None, None
159            res = resolution2d.Pinhole2D(data=data, index=index,
160                                         nsigma=3.0, accuracy=accuracy)
161            #self._theory = np.zeros_like(self.Iq)
162            q_vectors = res.q_calc
163            q_mono = []
164        elif self.data_type == 'Iq':
165            index = (data.x >= data.qmin) & (data.x <= data.qmax)
166            if data.y is not None:
167                index &= ~np.isnan(data.y)
168                Iq = data.y[index]
169                dIq = data.dy[index]
170            else:
171                Iq, dIq = None, None
172            if getattr(data, 'dx', None) is not None:
173                q, dq = data.x[index], data.dx[index]
174                if (dq > 0).any():
175                    res = resolution.Pinhole1D(q, dq)
176                else:
177                    res = resolution.Perfect1D(q)
178            elif (getattr(data, 'dxl', None) is not None
179                  and getattr(data, 'dxw', None) is not None):
180                res = resolution.Slit1D(data.x[index],
181                                        qx_width=data.dxl[index],
182                                        qy_width=data.dxw[index])
183            else:
184                res = resolution.Perfect1D(data.x[index])
185
186            #self._theory = np.zeros_like(self.Iq)
187            q_vectors = [res.q_calc]
188            q_mono = []
189        elif self.data_type == 'Iq-oriented':
190            index = (data.x >= data.qmin) & (data.x <= data.qmax)
191            if data.y is not None:
192                index &= ~np.isnan(data.y)
193                Iq = data.y[index]
194                dIq = data.dy[index]
195            else:
196                Iq, dIq = None, None
197            if (getattr(data, 'dxl', None) is None
198                or getattr(data, 'dxw', None) is None):
199                raise ValueError("oriented sample with 1D data needs slit resolution")
200
201            res = resolution2d.Slit2D(data.x[index],
202                                      qx_width=data.dxw[index],
203                                      qy_width=data.dxl[index])
204            q_vectors = res.q_calc
205            q_mono = []
206        else:
207            raise ValueError("Unknown data type") # never gets here
208
209        # Remember function inputs so we can delay loading the function and
210        # so we can save/restore state
211        self._kernel_inputs = q_vectors
212        self._kernel_mono_inputs = q_mono
213        self._kernel = None
214        self.Iq, self.dIq, self.index = Iq, dIq, index
215        self.resolution = res
216
217    def _set_data(self, Iq, noise=None):
218        # type: (np.ndarray, Optional[float]) -> None
219        # pylint: disable=attribute-defined-outside-init
220        if noise is not None:
221            self.dIq = Iq*noise*0.01
222        dy = self.dIq
223        y = Iq + np.random.randn(*dy.shape) * dy
224        self.Iq = y
225        if self.data_type in ('Iq', 'Iq-oriented'):
226            self._data.dy[self.index] = dy
227            self._data.y[self.index] = y
228        elif self.data_type == 'Iqxy':
229            self._data.data[self.index] = y
230        elif self.data_type == 'sesans':
231            self._data.y[self.index] = y
232        else:
233            raise ValueError("Unknown model")
234
235    def _calc_theory(self, pars, cutoff=0.0):
236        # type: (ParameterSet, float) -> np.ndarray
237        if self._kernel is None:
238            self._kernel = self._model.make_kernel(self._kernel_inputs)
239            self._kernel_mono = (self._model.make_kernel(self._kernel_mono_inputs)
240                                 if self._kernel_mono_inputs else None)
241
242        Iq_calc = call_kernel(self._kernel, pars, cutoff=cutoff)
243        # TODO: may want to plot the raw Iq for other than oriented usans
244        self.Iq_calc = None
245        if self.data_type == 'sesans':
246            Iq_mono = (call_kernel(self._kernel_mono, pars, mono=True)
247                       if self._kernel_mono_inputs else None)
248            result = sesans.transform(self._data,
249                                   self._kernel_inputs[0], Iq_calc, 
250                                   self._kernel_mono_inputs, Iq_mono)
251        else:
252            result = self.resolution.apply(Iq_calc)
253            if hasattr(self.resolution, 'nx'):
254                self.Iq_calc = (
255                    self.resolution.qx_calc, self.resolution.qy_calc,
256                    np.reshape(Iq_calc, (self.resolution.ny, self.resolution.nx))
257                )
258        return result       
259
260
261class DirectModel(DataMixin):
262    """
263    Create a calculator object for a model.
264
265    *data* is 1D SAS, 2D SAS or SESANS data
266
267    *model* is a model calculator return from :func:`generate.load_model`
268
269    *cutoff* is the polydispersity weight cutoff.
270    """
271    def __init__(self, data, model, cutoff=1e-5):
272        # type: (Data, KernelModel, float) -> None
273        self.model = model
274        self.cutoff = cutoff
275        # Note: _interpret_data defines the model attributes
276        self._interpret_data(data, model)
277
278    def __call__(self, **pars):
279        # type: (**float) -> np.ndarray
280        return self._calc_theory(pars, cutoff=self.cutoff)
281
282    def simulate_data(self, noise=None, **pars):
283        # type: (Optional[float], **float) -> None
284        """
285        Generate simulated data for the model.
286        """
287        Iq = self.__call__(**pars)
288        self._set_data(Iq, noise=noise)
289
290def main():
291    # type: () -> None
292    """
293    Program to evaluate a particular model at a set of q values.
294    """
295    import sys
296    from .data import empty_data1D, empty_data2D
297    from .core import load_model_info, build_model
298
299    if len(sys.argv) < 3:
300        print("usage: python -m sasmodels.direct_model modelname (q|qx,qy) par=val ...")
301        sys.exit(1)
302    model_name = sys.argv[1]
303    call = sys.argv[2].upper()
304    if call != "ER_VR":
305        try:
306            values = [float(v) for v in call.split(',')]
307        except Exception:
308            values = []
309        if len(values) == 1:
310            q, = values
311            data = empty_data1D([q])
312        elif len(values) == 2:
313            qx, qy = values
314            data = empty_data2D([qx], [qy])
315        else:
316            print("use q or qx,qy or ER or VR")
317            sys.exit(1)
318    else:
319        data = empty_data1D([0.001])  # Data not used in ER/VR
320
321    model_info = load_model_info(model_name)
322    model = build_model(model_info)
323    calculator = DirectModel(data, model)
324    pars = dict((k, float(v))
325                for pair in sys.argv[3:]
326                for k, v in [pair.split('=')])
327    if call == "ER_VR":
328        print(calculator.ER_VR(**pars))
329    else:
330        Iq = calculator(**pars)
331        print(Iq[0])
332
333if __name__ == "__main__":
334    main()
Note: See TracBrowser for help on using the repository browser.