source: sasmodels/sasmodels/bumps_model.py @ aa4946b

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since aa4946b was aa4946b, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

refactor so kernels are loaded via core.load_model

  • Property mode set to 100644
File size: 13.9 KB
Line 
1"""
2Wrap sasmodels for direct use by bumps.
3"""
4
5import datetime
6
7import numpy as np
8
9from . import sesans
10
11# CRUFT python 2.6
12if not hasattr(datetime.timedelta, 'total_seconds'):
13    def delay(dt):
14        """Return number date-time delta as number seconds"""
15        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
16else:
17    def delay(dt):
18        """Return number date-time delta as number seconds"""
19        return dt.total_seconds()
20
21
22def tic():
23    """
24    Timer function.
25
26    Use "toc=tic()" to start the clock and "toc()" to measure
27    a time interval.
28    """
29    then = datetime.datetime.now()
30    return lambda: delay(datetime.datetime.now() - then)
31
32
33def load_data(filename):
34    """
35    Load data using a sasview loader.
36    """
37    from sas.dataloader.loader import Loader
38    loader = Loader()
39    data = loader.load(filename)
40    if data is None:
41        raise IOError("Data %r could not be loaded" % filename)
42    return data
43
44
45def empty_data1D(q):
46    """
47    Create empty 1D data using the given *q* as the x value.
48
49    Resolutions dq/q is 5%.
50    """
51
52    from sas.dataloader.data_info import Data1D
53
54    Iq = 100 * np.ones_like(q)
55    dIq = np.sqrt(Iq)
56    data = Data1D(q, Iq, dx=0.05 * q, dy=dIq)
57    data.filename = "fake data"
58    data.qmin, data.qmax = q.min(), q.max()
59    return data
60
61
62def empty_data2D(qx, qy=None):
63    """
64    Create empty 2D data using the given mesh.
65
66    If *qy* is missing, create a square mesh with *qy=qx*.
67
68    Resolution dq/q is 5%.
69    """
70    from sas.dataloader.data_info import Data2D, Detector
71
72    if qy is None:
73        qy = qx
74    Qx, Qy = np.meshgrid(qx, qy)
75    Qx, Qy = Qx.flatten(), Qy.flatten()
76    Iq = 100 * np.ones_like(Qx)
77    dIq = np.sqrt(Iq)
78    mask = np.ones(len(Iq), dtype='bool')
79
80    data = Data2D()
81    data.filename = "fake data"
82    data.qx_data = Qx
83    data.qy_data = Qy
84    data.data = Iq
85    data.err_data = dIq
86    data.mask = mask
87
88    # 5% dQ/Q resolution
89    data.dqx_data = 0.05 * Qx
90    data.dqy_data = 0.05 * Qy
91
92    detector = Detector()
93    detector.pixel_size.x = 5 # mm
94    detector.pixel_size.y = 5 # mm
95    detector.distance = 4 # m
96    data.detector.append(detector)
97    data.xbins = qx
98    data.ybins = qy
99    data.source.wavelength = 5 # angstroms
100    data.source.wavelength_unit = "A"
101    data.Q_unit = "1/A"
102    data.I_unit = "1/cm"
103    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
104    data.xaxis("Q_x", "A^{-1}")
105    data.yaxis("Q_y", "A^{-1}")
106    data.zaxis("Intensity", r"\text{cm}^{-1}")
107    return data
108
109
110def set_beam_stop(data, radius, outer=None):
111    """
112    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
113    """
114    from sas.dataloader.manipulations import Ringcut
115    if hasattr(data, 'qx_data'):
116        data.mask = Ringcut(0, radius)(data)
117        if outer is not None:
118            data.mask += Ringcut(outer, np.inf)(data)
119    else:
120        data.mask = (data.x >= radius)
121        if outer is not None:
122            data.mask &= (data.x < outer)
123
124
125def set_half(data, half):
126    """
127    Select half of the data, either "right" or "left".
128    """
129    from sas.dataloader.manipulations import Boxcut
130    if half == 'right':
131        data.mask += \
132            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
133    if half == 'left':
134        data.mask += \
135            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
136
137
138def set_top(data, cutoff):
139    """
140    Chop the top off the data, above *cutoff*.
141    """
142    from sas.dataloader.manipulations import Boxcut
143    data.mask += \
144        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
145
146
147def plot_data(data, Iq, vmin=None, vmax=None, view='log'):
148    """
149    Plot the target value for the data.  This could be the data itself,
150    the theory calculation, or the residuals.
151
152    *scale* can be 'log' for log scale data, or 'linear'.
153    """
154    from numpy.ma import masked_array
155    import matplotlib.pyplot as plt
156    if hasattr(data, 'qx_data'):
157        Iq = Iq + 0
158        valid = np.isfinite(Iq)
159        if view == 'log':
160            valid[valid] = (Iq[valid] > 0)
161            Iq[valid] = np.log10(Iq[valid])
162        elif view == 'q4':
163            Iq[valid] = Iq*(data.qx_data[valid]**2+data.qy_data[valid]**2)**2
164        Iq[~valid | data.mask] = 0
165        #plottable = Iq
166        plottable = masked_array(Iq, ~valid | data.mask)
167        xmin, xmax = min(data.qx_data), max(data.qx_data)
168        ymin, ymax = min(data.qy_data), max(data.qy_data)
169        try:
170            if vmin is None: vmin = Iq[valid & ~data.mask].min()
171            if vmax is None: vmax = Iq[valid & ~data.mask].max()
172        except:
173            vmin, vmax = 0, 1
174        plt.imshow(plottable.reshape(128, 128),
175                   interpolation='nearest', aspect=1, origin='upper',
176                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
177    else: # 1D data
178        if view == 'linear' or view == 'q4':
179            #idx = np.isfinite(Iq)
180            scale = data.x**4 if view == 'q4' else 1.0
181            plt.plot(data.x, scale*Iq) #, '.')
182        else:
183            # Find the values that are finite and positive
184            idx = np.isfinite(Iq)
185            idx[idx] = (Iq[idx] > 0)
186            Iq[~idx] = np.nan
187            plt.loglog(data.x, Iq)
188
189
190def _plot_result1D(data, theory, view):
191    """
192    Plot the data and residuals for 1D data.
193    """
194    import matplotlib.pyplot as plt
195    from numpy.ma import masked_array, masked
196    #print "not a number",sum(np.isnan(data.y))
197    #data.y[data.y<0.05] = 0.5
198    mdata = masked_array(data.y, data.mask)
199    mdata[np.isnan(mdata)] = masked
200    if view is 'log':
201        mdata[mdata <= 0] = masked
202    mtheory = masked_array(theory, mdata.mask)
203    mresid = masked_array((theory - data.y) / data.dy, mdata.mask)
204
205    scale = data.x**4 if view == 'q4' else 1.0
206    plt.subplot(121)
207    plt.errorbar(data.x, scale*mdata, yerr=data.dy)
208    plt.plot(data.x, scale*mtheory, '-', hold=True)
209    plt.yscale('linear' if view == 'q4' else view)
210    plt.subplot(122)
211    plt.plot(data.x, mresid, 'x')
212
213# pylint: disable=unused-argument
214def _plot_sesans(data, theory, view):
215    import matplotlib.pyplot as plt
216    resid = (theory - data.y) / data.dy
217    plt.subplot(121)
218    plt.errorbar(data.x, data.y, yerr=data.dy)
219    plt.plot(data.x, theory, '-', hold=True)
220    plt.xlabel('spin echo length (A)')
221    plt.ylabel('polarization')
222    plt.subplot(122)
223    plt.plot(data.x, resid, 'x')
224    plt.xlabel('spin echo length (A)')
225    plt.ylabel('residuals')
226
227def _plot_result2D(data, theory, view):
228    """
229    Plot the data and residuals for 2D data.
230    """
231    import matplotlib.pyplot as plt
232    resid = (theory - data.data) / data.err_data
233    plt.subplot(131)
234    plot_data(data, data.data, view=view)
235    plt.colorbar()
236    plt.subplot(132)
237    plot_data(data, theory, view=view)
238    plt.colorbar()
239    plt.subplot(133)
240    plot_data(data, resid, view='linear')
241    plt.colorbar()
242
243class BumpsModel(object):
244    """
245    Return a bumps wrapper for a SAS model.
246
247    *data* is the data to be fitted.
248
249    *model* is the SAS model from :func:`core.load_model`.
250
251    *cutoff* is the integration cutoff, which avoids computing the
252    the SAS model where the polydispersity weight is low.
253
254    Model parameters can be initialized with additional keyword
255    arguments, or by assigning to model.parameter_name.value.
256
257    The resulting bumps model can be used directly in a FitProblem call.
258    """
259    def __init__(self, data, model, cutoff=1e-5, **kw):
260        from bumps.names import Parameter
261
262        # remember inputs so we can inspect from outside
263        self.data = data
264        self.model = model
265        self.cutoff = cutoff
266        if hasattr(data, 'lam'):
267            self.data_type = 'sesans'
268        elif hasattr(data, 'qx_data'):
269            self.data_type = 'Iqxy'
270        else:
271            self.data_type = 'Iq'
272
273        partype = model.info['partype']
274
275        # interpret data
276        if self.data_type == 'sesans':
277            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
278            self.index = slice(None, None)
279            self.Iq = data.y
280            self.dIq = data.dy
281            self._theory = np.zeros_like(q)
282            q_vectors = [q]
283        elif self.data_type == 'Iqxy':
284            self.index = (data.mask == 0) & (~np.isnan(data.data))
285            self.Iq = data.data[self.index]
286            self.dIq = data.err_data[self.index]
287            self._theory = np.zeros_like(data.data)
288            if not partype['orientation'] and not partype['magnetic']:
289                q_vectors = [np.sqrt(data.qx_data ** 2 + data.qy_data ** 2)]
290            else:
291                q_vectors = [data.qx_data, data.qy_data]
292        elif self.data_type == 'Iq':
293            self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y)
294            self.Iq = data.y[self.index]
295            self.dIq = data.dy[self.index]
296            self._theory = np.zeros_like(data.y)
297            q_vectors = [data.x]
298        else:
299            raise ValueError("Unknown data type") # never gets here
300
301        # Remember function inputs so we can delay loading the function and
302        # so we can save/restore state
303        self._fn_inputs = [v[self.index] for v in q_vectors]
304        self._fn = None
305
306        # define bumps parameters
307        pars = []
308        for p in model.info['parameters']:
309            name, default, limits = p[0], p[2], p[3]
310            value = kw.pop(name, default)
311            setattr(self, name, Parameter.default(value, name=name, limits=limits))
312            pars.append(name)
313        for name in partype['pd-2d']:
314            for xpart, xdefault, xlimits in [
315                    ('_pd', 0, limits),
316                    ('_pd_n', 35, (0, 1000)),
317                    ('_pd_nsigma', 3, (0, 10)),
318                    ('_pd_type', 'gaussian', None),
319                ]:
320                xname = name + xpart
321                xvalue = kw.pop(xname, xdefault)
322                if xlimits is not None:
323                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
324                    pars.append(xname)
325                setattr(self, xname, xvalue)
326        self._parameter_names = pars
327        if kw:
328            raise TypeError("unexpected parameters: %s"
329                            % (", ".join(sorted(kw.keys()))))
330        self.update()
331
332    def update(self):
333        self._cache = {}
334
335    def numpoints(self):
336        """
337            Return the number of points
338        """
339        return len(self.Iq)
340
341    def parameters(self):
342        """
343            Return a dictionary of parameters
344        """
345        return dict((k, getattr(self, k)) for k in self._parameter_names)
346
347    def theory(self):
348        if 'theory' not in self._cache:
349            if self._fn is None:
350                q_input = self.model.make_input(self._fn_inputs)
351                self._fn = self.model(q_input)
352
353            fixed_pars = [getattr(self, p).value for p in self._fn.fixed_pars]
354            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
355            #print fixed_pars,pd_pars
356            self._theory[self.index] = self._fn(fixed_pars, pd_pars, self.cutoff)
357            #self._theory[:] = self._fn.eval(pars, pd_pars)
358            if self.data_type == 'sesans':
359                result = sesans.hankel(self.data.x, self.data.lam * 1e-9,
360                                       self.data.sample.thickness / 10,
361                                       self._fn_inputs[0], self._theory)
362                self._cache['theory'] = result
363            else:
364                self._cache['theory'] = self._theory
365        return self._cache['theory']
366
367    def residuals(self):
368        #if np.any(self.err ==0): print "zeros in err"
369        return (self.theory()[self.index] - self.Iq) / self.dIq
370
371    def nllf(self):
372        delta = self.residuals()
373        #if np.any(np.isnan(R)): print "NaN in residuals"
374        return 0.5 * np.sum(delta ** 2)
375
376    #def __call__(self):
377    #    return 2 * self.nllf() / self.dof
378
379    def plot(self, view='log'):
380        """
381        Plot the data and residuals.
382        """
383        data, theory = self.data, self.theory()
384        if self.data_type == 'Iq':
385            _plot_result1D(data, theory, view)
386        elif self.data_type == 'Iqxy':
387            _plot_result2D(data, theory, view)
388        elif self.data_type == 'sesans':
389            _plot_sesans(data, theory, view)
390        else:
391            raise ValueError("Unknown data type")
392
393    def simulate_data(self, noise=None):
394        print "noise", noise
395        if noise is None:
396            noise = self.dIq[self.index]
397        else:
398            noise = 0.01 * noise
399            self.dIq[self.index] = noise
400        y = self.theory()
401        y += y * np.random.randn(*y.shape) * noise
402        if self.data_type == 'Iq':
403            self.data.y[self.index] = y
404        elif self.data_type == 'Iqxy':
405            self.data.data[self.index] = y
406        elif self.data_type == 'sesans':
407            self.data.y[self.index] = y
408        else:
409            raise ValueError("Unknown model")
410
411    def save(self, basename):
412        pass
413
414    def _get_weights(self, par):
415        """
416        Get parameter dispersion weights
417        """
418        from . import weights
419
420        relative = self.model.info['partype']['pd-rel']
421        limits = self.model.info['limits']
422        disperser, value, npts, width, nsigma = [
423            getattr(self, par + ext)
424            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')]
425        value, weight = weights.get_weights(
426            disperser, int(npts.value), width.value, nsigma.value,
427            value.value, limits[par], par in relative)
428        return value, weight / np.sum(weight)
429
430    def __getstate__(self):
431        # Can't pickle gpu functions, so instead make them lazy
432        state = self.__dict__.copy()
433        state['_fn'] = None
434        return state
435
436    def __setstate__(self, state):
437        # pylint: disable=attribute-defined-outside-init
438        self.__dict__ = state
439
440
441def demo():
442    data = load_data('DEC07086.DAT')
443    set_beam_stop(data, 0.004)
444    plot_data(data, data.data)
445    import matplotlib.pyplot as plt; plt.show()
446
447
448if __name__ == "__main__":
449    demo()
Note: See TracBrowser for help on using the repository browser.