source: sasmodels/sasmodels/direct_model.py @ 2d81cfe

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 2d81cfe was 2d81cfe, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

lint

  • Property mode set to 100644
File size: 17.1 KB
RevLine 
[803f835]1"""
2Class interface to the model calculator.
3
4Calling a model is somewhat non-trivial since the functions called depend
5on the data type.  For 1D data the *Iq* kernel needs to be called, for
62D data the *Iqxy* kernel needs to be called, and for SESANS data the
7*Iq* kernel needs to be called followed by a Hankel transform.  Before
8the kernel is called an appropriate *q* calculation vector needs to be
9constructed.  This is not the simple *q* vector where you have measured
10the data since the resolution calculation will require values beyond the
11range of the measured data.  After the calculation the resolution calculator
12must be called to return the predicted value for each measured data point.
13
14:class:`DirectModel` is a callable object that takes *parameter=value*
15keyword arguments and returns the appropriate theory values for the data.
16
17:class:`DataMixin` does the real work of interpreting the data and calling
18the model calculator.  This is used by :class:`DirectModel`, which uses
19direct parameter values and by :class:`bumps_model.Experiment` which wraps
20the parameter values in boxes so that the user can set fitting ranges, etc.
21on the individual parameters and send the model to the Bumps optimizers.
22"""
23from __future__ import print_function
[ae7b97b]24
[7ae2b7f]25import numpy as np  # type: ignore
[ae7b97b]26
[7ae2b7f]27# TODO: fix sesans module
28from . import sesans  # type: ignore
[6d6508e]29from . import weights
[7cf2cfd]30from . import resolution
31from . import resolution2d
[bde38b5]32from .details import make_kernel_args, dispersion_mesh
[6d6508e]33
[2d81cfe]34# pylint: disable=unused-import
[a5b8477]35try:
36    from typing import Optional, Dict, Tuple
37except ImportError:
38    pass
39else:
40    from .data import Data
41    from .kernel import Kernel, KernelModel
42    from .modelinfo import Parameter, ParameterSet
[2d81cfe]43# pylint: enable=unused-import
[a5b8477]44
[0ff62d4]45def call_kernel(calculator, pars, cutoff=0., mono=False):
[a5b8477]46    # type: (Kernel, ParameterSet, float, bool) -> np.ndarray
[6d6508e]47    """
48    Call *kernel* returned from *model.make_kernel* with parameters *pars*.
49
50    *cutoff* is the limiting value for the product of dispersion weights used
51    to perform the multidimensional dispersion calculation more quickly at a
52    slight cost to accuracy. The default value of *cutoff=0* integrates over
53    the entire dispersion cube.  Using *cutoff=1e-5* can be 50% faster, but
54    with an error of about 1%, which is usually less than the measurement
55    uncertainty.
56
57    *mono* is True if polydispersity should be set to none on all parameters.
58    """
[3c24ccd]59    mesh = get_mesh(calculator.info, pars, dim=calculator.dim, mono=mono)
[9e771a3]60    #print("pars", list(zip(*mesh))[0])
[8698a0d]61    call_details, values, is_magnetic = make_kernel_args(calculator, mesh)
[32e3c9b]62    #print("values:", values)
[9eb3632]63    return calculator(call_details, values, cutoff, is_magnetic)
[6d6508e]64
[40a87fa]65def call_ER(model_info, pars):
66    # type: (ModelInfo, ParameterSet) -> float
67    """
68    Call the model ER function using *values*.
69
70    *model_info* is either *model.info* if you have a loaded model,
71    or *kernel.info* if you have a model kernel prepared for evaluation.
72    """
73    if model_info.ER is None:
74        return 1.0
[4cc161e]75    elif not model_info.parameters.form_volume_parameters:
76        # handle the case where ER is provided but model is not polydisperse
77        return model_info.ER()
[40a87fa]78    else:
79        value, weight = _vol_pars(model_info, pars)
80        individual_radii = model_info.ER(*value)
81        return np.sum(weight*individual_radii) / np.sum(weight)
82
83
84def call_VR(model_info, pars):
85    # type: (ModelInfo, ParameterSet) -> float
86    """
87    Call the model VR function using *pars*.
88
89    *model_info* is either *model.info* if you have a loaded model,
90    or *kernel.info* if you have a model kernel prepared for evaluation.
91    """
92    if model_info.VR is None:
93        return 1.0
[4cc161e]94    elif not model_info.parameters.form_volume_parameters:
95        # handle the case where ER is provided but model is not polydisperse
96        return model_info.VR()
[40a87fa]97    else:
98        value, weight = _vol_pars(model_info, pars)
99        whole, part = model_info.VR(*value)
100        return np.sum(weight*part)/np.sum(weight*whole)
101
102
103def call_profile(model_info, **pars):
104    # type: (ModelInfo, ...) -> Tuple[np.ndarray, np.ndarray, Tuple[str, str]]
105    """
106    Returns the profile *x, y, (xlabel, ylabel)* representing the model.
107    """
108    args = {}
109    for p in model_info.parameters.kernel_parameters:
110        if p.length > 1:
111            value = np.array([pars.get(p.id+str(j), p.default)
112                              for j in range(1, p.length+1)])
113        else:
114            value = pars.get(p.id, p.default)
115        args[p.id] = value
116    x, y = model_info.profile(**args)
117    return x, y, model_info.profile_axes
118
[3c24ccd]119def get_mesh(model_info, values, dim='1d', mono=False):
120    # type: (ModelInfo, Dict[str, float], str, bool) -> List[Tuple[float, np.ndarray, np.ndarry]]
121    """
122    Retrieve the dispersity mesh described by the parameter set.
123
124    Returns a list of *(value, dispersity, weights)* with one tuple for each
125    parameter in the model call parameters.  Inactive parameters return the
126    default value with a weight of 1.0.
127    """
128    parameters = model_info.parameters
129    if mono:
130        active = lambda name: False
131    elif dim == '1d':
132        active = lambda name: name in parameters.pd_1d
133    elif dim == '2d':
134        active = lambda name: name in parameters.pd_2d
135    else:
136        active = lambda name: True
137
138    #print("pars",[p.id for p in parameters.call_parameters])
139    mesh = [_get_par_weights(p, values, active(p.name))
140            for p in parameters.call_parameters]
141    return mesh
142
[40a87fa]143
[3c24ccd]144def _get_par_weights(parameter, values, active=True):
145    # type: (Parameter, Dict[str, float]) -> Tuple[float, np.ndarray, np.ndarray]
[6d6508e]146    """
147    Generate the distribution for parameter *name* given the parameter values
148    in *pars*.
149
150    Uses "name", "name_pd", "name_pd_type", "name_pd_n", "name_pd_sigma"
151    from the *pars* dictionary for parameter value and parameter dispersion.
152    """
153    value = float(values.get(parameter.name, parameter.default))
154    npts = values.get(parameter.name+'_pd_n', 0)
155    width = values.get(parameter.name+'_pd', 0.0)
[32f87a5]156    relative = parameter.relative_pd
[9e771a3]157    if npts == 0 or width == 0.0 or not active:
[8698a0d]158        # Note: orientation parameters have the viewing angle as the parameter
159        # value and the jitter in the distribution, so be sure to set the
160        # empty pd for orientation parameters to 0.
[9e771a3]161        pd = [value if relative or not parameter.polydisperse else 0.0], [1.0]
[8698a0d]162    else:
163        limits = parameter.limits
164        disperser = values.get(parameter.name+'_pd_type', 'gaussian')
165        nsigma = values.get(parameter.name+'_pd_nsigma', 3.0)
166        pd = weights.get_weights(disperser, npts, width, nsigma,
[2d81cfe]167                                 value, limits, relative)
[8698a0d]168    return value, pd[0], pd[1]
[ae7b97b]169
[745b7bb]170
[9e771a3]171def _vol_pars(model_info, values):
[40a87fa]172    # type: (ModelInfo, ParameterSet) -> Tuple[np.ndarray, np.ndarray]
[9e771a3]173    vol_pars = [_get_par_weights(p, values)
[40a87fa]174                for p in model_info.parameters.call_parameters
175                if p.type == 'volume']
[4cc161e]176    #import pylab; pylab.plot(vol_pars[0][0],vol_pars[0][1]); pylab.show()
[9e771a3]177    dispersity, weight = dispersion_mesh(model_info, vol_pars)
178    return dispersity, weight
[745b7bb]179
180
[fa79f5c]181def _make_sesans_transform(data):
182    from sas.sascalc.data_util.nxsunit import Converter
183
184    # Pre-compute the Hankel matrix (H)
185    SElength = Converter(data._xunit)(data.x, "A")
186
187    theta_max = Converter("radians")(data.sample.zacceptance)[0]
188    q_max = 2 * np.pi / np.max(data.source.wavelength) * np.sin(theta_max)
189    zaccept = Converter("1/A")(q_max, "1/" + data.source.wavelength_unit),
190
191    Rmax = 10000000
192    hankel = sesans.SesansTransform(data.x, SElength,
193                                    data.source.wavelength,
194                                    zaccept, Rmax)
195    return hankel
196
197
[7cf2cfd]198class DataMixin(object):
199    """
200    DataMixin captures the common aspects of evaluating a SAS model for a
201    particular data set, including calculating Iq and evaluating the
202    resolution function.  It is used in particular by :class:`DirectModel`,
203    which evaluates a SAS model parameters as key word arguments to the
204    calculator method, and by :class:`bumps_model.Experiment`, which wraps the
205    model and data for use with the Bumps fitting engine.  It is not
206    currently used by :class:`sasview_model.SasviewModel` since this will
207    require a number of changes to SasView before we can do it.
[803f835]208
209    :meth:`_interpret_data` initializes the data structures necessary
210    to manage the calculations.  This sets attributes in the child class
211    such as *data_type* and *resolution*.
212
213    :meth:`_calc_theory` evaluates the model at the given control values.
214
215    :meth:`_set_data` sets the intensity data in the data object,
216    possibly with random noise added.  This is useful for simulating a
217    dataset with the results from :meth:`_calc_theory`.
[7cf2cfd]218    """
219    def _interpret_data(self, data, model):
[a5b8477]220        # type: (Data, KernelModel) -> None
[803f835]221        # pylint: disable=attribute-defined-outside-init
222
[7cf2cfd]223        self._data = data
224        self._model = model
225
226        # interpret data
[a769b54]227        if hasattr(data, 'isSesans') and data.isSesans:
[7cf2cfd]228            self.data_type = 'sesans'
229        elif hasattr(data, 'qx_data'):
230            self.data_type = 'Iqxy'
[ea75043]231        elif getattr(data, 'oriented', False):
232            self.data_type = 'Iq-oriented'
[7cf2cfd]233        else:
234            self.data_type = 'Iq'
235
236        if self.data_type == 'sesans':
[fa79f5c]237            res = _make_sesans_transform(data)
[803f835]238            index = slice(None, None)
[7cf2cfd]239            if data.y is not None:
[803f835]240                Iq, dIq = data.y, data.dy
241            else:
242                Iq, dIq = None, None
[7cf2cfd]243            #self._theory = np.zeros_like(q)
[fa79f5c]244            q_vectors = [res.q_calc]
[7cf2cfd]245        elif self.data_type == 'Iqxy':
[6d6508e]246            #if not model.info.parameters.has_2d:
[60eab2a]247            #    raise ValueError("not 2D without orientation or magnetic parameters")
[7cf2cfd]248            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
249            qmin = getattr(data, 'qmin', 1e-16)
250            qmax = getattr(data, 'qmax', np.inf)
251            accuracy = getattr(data, 'accuracy', 'Low')
[803f835]252            index = ~data.mask & (q >= qmin) & (q <= qmax)
[7cf2cfd]253            if data.data is not None:
[803f835]254                index &= ~np.isnan(data.data)
255                Iq = data.data[index]
256                dIq = data.err_data[index]
257            else:
258                Iq, dIq = None, None
259            res = resolution2d.Pinhole2D(data=data, index=index,
260                                         nsigma=3.0, accuracy=accuracy)
[7cf2cfd]261            #self._theory = np.zeros_like(self.Iq)
[803f835]262            q_vectors = res.q_calc
[7cf2cfd]263        elif self.data_type == 'Iq':
[803f835]264            index = (data.x >= data.qmin) & (data.x <= data.qmax)
[7cf2cfd]265            if data.y is not None:
[803f835]266                index &= ~np.isnan(data.y)
267                Iq = data.y[index]
268                dIq = data.dy[index]
269            else:
270                Iq, dIq = None, None
[7cf2cfd]271            if getattr(data, 'dx', None) is not None:
[803f835]272                q, dq = data.x[index], data.dx[index]
273                if (dq > 0).any():
274                    res = resolution.Pinhole1D(q, dq)
[7cf2cfd]275                else:
[803f835]276                    res = resolution.Perfect1D(q)
277            elif (getattr(data, 'dxl', None) is not None
278                  and getattr(data, 'dxw', None) is not None):
279                res = resolution.Slit1D(data.x[index],
[4d8e0bb]280                                        qx_width=data.dxl[index],
281                                        qy_width=data.dxw[index])
[7cf2cfd]282            else:
[803f835]283                res = resolution.Perfect1D(data.x[index])
[7cf2cfd]284
285            #self._theory = np.zeros_like(self.Iq)
[803f835]286            q_vectors = [res.q_calc]
[ea75043]287        elif self.data_type == 'Iq-oriented':
288            index = (data.x >= data.qmin) & (data.x <= data.qmax)
289            if data.y is not None:
290                index &= ~np.isnan(data.y)
291                Iq = data.y[index]
292                dIq = data.dy[index]
293            else:
294                Iq, dIq = None, None
295            if (getattr(data, 'dxl', None) is None
[40a87fa]296                    or getattr(data, 'dxw', None) is None):
[ea75043]297                raise ValueError("oriented sample with 1D data needs slit resolution")
298
299            res = resolution2d.Slit2D(data.x[index],
300                                      qx_width=data.dxw[index],
301                                      qy_width=data.dxl[index])
302            q_vectors = res.q_calc
[7cf2cfd]303        else:
304            raise ValueError("Unknown data type") # never gets here
305
306        # Remember function inputs so we can delay loading the function and
307        # so we can save/restore state
[02e70ff]308        self._kernel_inputs = q_vectors
[7cf2cfd]309        self._kernel = None
[803f835]310        self.Iq, self.dIq, self.index = Iq, dIq, index
311        self.resolution = res
[7cf2cfd]312
313    def _set_data(self, Iq, noise=None):
[a5b8477]314        # type: (np.ndarray, Optional[float]) -> None
[803f835]315        # pylint: disable=attribute-defined-outside-init
[7cf2cfd]316        if noise is not None:
317            self.dIq = Iq*noise*0.01
318        dy = self.dIq
319        y = Iq + np.random.randn(*dy.shape) * dy
320        self.Iq = y
[ea75043]321        if self.data_type in ('Iq', 'Iq-oriented'):
[d1ff3a5]322            if self._data.y is None:
323                self._data.y = np.empty(len(self._data.x), 'd')
324            if self._data.dy is None:
325                self._data.dy = np.empty(len(self._data.x), 'd')
[7cf2cfd]326            self._data.dy[self.index] = dy
327            self._data.y[self.index] = y
328        elif self.data_type == 'Iqxy':
[d1ff3a5]329            if self._data.data is None:
330                self._data.data = np.empty_like(self._data.qx_data, 'd')
331            if self._data.err_data is None:
332                self._data.err_data = np.empty_like(self._data.qx_data, 'd')
[7cf2cfd]333            self._data.data[self.index] = y
[d1ff3a5]334            self._data.err_data[self.index] = dy
[7cf2cfd]335        elif self.data_type == 'sesans':
[d1ff3a5]336            if self._data.y is None:
337                self._data.y = np.empty(len(self._data.x), 'd')
[7cf2cfd]338            self._data.y[self.index] = y
339        else:
340            raise ValueError("Unknown model")
341
342    def _calc_theory(self, pars, cutoff=0.0):
[a5b8477]343        # type: (ParameterSet, float) -> np.ndarray
[7cf2cfd]344        if self._kernel is None:
[68e7f9d]345            self._kernel = self._model.make_kernel(self._kernel_inputs)
[7cf2cfd]346
347        Iq_calc = call_kernel(self._kernel, pars, cutoff=cutoff)
[40a87fa]348        # Storing the calculated Iq values so that they can be plotted.
349        # Only applies to oriented USANS data for now.
350        # TODO: extend plotting of calculate Iq to other measurement types
351        # TODO: refactor so we don't store the result in the model
[d1ff3a5]352        self.Iq_calc = Iq_calc
[fa79f5c]353        result = self.resolution.apply(Iq_calc)
354        if hasattr(self.resolution, 'nx'):
355            self.Iq_calc = (
356                self.resolution.qx_calc, self.resolution.qy_calc,
357                np.reshape(Iq_calc, (self.resolution.ny, self.resolution.nx))
358            )
[40a87fa]359        return result
[7cf2cfd]360
361
362class DirectModel(DataMixin):
[803f835]363    """
364    Create a calculator object for a model.
365
366    *data* is 1D SAS, 2D SAS or SESANS data
367
368    *model* is a model calculator return from :func:`generate.load_model`
369
370    *cutoff* is the polydispersity weight cutoff.
371    """
[7cf2cfd]372    def __init__(self, data, model, cutoff=1e-5):
[a5b8477]373        # type: (Data, KernelModel, float) -> None
[7cf2cfd]374        self.model = model
375        self.cutoff = cutoff
[803f835]376        # Note: _interpret_data defines the model attributes
[7cf2cfd]377        self._interpret_data(data, model)
[803f835]378
[16bc3fc]379    def __call__(self, **pars):
[a5b8477]380        # type: (**float) -> np.ndarray
[7cf2cfd]381        return self._calc_theory(pars, cutoff=self.cutoff)
[803f835]382
[7cf2cfd]383    def simulate_data(self, noise=None, **pars):
[a5b8477]384        # type: (Optional[float], **float) -> None
[803f835]385        """
386        Generate simulated data for the model.
387        """
[7cf2cfd]388        Iq = self.__call__(**pars)
389        self._set_data(Iq, noise=noise)
[ae7b97b]390
[745b7bb]391    def profile(self, **pars):
392        # type: (**float) -> None
393        """
394        Generate a plottable profile.
395        """
396        return call_profile(self.model.info, **pars)
397
[803f835]398def main():
[a5b8477]399    # type: () -> None
[803f835]400    """
401    Program to evaluate a particular model at a set of q values.
402    """
[ae7b97b]403    import sys
[7cf2cfd]404    from .data import empty_data1D, empty_data2D
[17bbadd]405    from .core import load_model_info, build_model
[7cf2cfd]406
[ae7b97b]407    if len(sys.argv) < 3:
[9404dd3]408        print("usage: python -m sasmodels.direct_model modelname (q|qx,qy) par=val ...")
[ae7b97b]409        sys.exit(1)
410    model_name = sys.argv[1]
[aa4946b]411    call = sys.argv[2].upper()
[17bbadd]412    if call != "ER_VR":
[7cf2cfd]413        try:
414            values = [float(v) for v in call.split(',')]
[2d81cfe]415        except ValueError:
[7cf2cfd]416            values = []
[aa4946b]417        if len(values) == 1:
[7cf2cfd]418            q, = values
419            data = empty_data1D([q])
[aa4946b]420        elif len(values) == 2:
[803f835]421            qx, qy = values
422            data = empty_data2D([qx], [qy])
[aa4946b]423        else:
[9404dd3]424            print("use q or qx,qy or ER or VR")
[aa4946b]425            sys.exit(1)
[7cf2cfd]426    else:
427        data = empty_data1D([0.001])  # Data not used in ER/VR
428
[17bbadd]429    model_info = load_model_info(model_name)
430    model = build_model(model_info)
[7cf2cfd]431    calculator = DirectModel(data, model)
[4cc161e]432    pars = dict((k, (float(v) if not k.endswith("_pd_type") else v))
[ae7b97b]433                for pair in sys.argv[3:]
[803f835]434                for k, v in [pair.split('=')])
[17bbadd]435    if call == "ER_VR":
[40a87fa]436        ER = call_ER(model_info, pars)
437        VR = call_VR(model_info, pars)
438        print(ER, VR)
[aa4946b]439    else:
[7cf2cfd]440        Iq = calculator(**pars)
[9404dd3]441        print(Iq[0])
[ae7b97b]442
443if __name__ == "__main__":
[803f835]444    main()
Note: See TracBrowser for help on using the repository browser.