source: sasmodels/sasmodels/bumps_model.py @ d1fe925

gh-pages
Last change on this file since d1fe925 was d1fe925, checked in by ajj, 8 years ago

Creating gh_pages branch for docs

  • Property mode set to 100644
File size: 18.5 KB
Line 
1"""
2Wrap sasmodels for direct use by bumps.
3
4:class:`Model` is a wrapper for the sasmodels kernel which defines a
5bumps *Parameter* box for each kernel parameter.  *Model* accepts keyword
6arguments to set the initial value for each parameter.
7
8:class:`Experiment` combines the *Model* function with a data file loaded by the
9sasview data loader.  *Experiment* takes a *cutoff* parameter controlling
10how far the polydispersity integral extends.
11
12A variety of helper functions are provided:
13
14    :func:`load_data` loads a sasview data file.
15
16    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
17    a theory function before the data is measured.
18
19    :func:`empty_data2D` creates an empty 2D dataset.
20
21    :func:`set_beam_stop` masks the beam stop from the data.
22
23    :func:`set_half` selects the right or left half of the data, which can
24    be useful for shear measurements which have not been properly corrected
25    for path length and reflections.
26
27    :func:`set_top` cuts the top part off the data.
28
29    :func:`plot_data` plots the data file.
30
31    :func:`plot_theory` plots a calculated result from the model.
32
33"""
34
35import datetime
36import warnings
37
38import numpy as np
39
40from . import sesans
41from .resolution import Perfect1D, Pinhole1D, Slit1D
42from .resolution2d import Pinhole2D
43
44# CRUFT python 2.6
45if not hasattr(datetime.timedelta, 'total_seconds'):
46    def delay(dt):
47        """Return number date-time delta as number seconds"""
48        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
49else:
50    def delay(dt):
51        """Return number date-time delta as number seconds"""
52        return dt.total_seconds()
53
54
55# CRUFT: old style bumps wrapper which doesn't separate data and model
56def BumpsModel(data, model, cutoff=1e-5, **kw):
57    warnings.warn("Use of BumpsModel is deprecated.  Use bumps_model.Experiment instead.")
58    model = Model(model, **kw)
59    experiment = Experiment(data=data, model=model, cutoff=cutoff)
60    for k in model._parameter_names:
61        setattr(experiment, k, getattr(model, k))
62    return experiment
63
64
65
66def tic():
67    """
68    Timer function.
69
70    Use "toc=tic()" to start the clock and "toc()" to measure
71    a time interval.
72    """
73    then = datetime.datetime.now()
74    return lambda: delay(datetime.datetime.now() - then)
75
76
77def load_data(filename):
78    """
79    Load data using a sasview loader.
80    """
81    from sas.dataloader.loader import Loader
82    loader = Loader()
83    data = loader.load(filename)
84    if data is None:
85        raise IOError("Data %r could not be loaded" % filename)
86    return data
87
88def plot_data(data, view='log'):
89    """
90    Plot data loaded by the sasview loader.
91    """
92    if hasattr(data, 'qx_data'):
93        _plot_2d_signal(data, data.data, view=view)
94    else:
95        # Note: kind of weird using the _plot_result1D to plot just the
96        # data, but it handles the masking and graph markup already, so
97        # do not repeat.
98        _plot_result1D(data, None, None, view)
99
100def plot_theory(data, theory, view='log'):
101    if hasattr(data, 'qx_data'):
102        _plot_2d_signal(data, theory, view=view)
103    else:
104        _plot_result1D(data, theory, None, view, include_data=False)
105
106
107def empty_data1D(q, resolution=0.05):
108    """
109    Create empty 1D data using the given *q* as the x value.
110
111    *resolution* dq/q defaults to 5%.
112    """
113
114    from sas.dataloader.data_info import Data1D
115
116    Iq = 100 * np.ones_like(q)
117    dIq = np.sqrt(Iq)
118    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
119    data.filename = "fake data"
120    data.qmin, data.qmax = q.min(), q.max()
121    data.mask = np.zeros(len(Iq), dtype='bool')
122    return data
123
124
125def empty_data2D(qx, qy=None, resolution=0.05):
126    """
127    Create empty 2D data using the given mesh.
128
129    If *qy* is missing, create a square mesh with *qy=qx*.
130
131    *resolution* dq/q defaults to 5%.
132    """
133    from sas.dataloader.data_info import Data2D, Detector
134
135    if qy is None:
136        qy = qx
137    Qx, Qy = np.meshgrid(qx, qy)
138    Qx, Qy = Qx.flatten(), Qy.flatten()
139    Iq = 100 * np.ones_like(Qx)
140    dIq = np.sqrt(Iq)
141    mask = np.ones(len(Iq), dtype='bool')
142
143    data = Data2D()
144    data.filename = "fake data"
145    data.qx_data = Qx
146    data.qy_data = Qy
147    data.data = Iq
148    data.err_data = dIq
149    data.mask = mask
150    data.qmin = 1e-16
151    data.qmax = np.inf
152
153    # 5% dQ/Q resolution
154    if resolution != 0:
155        # https://www.ncnr.nist.gov/staff/hammouda/distance_learning/chapter_15.pdf
156        # Should have an additional constant which depends on distances and
157        # radii of the aperture, pixel dimensions and wavelength spread
158        # Instead, assume radial dQ/Q is constant, and perpendicular matches
159        # radial (which instead it should be inverse).
160        Q = np.sqrt(Qx**2 + Qy**2)
161        data.dqx_data = resolution * Q
162        data.dqy_data = resolution * Q
163
164    detector = Detector()
165    detector.pixel_size.x = 5 # mm
166    detector.pixel_size.y = 5 # mm
167    detector.distance = 4 # m
168    data.detector.append(detector)
169    data.xbins = qx
170    data.ybins = qy
171    data.source.wavelength = 5 # angstroms
172    data.source.wavelength_unit = "A"
173    data.Q_unit = "1/A"
174    data.I_unit = "1/cm"
175    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
176    data.xaxis("Q_x", "A^{-1}")
177    data.yaxis("Q_y", "A^{-1}")
178    data.zaxis("Intensity", r"\text{cm}^{-1}")
179    return data
180
181
182def set_beam_stop(data, radius, outer=None):
183    """
184    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
185    """
186    from sas.dataloader.manipulations import Ringcut
187    if hasattr(data, 'qx_data'):
188        data.mask = Ringcut(0, radius)(data)
189        if outer is not None:
190            data.mask += Ringcut(outer, np.inf)(data)
191    else:
192        data.mask = (data.x >= radius)
193        if outer is not None:
194            data.mask &= (data.x < outer)
195
196
197def set_half(data, half):
198    """
199    Select half of the data, either "right" or "left".
200    """
201    from sas.dataloader.manipulations import Boxcut
202    if half == 'right':
203        data.mask += \
204            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
205    if half == 'left':
206        data.mask += \
207            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
208
209
210def set_top(data, cutoff):
211    """
212    Chop the top off the data, above *cutoff*.
213    """
214    from sas.dataloader.manipulations import Boxcut
215    data.mask += \
216        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
217
218
219def _plot_result1D(data, theory, resid, view, include_data=True):
220    """
221    Plot the data and residuals for 1D data.
222    """
223    import matplotlib.pyplot as plt
224    from numpy.ma import masked_array, masked
225    #print "not a number",sum(np.isnan(data.y))
226    #data.y[data.y<0.05] = 0.5
227    mdata = masked_array(data.y, data.mask)
228    mdata[~np.isfinite(mdata)] = masked
229    if view is 'log':
230        mdata[mdata <= 0] = masked
231
232    scale = data.x**4 if view == 'q4' else 1.0
233    if resid is not None:
234        plt.subplot(121)
235
236    if include_data:
237        plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
238    if theory is not None:
239        mtheory = masked_array(theory, mdata.mask)
240        plt.plot(data.x, scale*mtheory, '-', hold=True)
241    plt.xscale(view)
242    plt.yscale('linear' if view == 'q4' else view)
243    plt.xlabel('Q')
244    plt.ylabel('I(Q)')
245    if resid is not None:
246        mresid = masked_array(resid, mdata.mask)
247        plt.subplot(122)
248        plt.plot(data.x, mresid, 'x')
249        plt.ylabel('residuals')
250        plt.xlabel('Q')
251        plt.xscale(view)
252
253# pylint: disable=unused-argument
254def _plot_sesans(data, theory, resid, view):
255    import matplotlib.pyplot as plt
256    plt.subplot(121)
257    plt.errorbar(data.x, data.y, yerr=data.dy)
258    plt.plot(data.x, theory, '-', hold=True)
259    plt.xlabel('spin echo length (nm)')
260    plt.ylabel('polarization (P/P0)')
261    plt.subplot(122)
262    plt.plot(data.x, resid, 'x')
263    plt.xlabel('spin echo length (nm)')
264    plt.ylabel('residuals (P/P0)')
265
266def _plot_result2D(data, theory, resid, view):
267    """
268    Plot the data and residuals for 2D data.
269    """
270    import matplotlib.pyplot as plt
271    target = data.data[~data.mask]
272    if view == 'log':
273        vmin = min(target[target>0].min(), theory[theory>0].min())
274        vmax = max(target.max(), theory.max())
275    else:
276        vmin = min(target.min(), theory.min())
277        vmax = max(target.max(), theory.max())
278    #print vmin, vmax
279    plt.subplot(131)
280    _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
281    plt.title('data')
282    plt.colorbar()
283    plt.subplot(132)
284    _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
285    plt.title('theory')
286    plt.colorbar()
287    plt.subplot(133)
288    _plot_2d_signal(data, resid, view='linear')
289    plt.title('residuals')
290    plt.colorbar()
291
292def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
293    """
294    Plot the target value for the data.  This could be the data itself,
295    the theory calculation, or the residuals.
296
297    *scale* can be 'log' for log scale data, or 'linear'.
298    """
299    import matplotlib.pyplot as plt
300    from numpy.ma import masked_array
301
302    image = np.zeros_like(data.qx_data)
303    image[~data.mask] = signal
304    valid = np.isfinite(image)
305    if view == 'log':
306        valid[valid] = (image[valid] > 0)
307        image[valid] = np.log10(image[valid])
308    elif view == 'q4':
309        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
310    image[~valid | data.mask] = 0
311    #plottable = Iq
312    plottable = masked_array(image, ~valid | data.mask)
313    xmin, xmax = min(data.qx_data), max(data.qx_data)
314    ymin, ymax = min(data.qy_data), max(data.qy_data)
315    # TODO: fix vmin, vmax so it is shared for theory/resid
316    vmin = vmax = None
317    try:
318        if vmin is None: vmin = image[valid & ~data.mask].min()
319        if vmax is None: vmax = image[valid & ~data.mask].max()
320    except:
321        vmin, vmax = 0, 1
322    #print vmin,vmax
323    plt.imshow(plottable.reshape(128, 128),
324               interpolation='nearest', aspect=1, origin='upper',
325               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
326
327
328class Model(object):
329    def __init__(self, kernel, **kw):
330        from bumps.names import Parameter
331
332        self.kernel = kernel
333        partype = kernel.info['partype']
334
335        pars = []
336        for p in kernel.info['parameters']:
337            name, default, limits = p[0], p[2], p[3]
338            value = kw.pop(name, default)
339            setattr(self, name, Parameter.default(value, name=name, limits=limits))
340            pars.append(name)
341        for name in partype['pd-2d']:
342            for xpart, xdefault, xlimits in [
343                ('_pd', 0, limits),
344                ('_pd_n', 35, (0, 1000)),
345                ('_pd_nsigma', 3, (0, 10)),
346                ('_pd_type', 'gaussian', None),
347                ]:
348                xname = name + xpart
349                xvalue = kw.pop(xname, xdefault)
350                if xlimits is not None:
351                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
352                    pars.append(xname)
353                setattr(self, xname, xvalue)
354        self._parameter_names = pars
355        if kw:
356            raise TypeError("unexpected parameters: %s"
357                            % (", ".join(sorted(kw.keys()))))
358
359    def parameters(self):
360        """
361        Return a dictionary of parameters
362        """
363        return dict((k, getattr(self, k)) for k in self._parameter_names)
364
365class Experiment(object):
366    """
367    Return a bumps wrapper for a SAS model.
368
369    *data* is the data to be fitted.
370
371    *model* is the SAS model from :func:`core.load_model`.
372
373    *cutoff* is the integration cutoff, which avoids computing the
374    the SAS model where the polydispersity weight is low.
375
376    Model parameters can be initialized with additional keyword
377    arguments, or by assigning to model.parameter_name.value.
378
379    The resulting bumps model can be used directly in a FitProblem call.
380    """
381    def __init__(self, data, model, cutoff=1e-5):
382
383        # remember inputs so we can inspect from outside
384        self.data = data
385        self.model = model
386        self.cutoff = cutoff
387        if hasattr(data, 'lam'):
388            self.data_type = 'sesans'
389        elif hasattr(data, 'qx_data'):
390            self.data_type = 'Iqxy'
391        else:
392            self.data_type = 'Iq'
393
394        # interpret data
395        partype = model.kernel.info['partype']
396        if self.data_type == 'sesans':
397            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
398            self.index = slice(None, None)
399            self.Iq = data.y
400            self.dIq = data.dy
401            #self._theory = np.zeros_like(q)
402            q_vectors = [q]
403        elif self.data_type == 'Iqxy':
404            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
405            qmin = getattr(data, 'qmin', 1e-16)
406            qmax = getattr(data, 'qmax', np.inf)
407            accuracy = getattr(data, 'accuracy', 'Low')
408            self.index = (~data.mask) & (~np.isnan(data.data)) \
409                         & (q >= qmin) & (q <= qmax)
410            self.Iq = data.data[self.index]
411            self.dIq = data.err_data[self.index]
412            self.resolution = Pinhole2D(data=data, index=self.index,
413                                        nsigma=3.0, accuracy=accuracy)
414            #self._theory = np.zeros_like(self.Iq)
415            if not partype['orientation'] and not partype['magnetic']:
416                raise ValueError("not 2D without orientation or magnetic parameters")
417                #qx,qy = self.resolution.q_calc
418                #q_vectors = [np.sqrt(qx**2 + qy**2)]
419            else:
420                q_vectors = self.resolution.q_calc
421        elif self.data_type == 'Iq':
422            self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y)
423            self.Iq = data.y[self.index]
424            self.dIq = data.dy[self.index]
425            if getattr(data, 'dx', None) is not None:
426                q, dq = data.x[self.index], data.dx[self.index]
427                if (dq>0).any():
428                    self.resolution = Pinhole1D(q, dq)
429                else:
430                    self.resolution = Perfect1D(q)
431            elif (getattr(data, 'dxl', None) is not None and
432                  getattr(data, 'dxw', None) is not None):
433                q = data.x[self.index]
434                width = data.dxh[self.index]  # Note: dx
435                self.resolution = Slit1D(data.x[self.index],
436                                         width=data.dxh[self.index],
437                                         height=data.dxw[self.index])
438            else:
439                self.resolution = Perfect1D(data.x[self.index])
440
441            #self._theory = np.zeros_like(self.Iq)
442            q_vectors = [self.resolution.q_calc]
443        else:
444            raise ValueError("Unknown data type") # never gets here
445
446        # Remember function inputs so we can delay loading the function and
447        # so we can save/restore state
448        self._fn_inputs = [v for v in q_vectors]
449        self._fn = None
450
451        self.update()
452
453    def update(self):
454        self._cache = {}
455
456    def numpoints(self):
457        """
458            Return the number of points
459        """
460        return len(self.Iq)
461
462    def parameters(self):
463        """
464        Return a dictionary of parameters
465        """
466        return self.model.parameters()
467
468    def theory(self):
469        if 'theory' not in self._cache:
470            if self._fn is None:
471                q_input = self.model.kernel.make_input(self._fn_inputs)
472                self._fn = self.model.kernel(q_input)
473
474            fixed_pars = [getattr(self.model, p).value for p in self._fn.fixed_pars]
475            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
476            #print fixed_pars,pd_pars
477            Iq_calc = self._fn(fixed_pars, pd_pars, self.cutoff)
478            #self._theory[:] = self._fn.eval(pars, pd_pars)
479            if self.data_type == 'sesans':
480                result = sesans.hankel(self.data.x, self.data.lam * 1e-9,
481                                       self.data.sample.thickness / 10,
482                                       self._fn_inputs[0], Iq_calc)
483                self._cache['theory'] = result
484            else:
485                Iq = self.resolution.apply(Iq_calc)
486                self._cache['theory'] = Iq
487        return self._cache['theory']
488
489    def residuals(self):
490        #if np.any(self.err ==0): print "zeros in err"
491        return (self.theory() - self.Iq) / self.dIq
492
493    def nllf(self):
494        delta = self.residuals()
495        #if np.any(np.isnan(R)): print "NaN in residuals"
496        return 0.5 * np.sum(delta ** 2)
497
498    #def __call__(self):
499    #    return 2 * self.nllf() / self.dof
500
501    def plot(self, view='log'):
502        """
503        Plot the data and residuals.
504        """
505        data, theory, resid = self.data, self.theory(), self.residuals()
506        if self.data_type == 'Iq':
507            _plot_result1D(data, theory, resid, view)
508        elif self.data_type == 'Iqxy':
509            _plot_result2D(data, theory, resid, view)
510        elif self.data_type == 'sesans':
511            _plot_sesans(data, theory, resid, view)
512        else:
513            raise ValueError("Unknown data type")
514
515    def simulate_data(self, noise=None):
516        theory = self.theory()
517        if noise is not None:
518            self.dIq = theory*noise*0.01
519        dy = self.dIq
520        y = theory + np.random.randn(*dy.shape) * dy
521        self.Iq = y
522        if self.data_type == 'Iq':
523            self.data.dy[self.index] = dy
524            self.data.y[self.index] = y
525        elif self.data_type == 'Iqxy':
526            self.data.data[self.index] = y
527        elif self.data_type == 'sesans':
528            self.data.y[self.index] = y
529        else:
530            raise ValueError("Unknown model")
531
532    def save(self, basename):
533        pass
534
535    def _get_weights(self, par):
536        """
537        Get parameter dispersion weights
538        """
539        from . import weights
540
541        relative = self.model.kernel.info['partype']['pd-rel']
542        limits = self.model.kernel.info['limits']
543        disperser, value, npts, width, nsigma = [
544            getattr(self.model, par + ext)
545            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')]
546        value, weight = weights.get_weights(
547            disperser, int(npts.value), width.value, nsigma.value,
548            value.value, limits[par], par in relative)
549        return value, weight / np.sum(weight)
550
551    def __getstate__(self):
552        # Can't pickle gpu functions, so instead make them lazy
553        state = self.__dict__.copy()
554        state['_fn'] = None
555        return state
556
557    def __setstate__(self, state):
558        # pylint: disable=attribute-defined-outside-init
559        self.__dict__ = state
560
561
562def demo():
563    data = load_data('DEC07086.DAT')
564    set_beam_stop(data, 0.004)
565    plot_data(data)
566    import matplotlib.pyplot as plt; plt.show()
567
568
569if __name__ == "__main__":
570    demo()
Note: See TracBrowser for help on using the repository browser.