source: sasmodels/sasmodels/bumps_model.py @ 5d4777d

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5d4777d was 5d4777d, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

reorganize, check and update models

  • Property mode set to 100644
File size: 10.2 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7import numpy as np
8
9
10def tic():
11    """
12    Timer function.
13
14    Use "toc=tic()" to start the clock and "toc()" to measure
15    a time interval.
16    """
17    then = datetime.datetime.now()
18    return lambda: (datetime.datetime.now()-then).total_seconds()
19
20
21def load_data(filename):
22    """
23    Load data using a sasview loader.
24    """
25    from sans.dataloader.loader import Loader
26    loader = Loader()
27    data = loader.load(filename)
28    if data is None:
29        raise IOError("Data %r could not be loaded"%filename)
30    return data
31
32
33def empty_data1D(q):
34    """
35    Create empty 1D data using the given *q* as the x value.
36
37    Resolutions dq/q is 5%.
38    """
39
40    from sans.dataloader.data_info import Data1D
41
42    Iq = 100*np.ones_like(q)
43    dIq = np.sqrt(Iq)
44    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
45    data.filename = "fake data"
46    data.qmin, data.qmax = q.min(), q.max()
47    return data
48
49
50def empty_data2D(qx, qy=None):
51    """
52    Create empty 2D data using the given mesh.
53
54    If *qy* is missing, create a square mesh with *qy=qx*.
55
56    Resolution dq/q is 5%.
57    """
58    from sans.dataloader.data_info import Data2D, Detector
59
60    if qy is None:
61        qy = qx
62    Qx,Qy = np.meshgrid(qx,qy)
63    Qx,Qy = Qx.flatten(), Qy.flatten()
64    Iq = 100*np.ones_like(Qx)
65    dIq = np.sqrt(Iq)
66    mask = np.ones(len(Iq), dtype='bool')
67
68    data = Data2D()
69    data.filename = "fake data"
70    data.qx_data = Qx
71    data.qy_data = Qy
72    data.data = Iq
73    data.err_data = dIq
74    data.mask = mask
75
76    # 5% dQ/Q resolution
77    data.dqx_data = 0.05*Qx
78    data.dqy_data = 0.05*Qy
79
80    detector = Detector()
81    detector.pixel_size.x = 5 # mm
82    detector.pixel_size.y = 5 # mm
83    detector.distance = 4 # m
84    data.detector.append(detector)
85    data.xbins = qx
86    data.ybins = qy
87    data.source.wavelength = 5 # angstroms
88    data.source.wavelength_unit = "A"
89    data.Q_unit = "1/A"
90    data.I_unit = "1/cm"
91    data.q_data = np.sqrt(Qx**2 + Qy**2)
92    data.xaxis("Q_x", "A^{-1}")
93    data.yaxis("Q_y", "A^{-1}")
94    data.zaxis("Intensity", r"\text{cm}^{-1}")
95    return data
96
97
98def set_beam_stop(data, radius, outer=None):
99    """
100    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
101    """
102    from sans.dataloader.manipulations import Ringcut
103    if hasattr(data, 'qx_data'):
104        data.mask = Ringcut(0, radius)(data)
105        if outer is not None:
106            data.mask += Ringcut(outer,np.inf)(data)
107    else:
108        data.mask = (data.x>=radius)
109        if outer is not None:
110            data.mask &= (data.x<outer)
111
112
113def set_half(data, half):
114    """
115    Select half of the data, either "right" or "left".
116    """
117    from sans.dataloader.manipulations import Boxcut
118    if half == 'right':
119        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
120    if half == 'left':
121        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
122
123
124def set_top(data, max):
125    """
126    Chop the top off the data, above *max*.
127    """
128    from sans.dataloader.manipulations import Boxcut
129    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
130
131
132def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
133    """
134    Plot the target value for the data.  This could be the data itself,
135    the theory calculation, or the residuals.
136
137    *scale* can be 'log' for log scale data, or 'linear'.
138    """
139    from numpy.ma import masked_array, masked
140    import matplotlib.pyplot as plt
141    if hasattr(data, 'qx_data'):
142        iq = iq[:]
143        valid = np.isfinite(iq)
144        if scale == 'log':
145            valid[valid] = (iq[valid] > 0)
146            iq[valid] = np.log10(iq[valid])
147        iq[~valid|data.mask] = 0
148        #plottable = iq
149        plottable = masked_array(iq, ~valid|data.mask)
150        xmin, xmax = min(data.qx_data), max(data.qx_data)
151        ymin, ymax = min(data.qy_data), max(data.qy_data)
152        plt.imshow(plottable.reshape(128,128),
153                   interpolation='nearest', aspect=1, origin='upper',
154                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
155    else: # 1D data
156        if scale == 'linear':
157            idx = np.isfinite(iq)
158            plt.plot(data.x[idx], iq[idx])
159        else:
160            idx = np.isfinite(iq)
161            idx[idx] = (iq[idx]>0)
162            plt.loglog(data.x[idx], iq[idx])
163
164
165def _plot_result1D(data, theory, view):
166    """
167    Plot the data and residuals for 1D data.
168    """
169    import matplotlib.pyplot as plt
170    from numpy.ma import masked_array, masked
171    #print "not a number",sum(np.isnan(data.y))
172    #data.y[data.y<0.05] = 0.5
173    mdata = masked_array(data.y, data.mask)
174    mdata[np.isnan(mdata)] = masked
175    if view is 'log':
176        mdata[mdata <= 0] = masked
177    mtheory = masked_array(theory, mdata.mask)
178    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
179
180    plt.subplot(121)
181    plt.errorbar(data.x, mdata, yerr=data.dy)
182    plt.plot(data.x, mtheory, '-', hold=True)
183    plt.yscale(view)
184    plt.subplot(122)
185    plt.plot(data.x, mresid, 'x')
186    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
187    #plt.axhline(0, color='black', lw=1, hold=True)
188    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
189
190
191def _plot_result2D(data, theory, view):
192    """
193    Plot the data and residuals for 2D data.
194    """
195    import matplotlib.pyplot as plt
196    resid = (theory-data.data)/data.err_data
197    plt.subplot(131)
198    plot_data(data, data.data, scale=view)
199    plt.colorbar()
200    plt.subplot(132)
201    plot_data(data, theory, scale=view)
202    plt.colorbar()
203    plt.subplot(133)
204    plot_data(data, resid, scale='linear')
205    plt.colorbar()
206
207def plot_result(data, theory, view='log'):
208    """
209    Plot the data and residuals.
210    """
211    if hasattr(data, 'qx_data'):
212        _plot_result2D(data, theory, view)
213    else:
214        _plot_result1D(data, theory, view)
215
216
217class BumpsModel(object):
218    """
219    Return a bumps wrapper for a SAS model.
220
221    *data* is the data to be fitted.
222
223    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
224
225    *cutoff* is the integration cutoff, which avoids computing the
226    the SAS model where the polydispersity weight is low.
227
228    Model parameters can be initialized with additional keyword
229    arguments, or by assigning to model.parameter_name.value.
230
231    The resulting bumps model can be used directly in a FitProblem call.
232    """
233    def __init__(self, data, model, cutoff=1e-5, **kw):
234        from bumps.names import Parameter
235        partype = model.info['partype']
236
237        # interpret data
238        self.data = data
239        if hasattr(data, 'qx_data'):
240            self.index = (data.mask==0) & (~np.isnan(data.data))
241            self.iq = data.data[self.index]
242            self.diq = data.err_data[self.index]
243            self._theory = np.zeros_like(data.data)
244            if not partype['orientation'] and not partype['magnetic']:
245                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
246            else:
247                q_vectors = [data.qx_data, data.qy_data]
248        else:
249            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
250            self.iq = data.y[self.index]
251            self.diq = data.dy[self.index]
252            self._theory = np.zeros_like(data.y)
253            q_vectors = [data.x]
254        #input = model.make_input(q_vectors)
255        input = model.make_input([v[self.index] for v in q_vectors])
256
257        # create model
258        self.fn = model(input)
259        self.cutoff = cutoff
260
261        # define bumps parameters
262        pars = []
263        extras = []
264        for p in model.info['parameters']:
265            name, default, limits, ptype = p[0], p[2], p[3], p[4]
266            value = kw.pop(name, default)
267            setattr(self, name, Parameter.default(value, name=name, limits=limits))
268            pars.append(name)
269        for name in partype['pd-2d']:
270            for xpart,xdefault,xlimits in [
271                    ('_pd', 0, limits),
272                    ('_pd_n', 35, (0,1000)),
273                    ('_pd_nsigma', 3, (0, 10)),
274                    ('_pd_type', 'gaussian', None),
275                ]:
276                xname = name+xpart
277                xvalue = kw.pop(xname, xdefault)
278                if xlimits is not None:
279                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
280                    pars.append(xname)
281                setattr(self, xname, xvalue)
282        self._parameter_names = pars
283        if kw:
284            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
285        self.update()
286
287    def update(self):
288        self._cache = {}
289
290    def numpoints(self):
291        return len(self.iq)
292
293    def parameters(self):
294        return dict((k,getattr(self,k)) for k in self._parameter_names)
295
296    def theory(self):
297        if 'theory' not in self._cache:
298            pars = [getattr(self,p).value for p in self.fn.fixed_pars]
299            pd_pars = [self._get_weights(p) for p in self.fn.pd_pars]
300            #print pars
301            self._theory[self.index] = self.fn(pars, pd_pars, self.cutoff)
302            #self._theory[:] = self.fn.eval(pars, pd_pars)
303            self._cache['theory'] = self._theory
304        return self._cache['theory']
305
306    def residuals(self):
307        #if np.any(self.err ==0): print "zeros in err"
308        return (self.theory()[self.index]-self.iq)/self.diq
309
310    def nllf(self):
311        R = self.residuals()
312        #if np.any(np.isnan(R)): print "NaN in residuals"
313        return 0.5*np.sum(R**2)
314
315    def __call__(self):
316        return 2*self.nllf()/self.dof
317
318    def plot(self, view='log'):
319        plot_result(self.data, self.theory(), view=view)
320
321    def save(self, basename):
322        pass
323
324    def _get_weights(self, par):
325        from . import weights
326
327        relative = self.fn.info['partype']['pd-rel']
328        limits = self.fn.info['limits']
329        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
330                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
331        v,w = weights.get_weights(
332            disperser, int(npts.value), width.value, nsigma.value,
333            value.value, limits[par], par in relative)
334        return v,w/w.max()
335
336
337def demo():
338    data = load_data('DEC07086.DAT')
339    set_beam_stop(data, 0.004)
340    plot_data(data)
341    import matplotlib.pyplot as plt; plt.show()
342
343
344if __name__ == "__main__":
345    demo()
Note: See TracBrowser for help on using the repository browser.