source: sasmodels/sasmodels/bumps_model.py @ ba69383

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since ba69383 was ba69383, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

add relerr histogram to compare

  • Property mode set to 100644
File size: 11.3 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7import numpy as np
8
9try:
10    from .gpu import load_model as _loader
11except ImportError,exc:
12    import warnings
13    warnings.warn(str(exc))
14    warnings.warn("OpenCL not available --- using ctypes instead")
15    from .dll import load_model as _loader
16
17def load_model(modelname, dtype='single'):
18    """
19    Load model by name.
20    """
21    sasmodels = __import__('sasmodels.models.'+modelname)
22    module = getattr(sasmodels.models, modelname, None)
23    model = _loader(module, dtype=dtype)
24    return model
25
26
27def tic():
28    """
29    Timer function.
30
31    Use "toc=tic()" to start the clock and "toc()" to measure
32    a time interval.
33    """
34    then = datetime.datetime.now()
35    return lambda: (datetime.datetime.now()-then).total_seconds()
36
37
38def load_data(filename):
39    """
40    Load data using a sasview loader.
41    """
42    from sans.dataloader.loader import Loader
43    loader = Loader()
44    data = loader.load(filename)
45    if data is None:
46        raise IOError("Data %r could not be loaded"%filename)
47    return data
48
49
50def empty_data1D(q):
51    """
52    Create empty 1D data using the given *q* as the x value.
53
54    Resolutions dq/q is 5%.
55    """
56
57    from sans.dataloader.data_info import Data1D
58
59    Iq = 100*np.ones_like(q)
60    dIq = np.sqrt(Iq)
61    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
62    data.filename = "fake data"
63    data.qmin, data.qmax = q.min(), q.max()
64    return data
65
66
67def empty_data2D(qx, qy=None):
68    """
69    Create empty 2D data using the given mesh.
70
71    If *qy* is missing, create a square mesh with *qy=qx*.
72
73    Resolution dq/q is 5%.
74    """
75    from sans.dataloader.data_info import Data2D, Detector
76
77    if qy is None:
78        qy = qx
79    Qx,Qy = np.meshgrid(qx,qy)
80    Qx,Qy = Qx.flatten(), Qy.flatten()
81    Iq = 100*np.ones_like(Qx)
82    dIq = np.sqrt(Iq)
83    mask = np.ones(len(Iq), dtype='bool')
84
85    data = Data2D()
86    data.filename = "fake data"
87    data.qx_data = Qx
88    data.qy_data = Qy
89    data.data = Iq
90    data.err_data = dIq
91    data.mask = mask
92
93    # 5% dQ/Q resolution
94    data.dqx_data = 0.05*Qx
95    data.dqy_data = 0.05*Qy
96
97    detector = Detector()
98    detector.pixel_size.x = 5 # mm
99    detector.pixel_size.y = 5 # mm
100    detector.distance = 4 # m
101    data.detector.append(detector)
102    data.xbins = qx
103    data.ybins = qy
104    data.source.wavelength = 5 # angstroms
105    data.source.wavelength_unit = "A"
106    data.Q_unit = "1/A"
107    data.I_unit = "1/cm"
108    data.q_data = np.sqrt(Qx**2 + Qy**2)
109    data.xaxis("Q_x", "A^{-1}")
110    data.yaxis("Q_y", "A^{-1}")
111    data.zaxis("Intensity", r"\text{cm}^{-1}")
112    return data
113
114
115def set_beam_stop(data, radius, outer=None):
116    """
117    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
118    """
119    from sans.dataloader.manipulations import Ringcut
120    if hasattr(data, 'qx_data'):
121        data.mask = Ringcut(0, radius)(data)
122        if outer is not None:
123            data.mask += Ringcut(outer,np.inf)(data)
124    else:
125        data.mask = (data.x>=radius)
126        if outer is not None:
127            data.mask &= (data.x<outer)
128
129
130def set_half(data, half):
131    """
132    Select half of the data, either "right" or "left".
133    """
134    from sans.dataloader.manipulations import Boxcut
135    if half == 'right':
136        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
137    if half == 'left':
138        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
139
140
141def set_top(data, max):
142    """
143    Chop the top off the data, above *max*.
144    """
145    from sans.dataloader.manipulations import Boxcut
146    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
147
148
149def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
150    """
151    Plot the target value for the data.  This could be the data itself,
152    the theory calculation, or the residuals.
153
154    *scale* can be 'log' for log scale data, or 'linear'.
155    """
156    from numpy.ma import masked_array, masked
157    import matplotlib.pyplot as plt
158    if hasattr(data, 'qx_data'):
159        iq = iq[:]
160        valid = np.isfinite(iq)
161        if scale == 'log':
162            valid[valid] = (iq[valid] > 0)
163            iq[valid] = np.log10(iq[valid])
164        iq[~valid|data.mask] = 0
165        #plottable = iq
166        plottable = masked_array(iq, ~valid|data.mask)
167        xmin, xmax = min(data.qx_data), max(data.qx_data)
168        ymin, ymax = min(data.qy_data), max(data.qy_data)
169        if vmin is None: vmin = iq[valid&~data.mask].min()
170        if vmax is None: vmax = iq[valid&~data.mask].max()
171        plt.imshow(plottable.reshape(128,128),
172                   interpolation='nearest', aspect=1, origin='upper',
173                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
174    else: # 1D data
175        if scale == 'linear':
176            idx = np.isfinite(iq)
177            plt.plot(data.x[idx], iq[idx])
178        else:
179            idx = np.isfinite(iq)
180            idx[idx] = (iq[idx]>0)
181            plt.loglog(data.x[idx], iq[idx])
182
183
184def _plot_result1D(data, theory, view):
185    """
186    Plot the data and residuals for 1D data.
187    """
188    import matplotlib.pyplot as plt
189    from numpy.ma import masked_array, masked
190    #print "not a number",sum(np.isnan(data.y))
191    #data.y[data.y<0.05] = 0.5
192    mdata = masked_array(data.y, data.mask)
193    mdata[np.isnan(mdata)] = masked
194    if view is 'log':
195        mdata[mdata <= 0] = masked
196    mtheory = masked_array(theory, mdata.mask)
197    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
198
199    plt.subplot(121)
200    plt.errorbar(data.x, mdata, yerr=data.dy)
201    plt.plot(data.x, mtheory, '-', hold=True)
202    plt.yscale(view)
203    plt.subplot(122)
204    plt.plot(data.x, mresid, 'x')
205    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
206    #plt.axhline(0, color='black', lw=1, hold=True)
207    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
208
209
210def _plot_result2D(data, theory, view):
211    """
212    Plot the data and residuals for 2D data.
213    """
214    import matplotlib.pyplot as plt
215    resid = (theory-data.data)/data.err_data
216    plt.subplot(131)
217    plot_data(data, data.data, scale=view)
218    plt.colorbar()
219    plt.subplot(132)
220    plot_data(data, theory, scale=view)
221    plt.colorbar()
222    plt.subplot(133)
223    plot_data(data, resid, scale='linear')
224    plt.colorbar()
225
226def plot_result(data, theory, view='log'):
227    """
228    Plot the data and residuals.
229    """
230    if hasattr(data, 'qx_data'):
231        _plot_result2D(data, theory, view)
232    else:
233        _plot_result1D(data, theory, view)
234
235
236class BumpsModel(object):
237    """
238    Return a bumps wrapper for a SAS model.
239
240    *data* is the data to be fitted.
241
242    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
243
244    *cutoff* is the integration cutoff, which avoids computing the
245    the SAS model where the polydispersity weight is low.
246
247    Model parameters can be initialized with additional keyword
248    arguments, or by assigning to model.parameter_name.value.
249
250    The resulting bumps model can be used directly in a FitProblem call.
251    """
252    def __init__(self, data, model, cutoff=1e-5, **kw):
253        from bumps.names import Parameter
254
255        # remember inputs so we can inspect from outside
256        self.data = data
257        self.model = model
258        self.cutoff = cutoff
259
260        partype = model.info['partype']
261
262        # interpret data
263        if hasattr(data, 'qx_data'):
264            self.index = (data.mask==0) & (~np.isnan(data.data))
265            self.iq = data.data[self.index]
266            self.diq = data.err_data[self.index]
267            self._theory = np.zeros_like(data.data)
268            if not partype['orientation'] and not partype['magnetic']:
269                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
270            else:
271                q_vectors = [data.qx_data, data.qy_data]
272        else:
273            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
274            self.iq = data.y[self.index]
275            self.diq = data.dy[self.index]
276            self._theory = np.zeros_like(data.y)
277            q_vectors = [data.x]
278
279        # Remember function inputs so we can delay loading the function and
280        # so we can save/restore state
281        self._fn_inputs = [v[self.index] for v in q_vectors]
282        self._fn = None
283
284        # define bumps parameters
285        pars = []
286        for p in model.info['parameters']:
287            name, default, limits, ptype = p[0], p[2], p[3], p[4]
288            value = kw.pop(name, default)
289            setattr(self, name, Parameter.default(value, name=name, limits=limits))
290            pars.append(name)
291        for name in partype['pd-2d']:
292            for xpart,xdefault,xlimits in [
293                    ('_pd', 0, limits),
294                    ('_pd_n', 35, (0,1000)),
295                    ('_pd_nsigma', 3, (0, 10)),
296                    ('_pd_type', 'gaussian', None),
297                ]:
298                xname = name+xpart
299                xvalue = kw.pop(xname, xdefault)
300                if xlimits is not None:
301                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
302                    pars.append(xname)
303                setattr(self, xname, xvalue)
304        self._parameter_names = pars
305        if kw:
306            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
307        self.update()
308
309    def update(self):
310        self._cache = {}
311
312    def numpoints(self):
313        return len(self.iq)
314
315    def parameters(self):
316        return dict((k,getattr(self,k)) for k in self._parameter_names)
317
318    def theory(self):
319        if 'theory' not in self._cache:
320            if self._fn is None:
321                input = self.model.make_input(self._fn_inputs)
322                self._fn = self.model(input)
323
324            pars = [getattr(self,p).value for p in self._fn.fixed_pars]
325            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
326            #print pars
327            self._theory[self.index] = self._fn(pars, pd_pars, self.cutoff)
328            #self._theory[:] = self._fn.eval(pars, pd_pars)
329            self._cache['theory'] = self._theory
330        return self._cache['theory']
331
332    def residuals(self):
333        #if np.any(self.err ==0): print "zeros in err"
334        return (self.theory()[self.index]-self.iq)/self.diq
335
336    def nllf(self):
337        R = self.residuals()
338        #if np.any(np.isnan(R)): print "NaN in residuals"
339        return 0.5*np.sum(R**2)
340
341    def __call__(self):
342        return 2*self.nllf()/self.dof
343
344    def plot(self, view='log'):
345        plot_result(self.data, self.theory(), view=view)
346
347    def save(self, basename):
348        pass
349
350    def _get_weights(self, par):
351        from . import weights
352
353        relative = self.model.info['partype']['pd-rel']
354        limits = self.model.info['limits']
355        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
356                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
357        v,w = weights.get_weights(
358            disperser, int(npts.value), width.value, nsigma.value,
359            value.value, limits[par], par in relative)
360        return v,w/w.max()
361
362    def __getstate__(self):
363        # Can't pickle gpu functions, so instead make them lazy
364        state = self.__dict__.copy()
365        state['_fn'] = None
366        return state
367
368    def __setstate__(self, state):
369        self.__dict__ = state
370
371
372def demo():
373    data = load_data('DEC07086.DAT')
374    set_beam_stop(data, 0.004)
375    plot_data(data)
376    import matplotlib.pyplot as plt; plt.show()
377
378
379if __name__ == "__main__":
380    demo()
Note: See TracBrowser for help on using the repository browser.