source: sasmodels/sasmodels/bumps_model.py @ 346bc88

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 346bc88 was 346bc88, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

Add 2D resolution smearing. Split BumpsModel? into Experiment and Model and fix up examples.

  • Property mode set to 100644
File size: 18.1 KB
Line 
1"""
2Wrap sasmodels for direct use by bumps.
3
4:class:`Model` is a wrapper for the sasmodels kernel which defines a
5bumps *Parameter* box for each kernel parameter.  *Model* accepts keyword
6arguments to set the initial value for each parameter.
7
8:class:`Experiment` combines the *Model* function with a data file loaded by the
9sasview data loader.  *Experiment* takes a *cutoff* parameter controlling
10how far the polydispersity integral extends.
11
12A variety of helper functions are provided:
13
14    :func:`load_data` loads a sasview data file.
15
16    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
17    a theory function before the data is measured.
18
19    :func:`empty_data2D` creates an empty 2D dataset.
20
21    :func:`set_beam_stop` masks the beam stop from the data.
22
23    :func:`set_half` selects the right or left half of the data, which can
24    be useful for shear measurements which have not been properly corrected
25    for path length and reflections.
26
27    :func:`set_top` cuts the top part off the data.
28
29    :func:`plot_data` plots the data file.
30
31    :func:`plot_theory` plots a calculated result from the model.
32
33"""
34
35import datetime
36import warnings
37
38import numpy as np
39
40from . import sesans
41from .resolution import Perfect1D, Pinhole1D, Slit1D
42from .resolution2d import Pinhole2D
43
44# CRUFT python 2.6
45if not hasattr(datetime.timedelta, 'total_seconds'):
46    def delay(dt):
47        """Return number date-time delta as number seconds"""
48        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
49else:
50    def delay(dt):
51        """Return number date-time delta as number seconds"""
52        return dt.total_seconds()
53
54
55# CRUFT: old style bumps wrapper which doesn't separate data and model
56def BumpsModel(data, model, cutoff=1e-5, **kw):
57    warnings.warn("Use of BumpsModel is deprecated.  Use bumps_model.Experiment instead.")
58    model = Model(model, **kw)
59    experiment = Experiment(data=data, model=model, cutoff=cutoff)
60    for k in model._parameter_names:
61        setattr(experiment, k, getattr(model, k))
62    return experiment
63
64
65
66def tic():
67    """
68    Timer function.
69
70    Use "toc=tic()" to start the clock and "toc()" to measure
71    a time interval.
72    """
73    then = datetime.datetime.now()
74    return lambda: delay(datetime.datetime.now() - then)
75
76
77def load_data(filename):
78    """
79    Load data using a sasview loader.
80    """
81    from sas.dataloader.loader import Loader
82    loader = Loader()
83    data = loader.load(filename)
84    if data is None:
85        raise IOError("Data %r could not be loaded" % filename)
86    return data
87
88def plot_data(data, view='log'):
89    """
90    Plot data loaded by the sasview loader.
91    """
92    if hasattr(data, 'qx_data'):
93        _plot_2d_signal(data, data.data, view=view)
94    else:
95        # Note: kind of weird using the _plot_result1D to plot just the
96        # data, but it handles the masking and graph markup already, so
97        # do not repeat.
98        _plot_result1D(data, None, None, view)
99
100def plot_theory(data, theory, view='log'):
101    if hasattr(data, 'qx_data'):
102        _plot_2d_signal(data, theory, view=view)
103    else:
104        _plot_result1D(data, theory, None, view, include_data=False)
105
106
107def empty_data1D(q, resolution=0.05):
108    """
109    Create empty 1D data using the given *q* as the x value.
110
111    *resolution* dq/q defaults to 5%.
112    """
113
114    from sas.dataloader.data_info import Data1D
115
116    Iq = 100 * np.ones_like(q)
117    dIq = np.sqrt(Iq)
118    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
119    data.filename = "fake data"
120    data.qmin, data.qmax = q.min(), q.max()
121    data.mask = np.zeros(len(Iq), dtype='bool')
122    return data
123
124
125def empty_data2D(qx, qy=None, resolution=0.05):
126    """
127    Create empty 2D data using the given mesh.
128
129    If *qy* is missing, create a square mesh with *qy=qx*.
130
131    *resolution* dq/q defaults to 5%.
132    """
133    from sas.dataloader.data_info import Data2D, Detector
134
135    if qy is None:
136        qy = qx
137    Qx, Qy = np.meshgrid(qx, qy)
138    Qx, Qy = Qx.flatten(), Qy.flatten()
139    Iq = 100 * np.ones_like(Qx)
140    dIq = np.sqrt(Iq)
141    mask = np.ones(len(Iq), dtype='bool')
142
143    data = Data2D()
144    data.filename = "fake data"
145    data.qx_data = Qx
146    data.qy_data = Qy
147    data.data = Iq
148    data.err_data = dIq
149    data.mask = mask
150    data.qmin = 1e-16
151    data.qmax = np.inf
152
153    # 5% dQ/Q resolution
154    if resolution != 0:
155        data.dqx_data = resolution * Qx
156        data.dqy_data = resolution * Qy
157
158    detector = Detector()
159    detector.pixel_size.x = 5 # mm
160    detector.pixel_size.y = 5 # mm
161    detector.distance = 4 # m
162    data.detector.append(detector)
163    data.xbins = qx
164    data.ybins = qy
165    data.source.wavelength = 5 # angstroms
166    data.source.wavelength_unit = "A"
167    data.Q_unit = "1/A"
168    data.I_unit = "1/cm"
169    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
170    data.xaxis("Q_x", "A^{-1}")
171    data.yaxis("Q_y", "A^{-1}")
172    data.zaxis("Intensity", r"\text{cm}^{-1}")
173    return data
174
175
176def set_beam_stop(data, radius, outer=None):
177    """
178    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
179    """
180    from sas.dataloader.manipulations import Ringcut
181    if hasattr(data, 'qx_data'):
182        data.mask = Ringcut(0, radius)(data)
183        if outer is not None:
184            data.mask += Ringcut(outer, np.inf)(data)
185    else:
186        data.mask = (data.x >= radius)
187        if outer is not None:
188            data.mask &= (data.x < outer)
189
190
191def set_half(data, half):
192    """
193    Select half of the data, either "right" or "left".
194    """
195    from sas.dataloader.manipulations import Boxcut
196    if half == 'right':
197        data.mask += \
198            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
199    if half == 'left':
200        data.mask += \
201            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
202
203
204def set_top(data, cutoff):
205    """
206    Chop the top off the data, above *cutoff*.
207    """
208    from sas.dataloader.manipulations import Boxcut
209    data.mask += \
210        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
211
212
213def _plot_result1D(data, theory, resid, view, include_data=True):
214    """
215    Plot the data and residuals for 1D data.
216    """
217    import matplotlib.pyplot as plt
218    from numpy.ma import masked_array, masked
219    #print "not a number",sum(np.isnan(data.y))
220    #data.y[data.y<0.05] = 0.5
221    mdata = masked_array(data.y, data.mask)
222    mdata[~np.isfinite(mdata)] = masked
223    if view is 'log':
224        mdata[mdata <= 0] = masked
225
226    scale = data.x**4 if view == 'q4' else 1.0
227    if resid is not None:
228        plt.subplot(121)
229
230    if include_data:
231        plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
232    if theory is not None:
233        mtheory = masked_array(theory, mdata.mask)
234        plt.plot(data.x, scale*mtheory, '-', hold=True)
235    plt.xscale(view)
236    plt.yscale('linear' if view == 'q4' else view)
237    plt.xlabel('Q')
238    plt.ylabel('I(Q)')
239    if resid is not None:
240        mresid = masked_array(resid, mdata.mask)
241        plt.subplot(122)
242        plt.plot(data.x, mresid, 'x')
243        plt.ylabel('residuals')
244        plt.xlabel('Q')
245        plt.xscale(view)
246
247# pylint: disable=unused-argument
248def _plot_sesans(data, theory, resid, view):
249    import matplotlib.pyplot as plt
250    plt.subplot(121)
251    plt.errorbar(data.x, data.y, yerr=data.dy)
252    plt.plot(data.x, theory, '-', hold=True)
253    plt.xlabel('spin echo length (nm)')
254    plt.ylabel('polarization (P/P0)')
255    plt.subplot(122)
256    plt.plot(data.x, resid, 'x')
257    plt.xlabel('spin echo length (nm)')
258    plt.ylabel('residuals (P/P0)')
259
260def _plot_result2D(data, theory, resid, view):
261    """
262    Plot the data and residuals for 2D data.
263    """
264    import matplotlib.pyplot as plt
265    target = data.data[~data.mask]
266    if view == 'log':
267        vmin = min(target[target>0].min(), theory[theory>0].min())
268        vmax = max(target.max(), theory.max())
269    else:
270        vmin = min(target.min(), theory.min())
271        vmax = max(target.max(), theory.max())
272    #print vmin, vmax
273    plt.subplot(131)
274    _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
275    plt.title('data')
276    plt.colorbar()
277    plt.subplot(132)
278    _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
279    plt.title('theory')
280    plt.colorbar()
281    plt.subplot(133)
282    _plot_2d_signal(data, resid, view='linear')
283    plt.title('residuals')
284    plt.colorbar()
285
286def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
287    """
288    Plot the target value for the data.  This could be the data itself,
289    the theory calculation, or the residuals.
290
291    *scale* can be 'log' for log scale data, or 'linear'.
292    """
293    import matplotlib.pyplot as plt
294    from numpy.ma import masked_array
295
296    image = np.zeros_like(data.qx_data)
297    image[~data.mask] = signal
298    valid = np.isfinite(image)
299    if view == 'log':
300        valid[valid] = (image[valid] > 0)
301        image[valid] = np.log10(image[valid])
302    elif view == 'q4':
303        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
304    image[~valid | data.mask] = 0
305    #plottable = Iq
306    plottable = masked_array(image, ~valid | data.mask)
307    xmin, xmax = min(data.qx_data), max(data.qx_data)
308    ymin, ymax = min(data.qy_data), max(data.qy_data)
309    # TODO: fix vmin, vmax so it is shared for theory/resid
310    vmin = vmax = None
311    try:
312        if vmin is None: vmin = image[valid & ~data.mask].min()
313        if vmax is None: vmax = image[valid & ~data.mask].max()
314    except:
315        vmin, vmax = 0, 1
316    #print vmin,vmax
317    plt.imshow(plottable.reshape(128, 128),
318               interpolation='nearest', aspect=1, origin='upper',
319               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
320
321
322class Model(object):
323    def __init__(self, kernel, **kw):
324        from bumps.names import Parameter
325
326        self.kernel = kernel
327        partype = kernel.info['partype']
328
329        pars = []
330        for p in kernel.info['parameters']:
331            name, default, limits = p[0], p[2], p[3]
332            value = kw.pop(name, default)
333            setattr(self, name, Parameter.default(value, name=name, limits=limits))
334            pars.append(name)
335        for name in partype['pd-2d']:
336            for xpart, xdefault, xlimits in [
337                ('_pd', 0, limits),
338                ('_pd_n', 35, (0, 1000)),
339                ('_pd_nsigma', 3, (0, 10)),
340                ('_pd_type', 'gaussian', None),
341                ]:
342                xname = name + xpart
343                xvalue = kw.pop(xname, xdefault)
344                if xlimits is not None:
345                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
346                    pars.append(xname)
347                setattr(self, xname, xvalue)
348        self._parameter_names = pars
349        if kw:
350            raise TypeError("unexpected parameters: %s"
351                            % (", ".join(sorted(kw.keys()))))
352
353    def parameters(self):
354        """
355        Return a dictionary of parameters
356        """
357        return dict((k, getattr(self, k)) for k in self._parameter_names)
358
359class Experiment(object):
360    """
361    Return a bumps wrapper for a SAS model.
362
363    *data* is the data to be fitted.
364
365    *model* is the SAS model from :func:`core.load_model`.
366
367    *cutoff* is the integration cutoff, which avoids computing the
368    the SAS model where the polydispersity weight is low.
369
370    Model parameters can be initialized with additional keyword
371    arguments, or by assigning to model.parameter_name.value.
372
373    The resulting bumps model can be used directly in a FitProblem call.
374    """
375    def __init__(self, data, model, cutoff=1e-5):
376
377        # remember inputs so we can inspect from outside
378        self.data = data
379        self.model = model
380        self.cutoff = cutoff
381        if hasattr(data, 'lam'):
382            self.data_type = 'sesans'
383        elif hasattr(data, 'qx_data'):
384            self.data_type = 'Iqxy'
385        else:
386            self.data_type = 'Iq'
387
388        # interpret data
389        partype = model.kernel.info['partype']
390        if self.data_type == 'sesans':
391            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
392            self.index = slice(None, None)
393            self.Iq = data.y
394            self.dIq = data.dy
395            #self._theory = np.zeros_like(q)
396            q_vectors = [q]
397        elif self.data_type == 'Iqxy':
398            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
399            qmin = getattr(data, 'qmin', 1e-16)
400            qmax = getattr(data, 'qmax', np.inf)
401            self.index = (~data.mask) & (~np.isnan(data.data)) \
402                         & (q >= qmin) & (q <= qmax)
403            self.Iq = data.data[self.index]
404            self.dIq = data.err_data[self.index]
405            self.resolution = Pinhole2D(data=data, index=self.index,
406                                        nsigma=3.0, accuracy='Low')
407            #self._theory = np.zeros_like(self.Iq)
408            if not partype['orientation'] and not partype['magnetic']:
409                raise ValueError("not 2D without orientation or magnetic parameters")
410                #qx,qy = self.resolution.q_calc
411                #q_vectors = [np.sqrt(qx**2 + qy**2)]
412            else:
413                q_vectors = self.resolution.q_calc
414        elif self.data_type == 'Iq':
415            self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y)
416            self.Iq = data.y[self.index]
417            self.dIq = data.dy[self.index]
418            if getattr(data, 'dx', None) is not None:
419                q, dq = data.x[self.index], data.dx[self.index]
420                if (dq>0).any():
421                    self.resolution = Pinhole1D(q, dq)
422                else:
423                    self.resolution = Perfect1D(q)
424            elif (getattr(data, 'dxl', None) is not None and
425                  getattr(data, 'dxw', None) is not None):
426                q = data.x[self.index]
427                width = data.dxh[self.index]  # Note: dx
428                self.resolution = Slit1D(data.x[self.index],
429                                         width=data.dxh[self.index],
430                                         height=data.dxw[self.index])
431            else:
432                self.resolution = Perfect1D(data.x[self.index])
433
434            #self._theory = np.zeros_like(self.Iq)
435            q_vectors = [self.resolution.q_calc]
436        else:
437            raise ValueError("Unknown data type") # never gets here
438
439        # Remember function inputs so we can delay loading the function and
440        # so we can save/restore state
441        self._fn_inputs = [v for v in q_vectors]
442        self._fn = None
443
444        self.update()
445
446    def update(self):
447        self._cache = {}
448
449    def numpoints(self):
450        """
451            Return the number of points
452        """
453        return len(self.Iq)
454
455    def parameters(self):
456        """
457        Return a dictionary of parameters
458        """
459        return self.model.parameters()
460
461    def theory(self):
462        if 'theory' not in self._cache:
463            if self._fn is None:
464                q_input = self.model.kernel.make_input(self._fn_inputs)
465                self._fn = self.model.kernel(q_input)
466
467            fixed_pars = [getattr(self.model, p).value for p in self._fn.fixed_pars]
468            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
469            #print fixed_pars,pd_pars
470            Iq_calc = self._fn(fixed_pars, pd_pars, self.cutoff)
471            #self._theory[:] = self._fn.eval(pars, pd_pars)
472            if self.data_type == 'sesans':
473                result = sesans.hankel(self.data.x, self.data.lam * 1e-9,
474                                       self.data.sample.thickness / 10,
475                                       self._fn_inputs[0], Iq_calc)
476                self._cache['theory'] = result
477            else:
478                Iq = self.resolution.apply(Iq_calc)
479                self._cache['theory'] = Iq
480        return self._cache['theory']
481
482    def residuals(self):
483        #if np.any(self.err ==0): print "zeros in err"
484        return (self.theory() - self.Iq) / self.dIq
485
486    def nllf(self):
487        delta = self.residuals()
488        #if np.any(np.isnan(R)): print "NaN in residuals"
489        return 0.5 * np.sum(delta ** 2)
490
491    #def __call__(self):
492    #    return 2 * self.nllf() / self.dof
493
494    def plot(self, view='log'):
495        """
496        Plot the data and residuals.
497        """
498        data, theory, resid = self.data, self.theory(), self.residuals()
499        if self.data_type == 'Iq':
500            _plot_result1D(data, theory, resid, view)
501        elif self.data_type == 'Iqxy':
502            _plot_result2D(data, theory, resid, view)
503        elif self.data_type == 'sesans':
504            _plot_sesans(data, theory, resid, view)
505        else:
506            raise ValueError("Unknown data type")
507
508    def simulate_data(self, noise=None):
509        theory = self.theory()
510        if noise is not None:
511            self.dIq = theory*noise*0.01
512        dy = self.dIq
513        y = theory + np.random.randn(*dy.shape) * dy
514        self.Iq = y
515        if self.data_type == 'Iq':
516            self.data.dy[self.index] = dy
517            self.data.y[self.index] = y
518        elif self.data_type == 'Iqxy':
519            self.data.data[self.index] = y
520        elif self.data_type == 'sesans':
521            self.data.y[self.index] = y
522        else:
523            raise ValueError("Unknown model")
524
525    def save(self, basename):
526        pass
527
528    def _get_weights(self, par):
529        """
530        Get parameter dispersion weights
531        """
532        from . import weights
533
534        relative = self.model.kernel.info['partype']['pd-rel']
535        limits = self.model.kernel.info['limits']
536        disperser, value, npts, width, nsigma = [
537            getattr(self.model, par + ext)
538            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')]
539        value, weight = weights.get_weights(
540            disperser, int(npts.value), width.value, nsigma.value,
541            value.value, limits[par], par in relative)
542        return value, weight / np.sum(weight)
543
544    def __getstate__(self):
545        # Can't pickle gpu functions, so instead make them lazy
546        state = self.__dict__.copy()
547        state['_fn'] = None
548        return state
549
550    def __setstate__(self, state):
551        # pylint: disable=attribute-defined-outside-init
552        self.__dict__ = state
553
554
555def demo():
556    data = load_data('DEC07086.DAT')
557    set_beam_stop(data, 0.004)
558    plot_data(data)
559    import matplotlib.pyplot as plt; plt.show()
560
561
562if __name__ == "__main__":
563    demo()
Note: See TracBrowser for help on using the repository browser.