source: sasmodels/sasmodels/bumps_model.py @ 87985ca

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 87985ca was 87985ca, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

clean up source tree

  • Property mode set to 100644
File size: 10.8 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7import numpy as np
8
9try:
10    from .gpu import load_model as _loader
11except ImportError,exc:
12    warnings.warn(str(exc))
13    warnings.warn("OpenCL not available --- using ctypes instead")
14    from .dll import load_model as _loader
15
16def load_model(modelname, dtype='single'):
17    """
18    Load model by name.
19    """
20    sasmodels = __import__('sasmodels.models.'+modelname)
21    module = getattr(sasmodels.models, modelname, None)
22    model = _loader(module, dtype=dtype)
23    return model
24
25
26def tic():
27    """
28    Timer function.
29
30    Use "toc=tic()" to start the clock and "toc()" to measure
31    a time interval.
32    """
33    then = datetime.datetime.now()
34    return lambda: (datetime.datetime.now()-then).total_seconds()
35
36
37def load_data(filename):
38    """
39    Load data using a sasview loader.
40    """
41    from sans.dataloader.loader import Loader
42    loader = Loader()
43    data = loader.load(filename)
44    if data is None:
45        raise IOError("Data %r could not be loaded"%filename)
46    return data
47
48
49def empty_data1D(q):
50    """
51    Create empty 1D data using the given *q* as the x value.
52
53    Resolutions dq/q is 5%.
54    """
55
56    from sans.dataloader.data_info import Data1D
57
58    Iq = 100*np.ones_like(q)
59    dIq = np.sqrt(Iq)
60    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
61    data.filename = "fake data"
62    data.qmin, data.qmax = q.min(), q.max()
63    return data
64
65
66def empty_data2D(qx, qy=None):
67    """
68    Create empty 2D data using the given mesh.
69
70    If *qy* is missing, create a square mesh with *qy=qx*.
71
72    Resolution dq/q is 5%.
73    """
74    from sans.dataloader.data_info import Data2D, Detector
75
76    if qy is None:
77        qy = qx
78    Qx,Qy = np.meshgrid(qx,qy)
79    Qx,Qy = Qx.flatten(), Qy.flatten()
80    Iq = 100*np.ones_like(Qx)
81    dIq = np.sqrt(Iq)
82    mask = np.ones(len(Iq), dtype='bool')
83
84    data = Data2D()
85    data.filename = "fake data"
86    data.qx_data = Qx
87    data.qy_data = Qy
88    data.data = Iq
89    data.err_data = dIq
90    data.mask = mask
91
92    # 5% dQ/Q resolution
93    data.dqx_data = 0.05*Qx
94    data.dqy_data = 0.05*Qy
95
96    detector = Detector()
97    detector.pixel_size.x = 5 # mm
98    detector.pixel_size.y = 5 # mm
99    detector.distance = 4 # m
100    data.detector.append(detector)
101    data.xbins = qx
102    data.ybins = qy
103    data.source.wavelength = 5 # angstroms
104    data.source.wavelength_unit = "A"
105    data.Q_unit = "1/A"
106    data.I_unit = "1/cm"
107    data.q_data = np.sqrt(Qx**2 + Qy**2)
108    data.xaxis("Q_x", "A^{-1}")
109    data.yaxis("Q_y", "A^{-1}")
110    data.zaxis("Intensity", r"\text{cm}^{-1}")
111    return data
112
113
114def set_beam_stop(data, radius, outer=None):
115    """
116    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
117    """
118    from sans.dataloader.manipulations import Ringcut
119    if hasattr(data, 'qx_data'):
120        data.mask = Ringcut(0, radius)(data)
121        if outer is not None:
122            data.mask += Ringcut(outer,np.inf)(data)
123    else:
124        data.mask = (data.x>=radius)
125        if outer is not None:
126            data.mask &= (data.x<outer)
127
128
129def set_half(data, half):
130    """
131    Select half of the data, either "right" or "left".
132    """
133    from sans.dataloader.manipulations import Boxcut
134    if half == 'right':
135        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
136    if half == 'left':
137        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
138
139
140def set_top(data, max):
141    """
142    Chop the top off the data, above *max*.
143    """
144    from sans.dataloader.manipulations import Boxcut
145    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
146
147
148def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
149    """
150    Plot the target value for the data.  This could be the data itself,
151    the theory calculation, or the residuals.
152
153    *scale* can be 'log' for log scale data, or 'linear'.
154    """
155    from numpy.ma import masked_array, masked
156    import matplotlib.pyplot as plt
157    if hasattr(data, 'qx_data'):
158        iq = iq[:]
159        valid = np.isfinite(iq)
160        if scale == 'log':
161            valid[valid] = (iq[valid] > 0)
162            iq[valid] = np.log10(iq[valid])
163        iq[~valid|data.mask] = 0
164        #plottable = iq
165        plottable = masked_array(iq, ~valid|data.mask)
166        xmin, xmax = min(data.qx_data), max(data.qx_data)
167        ymin, ymax = min(data.qy_data), max(data.qy_data)
168        plt.imshow(plottable.reshape(128,128),
169                   interpolation='nearest', aspect=1, origin='upper',
170                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
171    else: # 1D data
172        if scale == 'linear':
173            idx = np.isfinite(iq)
174            plt.plot(data.x[idx], iq[idx])
175        else:
176            idx = np.isfinite(iq)
177            idx[idx] = (iq[idx]>0)
178            plt.loglog(data.x[idx], iq[idx])
179
180
181def _plot_result1D(data, theory, view):
182    """
183    Plot the data and residuals for 1D data.
184    """
185    import matplotlib.pyplot as plt
186    from numpy.ma import masked_array, masked
187    #print "not a number",sum(np.isnan(data.y))
188    #data.y[data.y<0.05] = 0.5
189    mdata = masked_array(data.y, data.mask)
190    mdata[np.isnan(mdata)] = masked
191    if view is 'log':
192        mdata[mdata <= 0] = masked
193    mtheory = masked_array(theory, mdata.mask)
194    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
195
196    plt.subplot(121)
197    plt.errorbar(data.x, mdata, yerr=data.dy)
198    plt.plot(data.x, mtheory, '-', hold=True)
199    plt.yscale(view)
200    plt.subplot(122)
201    plt.plot(data.x, mresid, 'x')
202    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
203    #plt.axhline(0, color='black', lw=1, hold=True)
204    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
205
206
207def _plot_result2D(data, theory, view):
208    """
209    Plot the data and residuals for 2D data.
210    """
211    import matplotlib.pyplot as plt
212    resid = (theory-data.data)/data.err_data
213    plt.subplot(131)
214    plot_data(data, data.data, scale=view)
215    plt.colorbar()
216    plt.subplot(132)
217    plot_data(data, theory, scale=view)
218    plt.colorbar()
219    plt.subplot(133)
220    plot_data(data, resid, scale='linear')
221    plt.colorbar()
222
223def plot_result(data, theory, view='log'):
224    """
225    Plot the data and residuals.
226    """
227    if hasattr(data, 'qx_data'):
228        _plot_result2D(data, theory, view)
229    else:
230        _plot_result1D(data, theory, view)
231
232
233class BumpsModel(object):
234    """
235    Return a bumps wrapper for a SAS model.
236
237    *data* is the data to be fitted.
238
239    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
240
241    *cutoff* is the integration cutoff, which avoids computing the
242    the SAS model where the polydispersity weight is low.
243
244    Model parameters can be initialized with additional keyword
245    arguments, or by assigning to model.parameter_name.value.
246
247    The resulting bumps model can be used directly in a FitProblem call.
248    """
249    def __init__(self, data, model, cutoff=1e-5, **kw):
250        from bumps.names import Parameter
251        partype = model.info['partype']
252
253        # remember inputs so we can inspect from outside
254        self.data = data
255        self.model = model
256
257        # interpret data
258        if hasattr(data, 'qx_data'):
259            self.index = (data.mask==0) & (~np.isnan(data.data))
260            self.iq = data.data[self.index]
261            self.diq = data.err_data[self.index]
262            self._theory = np.zeros_like(data.data)
263            if not partype['orientation'] and not partype['magnetic']:
264                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
265            else:
266                q_vectors = [data.qx_data, data.qy_data]
267        else:
268            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
269            self.iq = data.y[self.index]
270            self.diq = data.dy[self.index]
271            self._theory = np.zeros_like(data.y)
272            q_vectors = [data.x]
273        #input = model.make_input(q_vectors)
274        input = model.make_input([v[self.index] for v in q_vectors])
275
276        # create model
277        self.fn = model(input)
278        self.cutoff = cutoff
279
280        # define bumps parameters
281        pars = []
282        extras = []
283        for p in model.info['parameters']:
284            name, default, limits, ptype = p[0], p[2], p[3], p[4]
285            value = kw.pop(name, default)
286            setattr(self, name, Parameter.default(value, name=name, limits=limits))
287            pars.append(name)
288        for name in partype['pd-2d']:
289            for xpart,xdefault,xlimits in [
290                    ('_pd', 0, limits),
291                    ('_pd_n', 35, (0,1000)),
292                    ('_pd_nsigma', 3, (0, 10)),
293                    ('_pd_type', 'gaussian', None),
294                ]:
295                xname = name+xpart
296                xvalue = kw.pop(xname, xdefault)
297                if xlimits is not None:
298                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
299                    pars.append(xname)
300                setattr(self, xname, xvalue)
301        self._parameter_names = pars
302        if kw:
303            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
304        self.update()
305
306    def update(self):
307        self._cache = {}
308
309    def numpoints(self):
310        return len(self.iq)
311
312    def parameters(self):
313        return dict((k,getattr(self,k)) for k in self._parameter_names)
314
315    def theory(self):
316        if 'theory' not in self._cache:
317            pars = [getattr(self,p).value for p in self.fn.fixed_pars]
318            pd_pars = [self._get_weights(p) for p in self.fn.pd_pars]
319            #print pars
320            self._theory[self.index] = self.fn(pars, pd_pars, self.cutoff)
321            #self._theory[:] = self.fn.eval(pars, pd_pars)
322            self._cache['theory'] = self._theory
323        return self._cache['theory']
324
325    def residuals(self):
326        #if np.any(self.err ==0): print "zeros in err"
327        return (self.theory()[self.index]-self.iq)/self.diq
328
329    def nllf(self):
330        R = self.residuals()
331        #if np.any(np.isnan(R)): print "NaN in residuals"
332        return 0.5*np.sum(R**2)
333
334    def __call__(self):
335        return 2*self.nllf()/self.dof
336
337    def plot(self, view='log'):
338        plot_result(self.data, self.theory(), view=view)
339
340    def save(self, basename):
341        pass
342
343    def _get_weights(self, par):
344        from . import weights
345
346        relative = self.fn.info['partype']['pd-rel']
347        limits = self.fn.info['limits']
348        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
349                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
350        v,w = weights.get_weights(
351            disperser, int(npts.value), width.value, nsigma.value,
352            value.value, limits[par], par in relative)
353        return v,w/w.max()
354
355
356def demo():
357    data = load_data('DEC07086.DAT')
358    set_beam_stop(data, 0.004)
359    plot_data(data)
360    import matplotlib.pyplot as plt; plt.show()
361
362
363if __name__ == "__main__":
364    demo()
Note: See TracBrowser for help on using the repository browser.