source: sasmodels/sasmodels/bumps_model.py @ ff7119b

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ff7119b was ff7119b, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

docu update

  • Property mode set to 100644
File size: 9.9 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7import numpy as np
8
9
10def tic():
11    """
12    Timer function.
13
14    Use "toc=tic()" to start the clock and "toc()" to measure
15    a time interval.
16    """
17    then = datetime.datetime.now()
18    return lambda: (datetime.datetime.now()-then).total_seconds()
19
20
21def load_data(filename):
22    """
23    Load data using a sasview loader.
24    """
25    from sans.dataloader.loader import Loader
26    loader = Loader()
27    data = loader.load(filename)
28    if data is None:
29        raise IOError("Data %r could not be loaded"%filename)
30    return data
31
32
33def empty_data1D(q):
34    """
35    Create empty 1D data using the given *q* as the x value.
36
37    Resolutions dq/q is 5%.
38    """
39
40    from sans.dataloader.data_info import Data1D
41
42    Iq = 100*np.ones_like(q)
43    dIq = np.sqrt(Iq)
44    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
45    data.filename = "fake data"
46    data.qmin, data.qmax = q.min(), q.max()
47    return data
48
49
50def empty_data2D(qx, qy=None):
51    """
52    Create empty 2D data using the given mesh.
53
54    If *qy* is missing, create a square mesh with *qy=qx*.
55
56    Resolution dq/q is 5%.
57    """
58    from sans.dataloader.data_info import Data2D, Detector
59
60    if qy is None:
61        qy = qx
62    Qx,Qy = np.meshgrid(qx,qy)
63    Qx,Qy = Qx.flatten(), Qy.flatten()
64    Iq = 100*np.ones_like(Qx)
65    dIq = np.sqrt(Iq)
66    mask = np.ones(len(Iq), dtype='bool')
67
68    data = Data2D()
69    data.filename = "fake data"
70    data.qx_data = Qx
71    data.qy_data = Qy
72    data.data = Iq
73    data.err_data = dIq
74    data.mask = mask
75
76    # 5% dQ/Q resolution
77    data.dqx_data = 0.05*Qx
78    data.dqy_data = 0.05*Qy
79
80    detector = Detector()
81    detector.pixel_size.x = 5 # mm
82    detector.pixel_size.y = 5 # mm
83    detector.distance = 4 # m
84    data.detector.append(detector)
85    data.xbins = qx
86    data.ybins = qy
87    data.source.wavelength = 5 # angstroms
88    data.source.wavelength_unit = "A"
89    data.Q_unit = "1/A"
90    data.I_unit = "1/cm"
91    data.q_data = np.sqrt(Qx**2 + Qy**2)
92    data.xaxis("Q_x", "A^{-1}")
93    data.yaxis("Q_y", "A^{-1}")
94    data.zaxis("Intensity", r"\text{cm}^{-1}")
95    return data
96
97
98def set_beam_stop(data, radius, outer=None):
99    """
100    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
101    """
102    from sans.dataloader.manipulations import Ringcut
103    if hasattr(data, 'qx_data'):
104        data.mask = Ringcut(0, radius)(data)
105        if outer is not None:
106            data.mask += Ringcut(outer,np.inf)(data)
107    else:
108        data.mask = (data.x>=radius)
109        if outer is not None:
110            data.mask &= (data.x<outer)
111
112
113def set_half(data, half):
114    """
115    Select half of the data, either "right" or "left".
116    """
117    from sans.dataloader.manipulations import Boxcut
118    if half == 'right':
119        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
120    if half == 'left':
121        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
122
123
124def set_top(data, max):
125    """
126    Chop the top off the data, above *max*.
127    """
128    from sans.dataloader.manipulations import Boxcut
129    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
130
131
132def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
133    """
134    Plot the target value for the data.  This could be the data itself,
135    the theory calculation, or the residuals.
136
137    *scale* can be 'log' for log scale data, or 'linear'.
138    """
139    from numpy.ma import masked_array, masked
140    import matplotlib.pyplot as plt
141    if hasattr(data, 'qx_data'):
142        img = masked_array(iq, data.mask)
143        if scale == 'log':
144            img[(img <= 0) | ~np.isfinite(img)] = masked
145            img = np.log10(img)
146        xmin, xmax = min(data.qx_data), max(data.qx_data)
147        ymin, ymax = min(data.qy_data), max(data.qy_data)
148        plt.imshow(img.reshape(128,128),
149                   interpolation='nearest', aspect=1, origin='upper',
150                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
151    else: # 1D data
152        if scale == 'linear':
153            idx = np.isfinite(iq)
154            plt.plot(data.x[idx], iq[idx])
155        else:
156            idx = np.isfinite(iq) & (iq>0)
157            plt.loglog(data.x[idx], iq[idx])
158
159
160def _plot_result1D(data, theory, view):
161    """
162    Plot the data and residuals for 1D data.
163    """
164    import matplotlib.pyplot as plt
165    from numpy.ma import masked_array, masked
166    #print "not a number",sum(np.isnan(data.y))
167    #data.y[data.y<0.05] = 0.5
168    mdata = masked_array(data.y, data.mask)
169    mdata[np.isnan(mdata)] = masked
170    if view is 'log':
171        mdata[mdata <= 0] = masked
172    mtheory = masked_array(theory, mdata.mask)
173    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
174
175    plt.subplot(121)
176    plt.errorbar(data.x, mdata, yerr=data.dy)
177    plt.plot(data.x, mtheory, '-', hold=True)
178    plt.yscale(view)
179    plt.subplot(122)
180    plt.plot(data.x, mresid, 'x')
181    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
182    #plt.axhline(0, color='black', lw=1, hold=True)
183    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
184
185
186def _plot_result2D(data, theory, view):
187    """
188    Plot the data and residuals for 2D data.
189    """
190    import matplotlib.pyplot as plt
191    resid = (theory-data.data)/data.err_data
192    plt.subplot(131)
193    plot_data(data, data.data, scale=view)
194    plt.colorbar()
195    plt.subplot(132)
196    plot_data(data, theory, scale=view)
197    plt.colorbar()
198    plt.subplot(133)
199    plot_data(data, resid, scale='linear')
200    plt.colorbar()
201
202def plot_result(data, theory, view='log'):
203    """
204    Plot the data and residuals.
205    """
206    if hasattr(data, 'qx_data'):
207        _plot_result2D(data, theory, view)
208    else:
209        _plot_result1D(data, theory, view)
210
211
212class BumpsModel(object):
213    """
214    Return a bumps wrapper for a SAS model.
215
216    *data* is the data to be fitted.
217
218    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
219
220    *cutoff* is the integration cutoff, which avoids computing the
221    the SAS model where the polydispersity weight is low.
222
223    Model parameters can be initialized with additional keyword
224    arguments, or by assigning to model.parameter_name.value.
225
226    The resulting bumps model can be used directly in a FitProblem call.
227    """
228    def __init__(self, data, model, cutoff=1e-5, **kw):
229        from bumps.names import Parameter
230
231        # interpret data
232        self.data = data
233        if hasattr(data, 'qx_data'):
234            self.index = (data.mask==0) & (~np.isnan(data.data))
235            self.iq = data.data[self.index]
236            self.diq = data.err_data[self.index]
237            self._theory = np.zeros_like(data.data)
238            q_vectors = [data.qx_data, data.qy_data]
239        else:
240            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
241            self.iq = data.y[self.index]
242            self.diq = data.dy[self.index]
243            self._theory = np.zeros_like(data.y)
244            q_vectors = [data.x]
245        #input = model.make_input(q_vectors)
246        input = model.make_input([v[self.index] for v in q_vectors])
247
248        # create model
249        self.fn = model(input)
250        self.cutoff = cutoff
251
252        # define bumps parameters
253        pars = []
254        extras = []
255        for p in model.info['parameters']:
256            name, default, limits, ptype = p[0], p[2], p[3], p[4]
257            value = kw.pop(name, default)
258            setattr(self, name, Parameter.default(value, name=name, limits=limits))
259            pars.append(name)
260        for name in model.info['partype']['pd-2d']:
261            for xpart,xdefault,xlimits in [
262                    ('_pd', 0, limits),
263                    ('_pd_n', 35, (0,1000)),
264                    ('_pd_nsigma', 3, (0, 10)),
265                    ('_pd_type', 'gaussian', None),
266                ]:
267                xname = name+xpart
268                xvalue = kw.pop(xname, xdefault)
269                if xlimits is not None:
270                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
271                    pars.append(xname)
272                setattr(self, xname, xvalue)
273        self._parameter_names = pars
274        if kw:
275            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
276        self.update()
277
278    def update(self):
279        self._cache = {}
280
281    def numpoints(self):
282        return len(self.iq)
283
284    def parameters(self):
285        return dict((k,getattr(self,k)) for k in self._parameter_names)
286
287    def theory(self):
288        if 'theory' not in self._cache:
289            pars = [getattr(self,p).value for p in self.fn.fixed_pars]
290            pd_pars = [self._get_weights(p) for p in self.fn.pd_pars]
291            #print pars
292            self._theory[self.index] = self.fn(pars, pd_pars, self.cutoff)
293            #self._theory[:] = self.fn.eval(pars, pd_pars)
294            self._cache['theory'] = self._theory
295        return self._cache['theory']
296
297    def residuals(self):
298        #if np.any(self.err ==0): print "zeros in err"
299        return (self.theory()[self.index]-self.iq)/self.diq
300
301    def nllf(self):
302        R = self.residuals()
303        #if np.any(np.isnan(R)): print "NaN in residuals"
304        return 0.5*np.sum(R**2)
305
306    def __call__(self):
307        return 2*self.nllf()/self.dof
308
309    def plot(self, view='log'):
310        plot_result(self.data, self.theory(), view=view)
311
312    def save(self, basename):
313        pass
314
315    def _get_weights(self, par):
316        from . import weights
317
318        relative = self.fn.info['partype']['pd-rel']
319        limits = self.fn.info['limits']
320        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
321                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
322        v,w = weights.get_weights(
323            disperser, int(npts.value), width.value, nsigma.value,
324            value.value, limits[par], par in relative)
325        return v,w/w.max()
326
327
328def demo():
329    data = load_data('DEC07086.DAT')
330    set_beam_stop(data, 0.004)
331    plot_data(data)
332    import matplotlib.pyplot as plt; plt.show()
333
334
335if __name__ == "__main__":
336    demo()
Note: See TracBrowser for help on using the repository browser.