source: sasmodels/sasmodels/bumps_model.py @ 31819c5

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 31819c5 was 31819c5, checked in by pkienzle, 9 years ago

python2.6 does not support timedelta.total_seconds()

  • Property mode set to 100644
File size: 11.5 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7# CRUFT python 2.6
8if not hasattr(datetime.timedelta, 'total_seconds'):
9    def delay(dt): return dt.days*86400 + dt.seconds + 1e-6*dt.microseconds
10else:
11    def delay(dt): return dt.total_seconds()
12
13import numpy as np
14
15try:
16    from .gpu import load_model as _loader
17except ImportError,exc:
18    import warnings
19    warnings.warn(str(exc))
20    warnings.warn("OpenCL not available --- using ctypes instead")
21    from .dll import load_model as _loader
22
23def load_model(modelname, dtype='single'):
24    """
25    Load model by name.
26    """
27    sasmodels = __import__('sasmodels.models.'+modelname)
28    module = getattr(sasmodels.models, modelname, None)
29    model = _loader(module, dtype=dtype)
30    return model
31
32
33def tic():
34    """
35    Timer function.
36
37    Use "toc=tic()" to start the clock and "toc()" to measure
38    a time interval.
39    """
40    then = datetime.datetime.now()
41    return lambda: delay(datetime.datetime.now()-then)
42
43
44def load_data(filename):
45    """
46    Load data using a sasview loader.
47    """
48    from sans.dataloader.loader import Loader
49    loader = Loader()
50    data = loader.load(filename)
51    if data is None:
52        raise IOError("Data %r could not be loaded"%filename)
53    return data
54
55
56def empty_data1D(q):
57    """
58    Create empty 1D data using the given *q* as the x value.
59
60    Resolutions dq/q is 5%.
61    """
62
63    from sans.dataloader.data_info import Data1D
64
65    Iq = 100*np.ones_like(q)
66    dIq = np.sqrt(Iq)
67    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
68    data.filename = "fake data"
69    data.qmin, data.qmax = q.min(), q.max()
70    return data
71
72
73def empty_data2D(qx, qy=None):
74    """
75    Create empty 2D data using the given mesh.
76
77    If *qy* is missing, create a square mesh with *qy=qx*.
78
79    Resolution dq/q is 5%.
80    """
81    from sans.dataloader.data_info import Data2D, Detector
82
83    if qy is None:
84        qy = qx
85    Qx,Qy = np.meshgrid(qx,qy)
86    Qx,Qy = Qx.flatten(), Qy.flatten()
87    Iq = 100*np.ones_like(Qx)
88    dIq = np.sqrt(Iq)
89    mask = np.ones(len(Iq), dtype='bool')
90
91    data = Data2D()
92    data.filename = "fake data"
93    data.qx_data = Qx
94    data.qy_data = Qy
95    data.data = Iq
96    data.err_data = dIq
97    data.mask = mask
98
99    # 5% dQ/Q resolution
100    data.dqx_data = 0.05*Qx
101    data.dqy_data = 0.05*Qy
102
103    detector = Detector()
104    detector.pixel_size.x = 5 # mm
105    detector.pixel_size.y = 5 # mm
106    detector.distance = 4 # m
107    data.detector.append(detector)
108    data.xbins = qx
109    data.ybins = qy
110    data.source.wavelength = 5 # angstroms
111    data.source.wavelength_unit = "A"
112    data.Q_unit = "1/A"
113    data.I_unit = "1/cm"
114    data.q_data = np.sqrt(Qx**2 + Qy**2)
115    data.xaxis("Q_x", "A^{-1}")
116    data.yaxis("Q_y", "A^{-1}")
117    data.zaxis("Intensity", r"\text{cm}^{-1}")
118    return data
119
120
121def set_beam_stop(data, radius, outer=None):
122    """
123    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
124    """
125    from sans.dataloader.manipulations import Ringcut
126    if hasattr(data, 'qx_data'):
127        data.mask = Ringcut(0, radius)(data)
128        if outer is not None:
129            data.mask += Ringcut(outer,np.inf)(data)
130    else:
131        data.mask = (data.x>=radius)
132        if outer is not None:
133            data.mask &= (data.x<outer)
134
135
136def set_half(data, half):
137    """
138    Select half of the data, either "right" or "left".
139    """
140    from sans.dataloader.manipulations import Boxcut
141    if half == 'right':
142        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
143    if half == 'left':
144        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
145
146
147def set_top(data, max):
148    """
149    Chop the top off the data, above *max*.
150    """
151    from sans.dataloader.manipulations import Boxcut
152    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
153
154
155def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
156    """
157    Plot the target value for the data.  This could be the data itself,
158    the theory calculation, or the residuals.
159
160    *scale* can be 'log' for log scale data, or 'linear'.
161    """
162    from numpy.ma import masked_array, masked
163    import matplotlib.pyplot as plt
164    if hasattr(data, 'qx_data'):
165        iq = iq[:]
166        valid = np.isfinite(iq)
167        if scale == 'log':
168            valid[valid] = (iq[valid] > 0)
169            iq[valid] = np.log10(iq[valid])
170        iq[~valid|data.mask] = 0
171        #plottable = iq
172        plottable = masked_array(iq, ~valid|data.mask)
173        xmin, xmax = min(data.qx_data), max(data.qx_data)
174        ymin, ymax = min(data.qy_data), max(data.qy_data)
175        if vmin is None: vmin = iq[valid&~data.mask].min()
176        if vmax is None: vmax = iq[valid&~data.mask].max()
177        plt.imshow(plottable.reshape(128,128),
178                   interpolation='nearest', aspect=1, origin='upper',
179                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
180    else: # 1D data
181        if scale == 'linear':
182            idx = np.isfinite(iq)
183            plt.plot(data.x[idx], iq[idx])
184        else:
185            idx = np.isfinite(iq)
186            idx[idx] = (iq[idx]>0)
187            plt.loglog(data.x[idx], iq[idx])
188
189
190def _plot_result1D(data, theory, view):
191    """
192    Plot the data and residuals for 1D data.
193    """
194    import matplotlib.pyplot as plt
195    from numpy.ma import masked_array, masked
196    #print "not a number",sum(np.isnan(data.y))
197    #data.y[data.y<0.05] = 0.5
198    mdata = masked_array(data.y, data.mask)
199    mdata[np.isnan(mdata)] = masked
200    if view is 'log':
201        mdata[mdata <= 0] = masked
202    mtheory = masked_array(theory, mdata.mask)
203    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
204
205    plt.subplot(121)
206    plt.errorbar(data.x, mdata, yerr=data.dy)
207    plt.plot(data.x, mtheory, '-', hold=True)
208    plt.yscale(view)
209    plt.subplot(122)
210    plt.plot(data.x, mresid, 'x')
211    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
212    #plt.axhline(0, color='black', lw=1, hold=True)
213    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
214
215
216def _plot_result2D(data, theory, view):
217    """
218    Plot the data and residuals for 2D data.
219    """
220    import matplotlib.pyplot as plt
221    resid = (theory-data.data)/data.err_data
222    plt.subplot(131)
223    plot_data(data, data.data, scale=view)
224    plt.colorbar()
225    plt.subplot(132)
226    plot_data(data, theory, scale=view)
227    plt.colorbar()
228    plt.subplot(133)
229    plot_data(data, resid, scale='linear')
230    plt.colorbar()
231
232def plot_result(data, theory, view='log'):
233    """
234    Plot the data and residuals.
235    """
236    if hasattr(data, 'qx_data'):
237        _plot_result2D(data, theory, view)
238    else:
239        _plot_result1D(data, theory, view)
240
241
242class BumpsModel(object):
243    """
244    Return a bumps wrapper for a SAS model.
245
246    *data* is the data to be fitted.
247
248    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
249
250    *cutoff* is the integration cutoff, which avoids computing the
251    the SAS model where the polydispersity weight is low.
252
253    Model parameters can be initialized with additional keyword
254    arguments, or by assigning to model.parameter_name.value.
255
256    The resulting bumps model can be used directly in a FitProblem call.
257    """
258    def __init__(self, data, model, cutoff=1e-5, **kw):
259        from bumps.names import Parameter
260
261        # remember inputs so we can inspect from outside
262        self.data = data
263        self.model = model
264        self.cutoff = cutoff
265
266        partype = model.info['partype']
267
268        # interpret data
269        if hasattr(data, 'qx_data'):
270            self.index = (data.mask==0) & (~np.isnan(data.data))
271            self.iq = data.data[self.index]
272            self.diq = data.err_data[self.index]
273            self._theory = np.zeros_like(data.data)
274            if not partype['orientation'] and not partype['magnetic']:
275                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
276            else:
277                q_vectors = [data.qx_data, data.qy_data]
278        else:
279            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
280            self.iq = data.y[self.index]
281            self.diq = data.dy[self.index]
282            self._theory = np.zeros_like(data.y)
283            q_vectors = [data.x]
284
285        # Remember function inputs so we can delay loading the function and
286        # so we can save/restore state
287        self._fn_inputs = [v[self.index] for v in q_vectors]
288        self._fn = None
289
290        # define bumps parameters
291        pars = []
292        for p in model.info['parameters']:
293            name, default, limits, ptype = p[0], p[2], p[3], p[4]
294            value = kw.pop(name, default)
295            setattr(self, name, Parameter.default(value, name=name, limits=limits))
296            pars.append(name)
297        for name in partype['pd-2d']:
298            for xpart,xdefault,xlimits in [
299                    ('_pd', 0, limits),
300                    ('_pd_n', 35, (0,1000)),
301                    ('_pd_nsigma', 3, (0, 10)),
302                    ('_pd_type', 'gaussian', None),
303                ]:
304                xname = name+xpart
305                xvalue = kw.pop(xname, xdefault)
306                if xlimits is not None:
307                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
308                    pars.append(xname)
309                setattr(self, xname, xvalue)
310        self._parameter_names = pars
311        if kw:
312            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
313        self.update()
314
315    def update(self):
316        self._cache = {}
317
318    def numpoints(self):
319        return len(self.iq)
320
321    def parameters(self):
322        return dict((k,getattr(self,k)) for k in self._parameter_names)
323
324    def theory(self):
325        if 'theory' not in self._cache:
326            if self._fn is None:
327                input = self.model.make_input(self._fn_inputs)
328                self._fn = self.model(input)
329
330            pars = [getattr(self,p).value for p in self._fn.fixed_pars]
331            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
332            #print pars
333            self._theory[self.index] = self._fn(pars, pd_pars, self.cutoff)
334            #self._theory[:] = self._fn.eval(pars, pd_pars)
335            self._cache['theory'] = self._theory
336        return self._cache['theory']
337
338    def residuals(self):
339        #if np.any(self.err ==0): print "zeros in err"
340        return (self.theory()[self.index]-self.iq)/self.diq
341
342    def nllf(self):
343        R = self.residuals()
344        #if np.any(np.isnan(R)): print "NaN in residuals"
345        return 0.5*np.sum(R**2)
346
347    def __call__(self):
348        return 2*self.nllf()/self.dof
349
350    def plot(self, view='log'):
351        plot_result(self.data, self.theory(), view=view)
352
353    def save(self, basename):
354        pass
355
356    def _get_weights(self, par):
357        from . import weights
358
359        relative = self.model.info['partype']['pd-rel']
360        limits = self.model.info['limits']
361        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
362                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
363        v,w = weights.get_weights(
364            disperser, int(npts.value), width.value, nsigma.value,
365            value.value, limits[par], par in relative)
366        return v,w/w.max()
367
368    def __getstate__(self):
369        # Can't pickle gpu functions, so instead make them lazy
370        state = self.__dict__.copy()
371        state['_fn'] = None
372        return state
373
374    def __setstate__(self, state):
375        self.__dict__ = state
376
377
378def demo():
379    data = load_data('DEC07086.DAT')
380    set_beam_stop(data, 0.004)
381    plot_data(data)
382    import matplotlib.pyplot as plt; plt.show()
383
384
385if __name__ == "__main__":
386    demo()
Note: See TracBrowser for help on using the repository browser.