source: sasmodels/sasmodels/bumps_model.py @ 3e6aaad

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 3e6aaad was 3e6aaad, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

add resolution accuracy setting to compare

  • Property mode set to 100644
File size: 18.2 KB
Line 
1"""
2Wrap sasmodels for direct use by bumps.
3
4:class:`Model` is a wrapper for the sasmodels kernel which defines a
5bumps *Parameter* box for each kernel parameter.  *Model* accepts keyword
6arguments to set the initial value for each parameter.
7
8:class:`Experiment` combines the *Model* function with a data file loaded by the
9sasview data loader.  *Experiment* takes a *cutoff* parameter controlling
10how far the polydispersity integral extends.
11
12A variety of helper functions are provided:
13
14    :func:`load_data` loads a sasview data file.
15
16    :func:`empty_data1D` creates an empty dataset, which is useful for plotting
17    a theory function before the data is measured.
18
19    :func:`empty_data2D` creates an empty 2D dataset.
20
21    :func:`set_beam_stop` masks the beam stop from the data.
22
23    :func:`set_half` selects the right or left half of the data, which can
24    be useful for shear measurements which have not been properly corrected
25    for path length and reflections.
26
27    :func:`set_top` cuts the top part off the data.
28
29    :func:`plot_data` plots the data file.
30
31    :func:`plot_theory` plots a calculated result from the model.
32
33"""
34
35import datetime
36import warnings
37
38import numpy as np
39
40from . import sesans
41from .resolution import Perfect1D, Pinhole1D, Slit1D
42from .resolution2d import Pinhole2D
43
44# CRUFT python 2.6
45if not hasattr(datetime.timedelta, 'total_seconds'):
46    def delay(dt):
47        """Return number date-time delta as number seconds"""
48        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
49else:
50    def delay(dt):
51        """Return number date-time delta as number seconds"""
52        return dt.total_seconds()
53
54
55# CRUFT: old style bumps wrapper which doesn't separate data and model
56def BumpsModel(data, model, cutoff=1e-5, **kw):
57    warnings.warn("Use of BumpsModel is deprecated.  Use bumps_model.Experiment instead.")
58    model = Model(model, **kw)
59    experiment = Experiment(data=data, model=model, cutoff=cutoff)
60    for k in model._parameter_names:
61        setattr(experiment, k, getattr(model, k))
62    return experiment
63
64
65
66def tic():
67    """
68    Timer function.
69
70    Use "toc=tic()" to start the clock and "toc()" to measure
71    a time interval.
72    """
73    then = datetime.datetime.now()
74    return lambda: delay(datetime.datetime.now() - then)
75
76
77def load_data(filename):
78    """
79    Load data using a sasview loader.
80    """
81    from sas.dataloader.loader import Loader
82    loader = Loader()
83    data = loader.load(filename)
84    if data is None:
85        raise IOError("Data %r could not be loaded" % filename)
86    return data
87
88def plot_data(data, view='log'):
89    """
90    Plot data loaded by the sasview loader.
91    """
92    if hasattr(data, 'qx_data'):
93        _plot_2d_signal(data, data.data, view=view)
94    else:
95        # Note: kind of weird using the _plot_result1D to plot just the
96        # data, but it handles the masking and graph markup already, so
97        # do not repeat.
98        _plot_result1D(data, None, None, view)
99
100def plot_theory(data, theory, view='log'):
101    if hasattr(data, 'qx_data'):
102        _plot_2d_signal(data, theory, view=view)
103    else:
104        _plot_result1D(data, theory, None, view, include_data=False)
105
106
107def empty_data1D(q, resolution=0.05):
108    """
109    Create empty 1D data using the given *q* as the x value.
110
111    *resolution* dq/q defaults to 5%.
112    """
113
114    from sas.dataloader.data_info import Data1D
115
116    Iq = 100 * np.ones_like(q)
117    dIq = np.sqrt(Iq)
118    data = Data1D(q, Iq, dx=resolution * q, dy=dIq)
119    data.filename = "fake data"
120    data.qmin, data.qmax = q.min(), q.max()
121    data.mask = np.zeros(len(Iq), dtype='bool')
122    return data
123
124
125def empty_data2D(qx, qy=None, resolution=0.05):
126    """
127    Create empty 2D data using the given mesh.
128
129    If *qy* is missing, create a square mesh with *qy=qx*.
130
131    *resolution* dq/q defaults to 5%.
132    """
133    from sas.dataloader.data_info import Data2D, Detector
134
135    if qy is None:
136        qy = qx
137    Qx, Qy = np.meshgrid(qx, qy)
138    Qx, Qy = Qx.flatten(), Qy.flatten()
139    Iq = 100 * np.ones_like(Qx)
140    dIq = np.sqrt(Iq)
141    mask = np.ones(len(Iq), dtype='bool')
142
143    data = Data2D()
144    data.filename = "fake data"
145    data.qx_data = Qx
146    data.qy_data = Qy
147    data.data = Iq
148    data.err_data = dIq
149    data.mask = mask
150    data.qmin = 1e-16
151    data.qmax = np.inf
152
153    # 5% dQ/Q resolution
154    if resolution != 0:
155        data.dqx_data = resolution * Qx
156        data.dqy_data = resolution * Qy
157
158    detector = Detector()
159    detector.pixel_size.x = 5 # mm
160    detector.pixel_size.y = 5 # mm
161    detector.distance = 4 # m
162    data.detector.append(detector)
163    data.xbins = qx
164    data.ybins = qy
165    data.source.wavelength = 5 # angstroms
166    data.source.wavelength_unit = "A"
167    data.Q_unit = "1/A"
168    data.I_unit = "1/cm"
169    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
170    data.xaxis("Q_x", "A^{-1}")
171    data.yaxis("Q_y", "A^{-1}")
172    data.zaxis("Intensity", r"\text{cm}^{-1}")
173    return data
174
175
176def set_beam_stop(data, radius, outer=None):
177    """
178    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
179    """
180    from sas.dataloader.manipulations import Ringcut
181    if hasattr(data, 'qx_data'):
182        data.mask = Ringcut(0, radius)(data)
183        if outer is not None:
184            data.mask += Ringcut(outer, np.inf)(data)
185    else:
186        data.mask = (data.x >= radius)
187        if outer is not None:
188            data.mask &= (data.x < outer)
189
190
191def set_half(data, half):
192    """
193    Select half of the data, either "right" or "left".
194    """
195    from sas.dataloader.manipulations import Boxcut
196    if half == 'right':
197        data.mask += \
198            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
199    if half == 'left':
200        data.mask += \
201            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
202
203
204def set_top(data, cutoff):
205    """
206    Chop the top off the data, above *cutoff*.
207    """
208    from sas.dataloader.manipulations import Boxcut
209    data.mask += \
210        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
211
212
213def _plot_result1D(data, theory, resid, view, include_data=True):
214    """
215    Plot the data and residuals for 1D data.
216    """
217    import matplotlib.pyplot as plt
218    from numpy.ma import masked_array, masked
219    #print "not a number",sum(np.isnan(data.y))
220    #data.y[data.y<0.05] = 0.5
221    mdata = masked_array(data.y, data.mask)
222    mdata[~np.isfinite(mdata)] = masked
223    if view is 'log':
224        mdata[mdata <= 0] = masked
225
226    scale = data.x**4 if view == 'q4' else 1.0
227    if resid is not None:
228        plt.subplot(121)
229
230    if include_data:
231        plt.errorbar(data.x, scale*mdata, yerr=data.dy, fmt='.')
232    if theory is not None:
233        mtheory = masked_array(theory, mdata.mask)
234        plt.plot(data.x, scale*mtheory, '-', hold=True)
235    plt.xscale(view)
236    plt.yscale('linear' if view == 'q4' else view)
237    plt.xlabel('Q')
238    plt.ylabel('I(Q)')
239    if resid is not None:
240        mresid = masked_array(resid, mdata.mask)
241        plt.subplot(122)
242        plt.plot(data.x, mresid, 'x')
243        plt.ylabel('residuals')
244        plt.xlabel('Q')
245        plt.xscale(view)
246
247# pylint: disable=unused-argument
248def _plot_sesans(data, theory, resid, view):
249    import matplotlib.pyplot as plt
250    plt.subplot(121)
251    plt.errorbar(data.x, data.y, yerr=data.dy)
252    plt.plot(data.x, theory, '-', hold=True)
253    plt.xlabel('spin echo length (nm)')
254    plt.ylabel('polarization (P/P0)')
255    plt.subplot(122)
256    plt.plot(data.x, resid, 'x')
257    plt.xlabel('spin echo length (nm)')
258    plt.ylabel('residuals (P/P0)')
259
260def _plot_result2D(data, theory, resid, view):
261    """
262    Plot the data and residuals for 2D data.
263    """
264    import matplotlib.pyplot as plt
265    target = data.data[~data.mask]
266    if view == 'log':
267        vmin = min(target[target>0].min(), theory[theory>0].min())
268        vmax = max(target.max(), theory.max())
269    else:
270        vmin = min(target.min(), theory.min())
271        vmax = max(target.max(), theory.max())
272    #print vmin, vmax
273    plt.subplot(131)
274    _plot_2d_signal(data, target, view=view, vmin=vmin, vmax=vmax)
275    plt.title('data')
276    plt.colorbar()
277    plt.subplot(132)
278    _plot_2d_signal(data, theory, view=view, vmin=vmin, vmax=vmax)
279    plt.title('theory')
280    plt.colorbar()
281    plt.subplot(133)
282    _plot_2d_signal(data, resid, view='linear')
283    plt.title('residuals')
284    plt.colorbar()
285
286def _plot_2d_signal(data, signal, vmin=None, vmax=None, view='log'):
287    """
288    Plot the target value for the data.  This could be the data itself,
289    the theory calculation, or the residuals.
290
291    *scale* can be 'log' for log scale data, or 'linear'.
292    """
293    import matplotlib.pyplot as plt
294    from numpy.ma import masked_array
295
296    image = np.zeros_like(data.qx_data)
297    image[~data.mask] = signal
298    valid = np.isfinite(image)
299    if view == 'log':
300        valid[valid] = (image[valid] > 0)
301        image[valid] = np.log10(image[valid])
302    elif view == 'q4':
303        image[valid] *= (data.qx_data[valid]**2+data.qy_data[valid]**2)**2
304    image[~valid | data.mask] = 0
305    #plottable = Iq
306    plottable = masked_array(image, ~valid | data.mask)
307    xmin, xmax = min(data.qx_data), max(data.qx_data)
308    ymin, ymax = min(data.qy_data), max(data.qy_data)
309    # TODO: fix vmin, vmax so it is shared for theory/resid
310    vmin = vmax = None
311    try:
312        if vmin is None: vmin = image[valid & ~data.mask].min()
313        if vmax is None: vmax = image[valid & ~data.mask].max()
314    except:
315        vmin, vmax = 0, 1
316    #print vmin,vmax
317    plt.imshow(plottable.reshape(128, 128),
318               interpolation='nearest', aspect=1, origin='upper',
319               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
320
321
322class Model(object):
323    def __init__(self, kernel, **kw):
324        from bumps.names import Parameter
325
326        self.kernel = kernel
327        partype = kernel.info['partype']
328
329        pars = []
330        for p in kernel.info['parameters']:
331            name, default, limits = p[0], p[2], p[3]
332            value = kw.pop(name, default)
333            setattr(self, name, Parameter.default(value, name=name, limits=limits))
334            pars.append(name)
335        for name in partype['pd-2d']:
336            for xpart, xdefault, xlimits in [
337                ('_pd', 0, limits),
338                ('_pd_n', 35, (0, 1000)),
339                ('_pd_nsigma', 3, (0, 10)),
340                ('_pd_type', 'gaussian', None),
341                ]:
342                xname = name + xpart
343                xvalue = kw.pop(xname, xdefault)
344                if xlimits is not None:
345                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
346                    pars.append(xname)
347                setattr(self, xname, xvalue)
348        self._parameter_names = pars
349        if kw:
350            raise TypeError("unexpected parameters: %s"
351                            % (", ".join(sorted(kw.keys()))))
352
353    def parameters(self):
354        """
355        Return a dictionary of parameters
356        """
357        return dict((k, getattr(self, k)) for k in self._parameter_names)
358
359class Experiment(object):
360    """
361    Return a bumps wrapper for a SAS model.
362
363    *data* is the data to be fitted.
364
365    *model* is the SAS model from :func:`core.load_model`.
366
367    *cutoff* is the integration cutoff, which avoids computing the
368    the SAS model where the polydispersity weight is low.
369
370    Model parameters can be initialized with additional keyword
371    arguments, or by assigning to model.parameter_name.value.
372
373    The resulting bumps model can be used directly in a FitProblem call.
374    """
375    def __init__(self, data, model, cutoff=1e-5):
376
377        # remember inputs so we can inspect from outside
378        self.data = data
379        self.model = model
380        self.cutoff = cutoff
381        if hasattr(data, 'lam'):
382            self.data_type = 'sesans'
383        elif hasattr(data, 'qx_data'):
384            self.data_type = 'Iqxy'
385        else:
386            self.data_type = 'Iq'
387
388        # interpret data
389        partype = model.kernel.info['partype']
390        if self.data_type == 'sesans':
391            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
392            self.index = slice(None, None)
393            self.Iq = data.y
394            self.dIq = data.dy
395            #self._theory = np.zeros_like(q)
396            q_vectors = [q]
397        elif self.data_type == 'Iqxy':
398            q = np.sqrt(data.qx_data**2 + data.qy_data**2)
399            qmin = getattr(data, 'qmin', 1e-16)
400            qmax = getattr(data, 'qmax', np.inf)
401            accuracy = getattr(data, 'accuracy', 'Low')
402            self.index = (~data.mask) & (~np.isnan(data.data)) \
403                         & (q >= qmin) & (q <= qmax)
404            self.Iq = data.data[self.index]
405            self.dIq = data.err_data[self.index]
406            self.resolution = Pinhole2D(data=data, index=self.index,
407                                        nsigma=3.0, accuracy=accuracy)
408            #self._theory = np.zeros_like(self.Iq)
409            if not partype['orientation'] and not partype['magnetic']:
410                raise ValueError("not 2D without orientation or magnetic parameters")
411                #qx,qy = self.resolution.q_calc
412                #q_vectors = [np.sqrt(qx**2 + qy**2)]
413            else:
414                q_vectors = self.resolution.q_calc
415        elif self.data_type == 'Iq':
416            self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y)
417            self.Iq = data.y[self.index]
418            self.dIq = data.dy[self.index]
419            if getattr(data, 'dx', None) is not None:
420                q, dq = data.x[self.index], data.dx[self.index]
421                if (dq>0).any():
422                    self.resolution = Pinhole1D(q, dq)
423                else:
424                    self.resolution = Perfect1D(q)
425            elif (getattr(data, 'dxl', None) is not None and
426                  getattr(data, 'dxw', None) is not None):
427                q = data.x[self.index]
428                width = data.dxh[self.index]  # Note: dx
429                self.resolution = Slit1D(data.x[self.index],
430                                         width=data.dxh[self.index],
431                                         height=data.dxw[self.index])
432            else:
433                self.resolution = Perfect1D(data.x[self.index])
434
435            #self._theory = np.zeros_like(self.Iq)
436            q_vectors = [self.resolution.q_calc]
437        else:
438            raise ValueError("Unknown data type") # never gets here
439
440        # Remember function inputs so we can delay loading the function and
441        # so we can save/restore state
442        self._fn_inputs = [v for v in q_vectors]
443        self._fn = None
444
445        self.update()
446
447    def update(self):
448        self._cache = {}
449
450    def numpoints(self):
451        """
452            Return the number of points
453        """
454        return len(self.Iq)
455
456    def parameters(self):
457        """
458        Return a dictionary of parameters
459        """
460        return self.model.parameters()
461
462    def theory(self):
463        if 'theory' not in self._cache:
464            if self._fn is None:
465                q_input = self.model.kernel.make_input(self._fn_inputs)
466                self._fn = self.model.kernel(q_input)
467
468            fixed_pars = [getattr(self.model, p).value for p in self._fn.fixed_pars]
469            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
470            #print fixed_pars,pd_pars
471            Iq_calc = self._fn(fixed_pars, pd_pars, self.cutoff)
472            #self._theory[:] = self._fn.eval(pars, pd_pars)
473            if self.data_type == 'sesans':
474                result = sesans.hankel(self.data.x, self.data.lam * 1e-9,
475                                       self.data.sample.thickness / 10,
476                                       self._fn_inputs[0], Iq_calc)
477                self._cache['theory'] = result
478            else:
479                Iq = self.resolution.apply(Iq_calc)
480                self._cache['theory'] = Iq
481        return self._cache['theory']
482
483    def residuals(self):
484        #if np.any(self.err ==0): print "zeros in err"
485        return (self.theory() - self.Iq) / self.dIq
486
487    def nllf(self):
488        delta = self.residuals()
489        #if np.any(np.isnan(R)): print "NaN in residuals"
490        return 0.5 * np.sum(delta ** 2)
491
492    #def __call__(self):
493    #    return 2 * self.nllf() / self.dof
494
495    def plot(self, view='log'):
496        """
497        Plot the data and residuals.
498        """
499        data, theory, resid = self.data, self.theory(), self.residuals()
500        if self.data_type == 'Iq':
501            _plot_result1D(data, theory, resid, view)
502        elif self.data_type == 'Iqxy':
503            _plot_result2D(data, theory, resid, view)
504        elif self.data_type == 'sesans':
505            _plot_sesans(data, theory, resid, view)
506        else:
507            raise ValueError("Unknown data type")
508
509    def simulate_data(self, noise=None):
510        theory = self.theory()
511        if noise is not None:
512            self.dIq = theory*noise*0.01
513        dy = self.dIq
514        y = theory + np.random.randn(*dy.shape) * dy
515        self.Iq = y
516        if self.data_type == 'Iq':
517            self.data.dy[self.index] = dy
518            self.data.y[self.index] = y
519        elif self.data_type == 'Iqxy':
520            self.data.data[self.index] = y
521        elif self.data_type == 'sesans':
522            self.data.y[self.index] = y
523        else:
524            raise ValueError("Unknown model")
525
526    def save(self, basename):
527        pass
528
529    def _get_weights(self, par):
530        """
531        Get parameter dispersion weights
532        """
533        from . import weights
534
535        relative = self.model.kernel.info['partype']['pd-rel']
536        limits = self.model.kernel.info['limits']
537        disperser, value, npts, width, nsigma = [
538            getattr(self.model, par + ext)
539            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')]
540        value, weight = weights.get_weights(
541            disperser, int(npts.value), width.value, nsigma.value,
542            value.value, limits[par], par in relative)
543        return value, weight / np.sum(weight)
544
545    def __getstate__(self):
546        # Can't pickle gpu functions, so instead make them lazy
547        state = self.__dict__.copy()
548        state['_fn'] = None
549        return state
550
551    def __setstate__(self, state):
552        # pylint: disable=attribute-defined-outside-init
553        self.__dict__ = state
554
555
556def demo():
557    data = load_data('DEC07086.DAT')
558    set_beam_stop(data, 0.004)
559    plot_data(data)
560    import matplotlib.pyplot as plt; plt.show()
561
562
563if __name__ == "__main__":
564    demo()
Note: See TracBrowser for help on using the repository browser.