source: sasmodels/sasmodels/direct_model.py @ 9e771a3

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9e771a3 was 9e771a3, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

sort out weights (seems to be correct this time)

  • Property mode set to 100644
File size: 17.1 KB
Line 
1"""
2Class interface to the model calculator.
3
4Calling a model is somewhat non-trivial since the functions called depend
5on the data type.  For 1D data the *Iq* kernel needs to be called, for
62D data the *Iqxy* kernel needs to be called, and for SESANS data the
7*Iq* kernel needs to be called followed by a Hankel transform.  Before
8the kernel is called an appropriate *q* calculation vector needs to be
9constructed.  This is not the simple *q* vector where you have measured
10the data since the resolution calculation will require values beyond the
11range of the measured data.  After the calculation the resolution calculator
12must be called to return the predicted value for each measured data point.
13
14:class:`DirectModel` is a callable object that takes *parameter=value*
15keyword arguments and returns the appropriate theory values for the data.
16
17:class:`DataMixin` does the real work of interpreting the data and calling
18the model calculator.  This is used by :class:`DirectModel`, which uses
19direct parameter values and by :class:`bumps_model.Experiment` which wraps
20the parameter values in boxes so that the user can set fitting ranges, etc.
21on the individual parameters and send the model to the Bumps optimizers.
22"""
23from __future__ import print_function
24
25import numpy as np  # type: ignore
26
27# TODO: fix sesans module
28from . import sesans  # type: ignore
29from . import weights
30from . import resolution
31from . import resolution2d
32from .details import make_kernel_args, dispersion_mesh
33
34try:
35    from typing import Optional, Dict, Tuple
36except ImportError:
37    pass
38else:
39    from .data import Data
40    from .kernel import Kernel, KernelModel
41    from .modelinfo import Parameter, ParameterSet
42
43def call_kernel(calculator, pars, cutoff=0., mono=False):
44    # type: (Kernel, ParameterSet, float, bool) -> np.ndarray
45    """
46    Call *kernel* returned from *model.make_kernel* with parameters *pars*.
47
48    *cutoff* is the limiting value for the product of dispersion weights used
49    to perform the multidimensional dispersion calculation more quickly at a
50    slight cost to accuracy. The default value of *cutoff=0* integrates over
51    the entire dispersion cube.  Using *cutoff=1e-5* can be 50% faster, but
52    with an error of about 1%, which is usually less than the measurement
53    uncertainty.
54
55    *mono* is True if polydispersity should be set to none on all parameters.
56    """
57    mesh = get_mesh(calculator.info, pars, dim=calculator.dim, mono=mono)
58    #print("pars", list(zip(*mesh))[0])
59    call_details, values, is_magnetic = make_kernel_args(calculator, mesh)
60    #print("values:", values)
61    return calculator(call_details, values, cutoff, is_magnetic)
62
63def call_ER(model_info, pars):
64    # type: (ModelInfo, ParameterSet) -> float
65    """
66    Call the model ER function using *values*.
67
68    *model_info* is either *model.info* if you have a loaded model,
69    or *kernel.info* if you have a model kernel prepared for evaluation.
70    """
71    if model_info.ER is None:
72        return 1.0
73    elif not model_info.parameters.form_volume_parameters:
74        # handle the case where ER is provided but model is not polydisperse
75        return model_info.ER()
76    else:
77        value, weight = _vol_pars(model_info, pars)
78        individual_radii = model_info.ER(*value)
79        return np.sum(weight*individual_radii) / np.sum(weight)
80
81
82def call_VR(model_info, pars):
83    # type: (ModelInfo, ParameterSet) -> float
84    """
85    Call the model VR function using *pars*.
86
87    *model_info* is either *model.info* if you have a loaded model,
88    or *kernel.info* if you have a model kernel prepared for evaluation.
89    """
90    if model_info.VR is None:
91        return 1.0
92    elif not model_info.parameters.form_volume_parameters:
93        # handle the case where ER is provided but model is not polydisperse
94        return model_info.VR()
95    else:
96        value, weight = _vol_pars(model_info, pars)
97        whole, part = model_info.VR(*value)
98        return np.sum(weight*part)/np.sum(weight*whole)
99
100
101def call_profile(model_info, **pars):
102    # type: (ModelInfo, ...) -> Tuple[np.ndarray, np.ndarray, Tuple[str, str]]
103    """
104    Returns the profile *x, y, (xlabel, ylabel)* representing the model.
105    """
106    args = {}
107    for p in model_info.parameters.kernel_parameters:
108        if p.length > 1:
109            value = np.array([pars.get(p.id+str(j), p.default)
110                              for j in range(1, p.length+1)])
111        else:
112            value = pars.get(p.id, p.default)
113        args[p.id] = value
114    x, y = model_info.profile(**args)
115    return x, y, model_info.profile_axes
116
117def get_mesh(model_info, values, dim='1d', mono=False):
118    # type: (ModelInfo, Dict[str, float], str, bool) -> List[Tuple[float, np.ndarray, np.ndarry]]
119    """
120    Retrieve the dispersity mesh described by the parameter set.
121
122    Returns a list of *(value, dispersity, weights)* with one tuple for each
123    parameter in the model call parameters.  Inactive parameters return the
124    default value with a weight of 1.0.
125    """
126    parameters = model_info.parameters
127    if mono:
128        active = lambda name: False
129    elif dim == '1d':
130        active = lambda name: name in parameters.pd_1d
131    elif dim == '2d':
132        active = lambda name: name in parameters.pd_2d
133    else:
134        active = lambda name: True
135
136    #print("pars",[p.id for p in parameters.call_parameters])
137    mesh = [_get_par_weights(p, values, active(p.name))
138            for p in parameters.call_parameters]
139    return mesh
140
141
142def _get_par_weights(parameter, values, active=True):
143    # type: (Parameter, Dict[str, float]) -> Tuple[float, np.ndarray, np.ndarray]
144    """
145    Generate the distribution for parameter *name* given the parameter values
146    in *pars*.
147
148    Uses "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma"
149    from the *pars* dictionary for parameter value and parameter dispersion.
150    """
151    value = float(values.get(parameter.name, parameter.default))
152    npts = values.get(parameter.name+'_pd_n', 0)
153    width = values.get(parameter.name+'_pd', 0.0)
154    relative = parameter.relative_pd
155    if npts == 0 or width == 0.0 or not active:
156        # Note: orientation parameters have the viewing angle as the parameter
157        # value and the jitter in the distribution, so be sure to set the
158        # empty pd for orientation parameters to 0.
159        pd = [value if relative or not parameter.polydisperse else 0.0], [1.0]
160    else:
161        limits = parameter.limits
162        disperser = values.get(parameter.name+'_pd_type', 'gaussian')
163        nsigma = values.get(parameter.name+'_pd_nsigma', 3.0)
164        pd = weights.get_weights(disperser, npts, width, nsigma,
165                                value, limits, relative)
166    return value, pd[0], pd[1]
167
168
169def _vol_pars(model_info, values):
170    # type: (ModelInfo, ParameterSet) -> Tuple[np.ndarray, np.ndarray]
171    vol_pars = [_get_par_weights(p, values)
172                for p in model_info.parameters.call_parameters
173                if p.type == 'volume']
174    #import pylab; pylab.plot(vol_pars[0][0],vol_pars[0][1]); pylab.show()
175    dispersity, weight = dispersion_mesh(model_info, vol_pars)
176    return dispersity, weight
177
178
179class DataMixin(object):
180    """
181    DataMixin captures the common aspects of evaluating a SAS model for a
182    particular data set, including calculating Iq and evaluating the
183    resolution function.  It is used in particular by :class:`DirectModel`,
184    which evaluates a SAS model parameters as key word arguments to the
185    calculator method, and by :class:`bumps_model.Experiment`, which wraps the
186    model and data for use with the Bumps fitting engine.  It is not
187    currently used by :class:`sasview_model.SasviewModel` since this will
188    require a number of changes to SasView before we can do it.
189
190    :meth:`_interpret_data` initializes the data structures necessary
191    to manage the calculations.  This sets attributes in the child class
192    such as *data_type* and *resolution*.
193
194    :meth:`_calc_theory` evaluates the model at the given control values.
195
196    :meth:`_set_data` sets the intensity data in the data object,
197    possibly with random noise added.  This is useful for simulating a
198    dataset with the results from :meth:`_calc_theory`.
199    """
200    def _interpret_data(self, data, model):
201        # type: (Data, KernelModel) -> None
202        # pylint: disable=attribute-defined-outside-init
203
204        self._data = data
205        self._model = model
206
207        # interpret data
208        if hasattr(data, 'isSesans') and data.isSesans:
209            self.data_type = 'sesans'
210        elif hasattr(data, 'qx_data'):
211            self.data_type = 'Iqxy'
212        elif getattr(data, 'oriented', False):
213            self.data_type = 'Iq-oriented'
214        else:
215            self.data_type = 'Iq'
216
217        if self.data_type == 'sesans':
218            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
219            index = slice(None, None)
220            res = None
221            if data.y is not None:
222                Iq, dIq = data.y, data.dy
223            else:
224                Iq, dIq = None, None
225            #self._theory = np.zeros_like(q)
226            q_vectors = [q]
227            q_mono = sesans.make_all_q(data)
228        elif self.data_type == 'Iqxy':
229            #if not model.info.parameters.has_2d:
230            #    raise ValueError("not 2D without orientation or magnetic parameters")
231            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
232            qmin = getattr(data, 'qmin', 1e-16)
233            qmax = getattr(data, 'qmax', np.inf)
234            accuracy = getattr(data, 'accuracy', 'Low')
235            index = ~data.mask & (q >= qmin) & (q <= qmax)
236            if data.data is not None:
237                index &= ~np.isnan(data.data)
238                Iq = data.data[index]
239                dIq = data.err_data[index]
240            else:
241                Iq, dIq = None, None
242            res = resolution2d.Pinhole2D(data=data, index=index,
243                                         nsigma=3.0, accuracy=accuracy)
244            #self._theory = np.zeros_like(self.Iq)
245            q_vectors = res.q_calc
246            q_mono = []
247        elif self.data_type == 'Iq':
248            index = (data.x >= data.qmin) & (data.x <= data.qmax)
249            if data.y is not None:
250                index &= ~np.isnan(data.y)
251                Iq = data.y[index]
252                dIq = data.dy[index]
253            else:
254                Iq, dIq = None, None
255            if getattr(data, 'dx', None) is not None:
256                q, dq = data.x[index], data.dx[index]
257                if (dq > 0).any():
258                    res = resolution.Pinhole1D(q, dq)
259                else:
260                    res = resolution.Perfect1D(q)
261            elif (getattr(data, 'dxl', None) is not None
262                  and getattr(data, 'dxw', None) is not None):
263                res = resolution.Slit1D(data.x[index],
264                                        qx_width=data.dxl[index],
265                                        qy_width=data.dxw[index])
266            else:
267                res = resolution.Perfect1D(data.x[index])
268
269            #self._theory = np.zeros_like(self.Iq)
270            q_vectors = [res.q_calc]
271            q_mono = []
272        elif self.data_type == 'Iq-oriented':
273            index = (data.x >= data.qmin) & (data.x <= data.qmax)
274            if data.y is not None:
275                index &= ~np.isnan(data.y)
276                Iq = data.y[index]
277                dIq = data.dy[index]
278            else:
279                Iq, dIq = None, None
280            if (getattr(data, 'dxl', None) is None
281                    or getattr(data, 'dxw', None) is None):
282                raise ValueError("oriented sample with 1D data needs slit resolution")
283
284            res = resolution2d.Slit2D(data.x[index],
285                                      qx_width=data.dxw[index],
286                                      qy_width=data.dxl[index])
287            q_vectors = res.q_calc
288            q_mono = []
289        else:
290            raise ValueError("Unknown data type") # never gets here
291
292        # Remember function inputs so we can delay loading the function and
293        # so we can save/restore state
294        self._kernel_inputs = q_vectors
295        self._kernel_mono_inputs = q_mono
296        self._kernel = None
297        self.Iq, self.dIq, self.index = Iq, dIq, index
298        self.resolution = res
299
300    def _set_data(self, Iq, noise=None):
301        # type: (np.ndarray, Optional[float]) -> None
302        # pylint: disable=attribute-defined-outside-init
303        if noise is not None:
304            self.dIq = Iq*noise*0.01
305        dy = self.dIq
306        y = Iq + np.random.randn(*dy.shape) * dy
307        self.Iq = y
308        if self.data_type in ('Iq', 'Iq-oriented'):
309            if self._data.y is None:
310                self._data.y = np.empty(len(self._data.x), 'd')
311            if self._data.dy is None:
312                self._data.dy = np.empty(len(self._data.x), 'd')
313            self._data.dy[self.index] = dy
314            self._data.y[self.index] = y
315        elif self.data_type == 'Iqxy':
316            if self._data.data is None:
317                self._data.data = np.empty_like(self._data.qx_data, 'd')
318            if self._data.err_data is None:
319                self._data.err_data = np.empty_like(self._data.qx_data, 'd')
320            self._data.data[self.index] = y
321            self._data.err_data[self.index] = dy
322        elif self.data_type == 'sesans':
323            if self._data.y is None:
324                self._data.y = np.empty(len(self._data.x), 'd')
325            self._data.y[self.index] = y
326        else:
327            raise ValueError("Unknown model")
328
329    def _calc_theory(self, pars, cutoff=0.0):
330        # type: (ParameterSet, float) -> np.ndarray
331        if self._kernel is None:
332            self._kernel = self._model.make_kernel(self._kernel_inputs)
333            self._kernel_mono = (
334                self._model.make_kernel(self._kernel_mono_inputs)
335                if self._kernel_mono_inputs else None)
336
337        Iq_calc = call_kernel(self._kernel, pars, cutoff=cutoff)
338        # Storing the calculated Iq values so that they can be plotted.
339        # Only applies to oriented USANS data for now.
340        # TODO: extend plotting of calculate Iq to other measurement types
341        # TODO: refactor so we don't store the result in the model
342        self.Iq_calc = Iq_calc
343        if self.data_type == 'sesans':
344            Iq_mono = (call_kernel(self._kernel_mono, pars, mono=True)
345                       if self._kernel_mono_inputs else None)
346            result = sesans.transform(self._data,
347                                      self._kernel_inputs[0], Iq_calc,
348                                      self._kernel_mono_inputs, Iq_mono)
349        else:
350            result = self.resolution.apply(Iq_calc)
351            if hasattr(self.resolution, 'nx'):
352                self.Iq_calc = (
353                    self.resolution.qx_calc, self.resolution.qy_calc,
354                    np.reshape(Iq_calc, (self.resolution.ny, self.resolution.nx))
355                )
356        return result
357
358
359class DirectModel(DataMixin):
360    """
361    Create a calculator object for a model.
362
363    *data* is 1D SAS, 2D SAS or SESANS data
364
365    *model* is a model calculator return from :func:`generate.load_model`
366
367    *cutoff* is the polydispersity weight cutoff.
368    """
369    def __init__(self, data, model, cutoff=1e-5):
370        # type: (Data, KernelModel, float) -> None
371        self.model = model
372        self.cutoff = cutoff
373        # Note: _interpret_data defines the model attributes
374        self._interpret_data(data, model)
375
376    def __call__(self, **pars):
377        # type: (**float) -> np.ndarray
378        return self._calc_theory(pars, cutoff=self.cutoff)
379
380    def simulate_data(self, noise=None, **pars):
381        # type: (Optional[float], **float) -> None
382        """
383        Generate simulated data for the model.
384        """
385        Iq = self.__call__(**pars)
386        self._set_data(Iq, noise=noise)
387
388    def profile(self, **pars):
389        # type: (**float) -> None
390        """
391        Generate a plottable profile.
392        """
393        return call_profile(self.model.info, **pars)
394
395def main():
396    # type: () -> None
397    """
398    Program to evaluate a particular model at a set of q values.
399    """
400    import sys
401    from .data import empty_data1D, empty_data2D
402    from .core import load_model_info, build_model
403
404    if len(sys.argv) < 3:
405        print("usage: python -m sasmodels.direct_model modelname (q|qx,qy) par=val ...")
406        sys.exit(1)
407    model_name = sys.argv[1]
408    call = sys.argv[2].upper()
409    if call != "ER_VR":
410        try:
411            values = [float(v) for v in call.split(',')]
412        except Exception:
413            values = []
414        if len(values) == 1:
415            q, = values
416            data = empty_data1D([q])
417        elif len(values) == 2:
418            qx, qy = values
419            data = empty_data2D([qx], [qy])
420        else:
421            print("use q or qx,qy or ER or VR")
422            sys.exit(1)
423    else:
424        data = empty_data1D([0.001])  # Data not used in ER/VR
425
426    model_info = load_model_info(model_name)
427    model = build_model(model_info)
428    calculator = DirectModel(data, model)
429    pars = dict((k, (float(v) if not k.endswith("_pd_type") else v))
430                for pair in sys.argv[3:]
431                for k, v in [pair.split('=')])
432    if call == "ER_VR":
433        ER = call_ER(model_info, pars)
434        VR = call_VR(model_info, pars)
435        print(ER, VR)
436    else:
437        Iq = calculator(**pars)
438        print(Iq[0])
439
440if __name__ == "__main__":
441    main()
Note: See TracBrowser for help on using the repository browser.