source: sasmodels/sasmodels/direct_model.py @ 56547a8

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 56547a8 was 9eb3632, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

restructure kernels using fixed PD loops

  • Property mode set to 100644
File size: 13.0 KB
Line 
1"""
2Class interface to the model calculator.
3
4Calling a model is somewhat non-trivial since the functions called depend
5on the data type.  For 1D data the *Iq* kernel needs to be called, for
62D data the *Iqxy* kernel needs to be called, and for SESANS data the
7*Iq* kernel needs to be called followed by a Hankel transform.  Before
8the kernel is called an appropriate *q* calculation vector needs to be
9constructed.  This is not the simple *q* vector where you have measured
10the data since the resolution calculation will require values beyond the
11range of the measured data.  After the calculation the resolution calculator
12must be called to return the predicted value for each measured data point.
13
14:class:`DirectModel` is a callable object that takes *parameter=value*
15keyword arguments and returns the appropriate theory values for the data.
16
17:class:`DataMixin` does the real work of interpreting the data and calling
18the model calculator.  This is used by :class:`DirectModel`, which uses
19direct parameter values and by :class:`bumps_model.Experiment` which wraps
20the parameter values in boxes so that the user can set fitting ranges, etc.
21on the individual parameters and send the model to the Bumps optimizers.
22"""
23from __future__ import print_function
24
25import numpy as np  # type: ignore
26
27# TODO: fix sesans module
28from . import sesans  # type: ignore
29from . import weights
30from . import resolution
31from . import resolution2d
32from .details import build_details
33
34try:
35    from typing import Optional, Dict, Tuple
36except ImportError:
37    pass
38else:
39    from .data import Data
40    from .kernel import Kernel, KernelModel
41    from .modelinfo import Parameter, ParameterSet
42
43def call_kernel(calculator, pars, cutoff=0., mono=False):
44    # type: (Kernel, ParameterSet, float, bool) -> np.ndarray
45    """
46    Call *kernel* returned from *model.make_kernel* with parameters *pars*.
47
48    *cutoff* is the limiting value for the product of dispersion weights used
49    to perform the multidimensional dispersion calculation more quickly at a
50    slight cost to accuracy. The default value of *cutoff=0* integrates over
51    the entire dispersion cube.  Using *cutoff=1e-5* can be 50% faster, but
52    with an error of about 1%, which is usually less than the measurement
53    uncertainty.
54
55    *mono* is True if polydispersity should be set to none on all parameters.
56    """
57    parameters = calculator.info.parameters
58    if mono:
59        active = lambda name: False
60    elif calculator.dim == '1d':
61        active = lambda name: name in parameters.pd_1d
62    elif calculator.dim == '2d':
63        active = lambda name: name in parameters.pd_2d
64    else:
65        active = lambda name: True
66
67    #print("pars",[p.id for p in parameters.call_parameters])
68    vw_pairs = [(get_weights(p, pars) if active(p.name)
69                 else ([pars.get(p.name, p.default)], [1.0]))
70                for p in parameters.call_parameters]
71
72    call_details, values, is_magnetic = build_details(calculator, vw_pairs)
73    #print("values:", values)
74    return calculator(call_details, values, cutoff, is_magnetic)
75
76def get_weights(parameter, values):
77    # type: (Parameter, Dict[str, float]) -> Tuple[np.ndarray, np.ndarray]
78    """
79    Generate the distribution for parameter *name* given the parameter values
80    in *pars*.
81
82    Uses "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma"
83    from the *pars* dictionary for parameter value and parameter dispersion.
84    """
85    value = float(values.get(parameter.name, parameter.default))
86    relative = parameter.relative_pd
87    limits = parameter.limits
88    disperser = values.get(parameter.name+'_pd_type', 'gaussian')
89    npts = values.get(parameter.name+'_pd_n', 0)
90    width = values.get(parameter.name+'_pd', 0.0)
91    nsigma = values.get(parameter.name+'_pd_nsigma', 3.0)
92    if npts == 0 or width == 0:
93        return [value], [1.0]
94    value, weight = weights.get_weights(
95        disperser, npts, width, nsigma, value, limits, relative)
96    return value, weight / np.sum(weight)
97
98class DataMixin(object):
99    """
100    DataMixin captures the common aspects of evaluating a SAS model for a
101    particular data set, including calculating Iq and evaluating the
102    resolution function.  It is used in particular by :class:`DirectModel`,
103    which evaluates a SAS model parameters as key word arguments to the
104    calculator method, and by :class:`bumps_model.Experiment`, which wraps the
105    model and data for use with the Bumps fitting engine.  It is not
106    currently used by :class:`sasview_model.SasviewModel` since this will
107    require a number of changes to SasView before we can do it.
108
109    :meth:`_interpret_data` initializes the data structures necessary
110    to manage the calculations.  This sets attributes in the child class
111    such as *data_type* and *resolution*.
112
113    :meth:`_calc_theory` evaluates the model at the given control values.
114
115    :meth:`_set_data` sets the intensity data in the data object,
116    possibly with random noise added.  This is useful for simulating a
117    dataset with the results from :meth:`_calc_theory`.
118    """
119    def _interpret_data(self, data, model):
120        # type: (Data, KernelModel) -> None
121        # pylint: disable=attribute-defined-outside-init
122
123        self._data = data
124        self._model = model
125
126        # interpret data
127        if hasattr(data, 'lam'):
128            self.data_type = 'sesans'
129        elif hasattr(data, 'qx_data'):
130            self.data_type = 'Iqxy'
131        elif getattr(data, 'oriented', False):
132            self.data_type = 'Iq-oriented'
133        else:
134            self.data_type = 'Iq'
135
136        if self.data_type == 'sesans':
137            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
138            index = slice(None, None)
139            res = None
140            if data.y is not None:
141                Iq, dIq = data.y, data.dy
142            else:
143                Iq, dIq = None, None
144            #self._theory = np.zeros_like(q)
145            q_vectors = [q]           
146            q_mono = sesans.make_all_q(data)
147        elif self.data_type == 'Iqxy':
148            #if not model.info.parameters.has_2d:
149            #    raise ValueError("not 2D without orientation or magnetic parameters")
150            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
151            qmin = getattr(data, 'qmin', 1e-16)
152            qmax = getattr(data, 'qmax', np.inf)
153            accuracy = getattr(data, 'accuracy', 'Low')
154            index = ~data.mask & (q >= qmin) & (q <= qmax)
155            if data.data is not None:
156                index &= ~np.isnan(data.data)
157                Iq = data.data[index]
158                dIq = data.err_data[index]
159            else:
160                Iq, dIq = None, None
161            res = resolution2d.Pinhole2D(data=data, index=index,
162                                         nsigma=3.0, accuracy=accuracy)
163            #self._theory = np.zeros_like(self.Iq)
164            q_vectors = res.q_calc
165            q_mono = []
166        elif self.data_type == 'Iq':
167            index = (data.x >= data.qmin) & (data.x <= data.qmax)
168            if data.y is not None:
169                index &= ~np.isnan(data.y)
170                Iq = data.y[index]
171                dIq = data.dy[index]
172            else:
173                Iq, dIq = None, None
174            if getattr(data, 'dx', None) is not None:
175                q, dq = data.x[index], data.dx[index]
176                if (dq > 0).any():
177                    res = resolution.Pinhole1D(q, dq)
178                else:
179                    res = resolution.Perfect1D(q)
180            elif (getattr(data, 'dxl', None) is not None
181                  and getattr(data, 'dxw', None) is not None):
182                res = resolution.Slit1D(data.x[index],
183                                        qx_width=data.dxl[index],
184                                        qy_width=data.dxw[index])
185            else:
186                res = resolution.Perfect1D(data.x[index])
187
188            #self._theory = np.zeros_like(self.Iq)
189            q_vectors = [res.q_calc]
190            q_mono = []
191        elif self.data_type == 'Iq-oriented':
192            index = (data.x >= data.qmin) & (data.x <= data.qmax)
193            if data.y is not None:
194                index &= ~np.isnan(data.y)
195                Iq = data.y[index]
196                dIq = data.dy[index]
197            else:
198                Iq, dIq = None, None
199            if (getattr(data, 'dxl', None) is None
200                or getattr(data, 'dxw', None) is None):
201                raise ValueError("oriented sample with 1D data needs slit resolution")
202
203            res = resolution2d.Slit2D(data.x[index],
204                                      qx_width=data.dxw[index],
205                                      qy_width=data.dxl[index])
206            q_vectors = res.q_calc
207            q_mono = []
208        else:
209            raise ValueError("Unknown data type") # never gets here
210
211        # Remember function inputs so we can delay loading the function and
212        # so we can save/restore state
213        self._kernel_inputs = q_vectors
214        self._kernel_mono_inputs = q_mono
215        self._kernel = None
216        self.Iq, self.dIq, self.index = Iq, dIq, index
217        self.resolution = res
218
219    def _set_data(self, Iq, noise=None):
220        # type: (np.ndarray, Optional[float]) -> None
221        # pylint: disable=attribute-defined-outside-init
222        if noise is not None:
223            self.dIq = Iq*noise*0.01
224        dy = self.dIq
225        y = Iq + np.random.randn(*dy.shape) * dy
226        self.Iq = y
227        if self.data_type in ('Iq', 'Iq-oriented'):
228            self._data.dy[self.index] = dy
229            self._data.y[self.index] = y
230        elif self.data_type == 'Iqxy':
231            self._data.data[self.index] = y
232        elif self.data_type == 'sesans':
233            self._data.y[self.index] = y
234        else:
235            raise ValueError("Unknown model")
236
237    def _calc_theory(self, pars, cutoff=0.0):
238        # type: (ParameterSet, float) -> np.ndarray
239        if self._kernel is None:
240            self._kernel = self._model.make_kernel(self._kernel_inputs)
241            self._kernel_mono = (self._model.make_kernel(self._kernel_mono_inputs)
242                                 if self._kernel_mono_inputs else None)
243
244        Iq_calc = call_kernel(self._kernel, pars, cutoff=cutoff)
245        # TODO: may want to plot the raw Iq for other than oriented usans
246        self.Iq_calc = None
247        if self.data_type == 'sesans':
248            Iq_mono = (call_kernel(self._kernel_mono, pars, mono=True)
249                       if self._kernel_mono_inputs else None)
250            result = sesans.transform(self._data,
251                                   self._kernel_inputs[0], Iq_calc, 
252                                   self._kernel_mono_inputs, Iq_mono)
253        else:
254            result = self.resolution.apply(Iq_calc)
255            if hasattr(self.resolution, 'nx'):
256                self.Iq_calc = (
257                    self.resolution.qx_calc, self.resolution.qy_calc,
258                    np.reshape(Iq_calc, (self.resolution.ny, self.resolution.nx))
259                )
260        return result       
261
262
263class DirectModel(DataMixin):
264    """
265    Create a calculator object for a model.
266
267    *data* is 1D SAS, 2D SAS or SESANS data
268
269    *model* is a model calculator return from :func:`generate.load_model`
270
271    *cutoff* is the polydispersity weight cutoff.
272    """
273    def __init__(self, data, model, cutoff=1e-5):
274        # type: (Data, KernelModel, float) -> None
275        self.model = model
276        self.cutoff = cutoff
277        # Note: _interpret_data defines the model attributes
278        self._interpret_data(data, model)
279
280    def __call__(self, **pars):
281        # type: (**float) -> np.ndarray
282        return self._calc_theory(pars, cutoff=self.cutoff)
283
284    def simulate_data(self, noise=None, **pars):
285        # type: (Optional[float], **float) -> None
286        """
287        Generate simulated data for the model.
288        """
289        Iq = self.__call__(**pars)
290        self._set_data(Iq, noise=noise)
291
292def main():
293    # type: () -> None
294    """
295    Program to evaluate a particular model at a set of q values.
296    """
297    import sys
298    from .data import empty_data1D, empty_data2D
299    from .core import load_model_info, build_model
300
301    if len(sys.argv) < 3:
302        print("usage: python -m sasmodels.direct_model modelname (q|qx,qy) par=val ...")
303        sys.exit(1)
304    model_name = sys.argv[1]
305    call = sys.argv[2].upper()
306    if call != "ER_VR":
307        try:
308            values = [float(v) for v in call.split(',')]
309        except Exception:
310            values = []
311        if len(values) == 1:
312            q, = values
313            data = empty_data1D([q])
314        elif len(values) == 2:
315            qx, qy = values
316            data = empty_data2D([qx], [qy])
317        else:
318            print("use q or qx,qy or ER or VR")
319            sys.exit(1)
320    else:
321        data = empty_data1D([0.001])  # Data not used in ER/VR
322
323    model_info = load_model_info(model_name)
324    model = build_model(model_info)
325    calculator = DirectModel(data, model)
326    pars = dict((k, float(v))
327                for pair in sys.argv[3:]
328                for k, v in [pair.split('=')])
329    if call == "ER_VR":
330        print(calculator.ER_VR(**pars))
331    else:
332        Iq = calculator(**pars)
333        print(Iq[0])
334
335if __name__ == "__main__":
336    main()
Note: See TracBrowser for help on using the repository browser.