source: sasmodels/sasmodels/bumps_model.py @ 5134b2c

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5134b2c was 5134b2c, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

fix compare plots so they show both positive and negative error

  • Property mode set to 100644
File size: 13.6 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7from sasmodels import sesans
8
9# CRUFT python 2.6
10if not hasattr(datetime.timedelta, 'total_seconds'):
11    def delay(dt): return dt.days*86400 + dt.seconds + 1e-6*dt.microseconds
12else:
13    def delay(dt): return dt.total_seconds()
14
15import numpy as np
16
17try:
18    from .kernelcl import load_model as _loader
19except RuntimeError,exc:
20    import warnings
21    warnings.warn(str(exc))
22    warnings.warn("OpenCL not available --- using ctypes instead")
23    from .kerneldll import load_model as _loader
24
25def load_model(modelname, dtype='single'):
26    """
27    Load model by name.
28    """
29    sasmodels = __import__('sasmodels.models.'+modelname)
30    module = getattr(sasmodels.models, modelname, None)
31    model = _loader(module, dtype=dtype)
32    return model
33
34
35def tic():
36    """
37    Timer function.
38
39    Use "toc=tic()" to start the clock and "toc()" to measure
40    a time interval.
41    """
42    then = datetime.datetime.now()
43    return lambda: delay(datetime.datetime.now()-then)
44
45
46def load_data(filename):
47    """
48    Load data using a sasview loader.
49    """
50    from sas.dataloader.loader import Loader
51    loader = Loader()
52    data = loader.load(filename)
53    if data is None:
54        raise IOError("Data %r could not be loaded"%filename)
55    return data
56
57
58def empty_data1D(q):
59    """
60    Create empty 1D data using the given *q* as the x value.
61
62    Resolutions dq/q is 5%.
63    """
64
65    from sas.dataloader.data_info import Data1D
66
67    Iq = 100*np.ones_like(q)
68    dIq = np.sqrt(Iq)
69    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
70    data.filename = "fake data"
71    data.qmin, data.qmax = q.min(), q.max()
72    return data
73
74
75def empty_data2D(qx, qy=None):
76    """
77    Create empty 2D data using the given mesh.
78
79    If *qy* is missing, create a square mesh with *qy=qx*.
80
81    Resolution dq/q is 5%.
82    """
83    from sas.dataloader.data_info import Data2D, Detector
84
85    if qy is None:
86        qy = qx
87    Qx,Qy = np.meshgrid(qx,qy)
88    Qx,Qy = Qx.flatten(), Qy.flatten()
89    Iq = 100*np.ones_like(Qx)
90    dIq = np.sqrt(Iq)
91    mask = np.ones(len(Iq), dtype='bool')
92
93    data = Data2D()
94    data.filename = "fake data"
95    data.qx_data = Qx
96    data.qy_data = Qy
97    data.data = Iq
98    data.err_data = dIq
99    data.mask = mask
100
101    # 5% dQ/Q resolution
102    data.dqx_data = 0.05*Qx
103    data.dqy_data = 0.05*Qy
104
105    detector = Detector()
106    detector.pixel_size.x = 5 # mm
107    detector.pixel_size.y = 5 # mm
108    detector.distance = 4 # m
109    data.detector.append(detector)
110    data.xbins = qx
111    data.ybins = qy
112    data.source.wavelength = 5 # angstroms
113    data.source.wavelength_unit = "A"
114    data.Q_unit = "1/A"
115    data.I_unit = "1/cm"
116    data.q_data = np.sqrt(Qx**2 + Qy**2)
117    data.xaxis("Q_x", "A^{-1}")
118    data.yaxis("Q_y", "A^{-1}")
119    data.zaxis("Intensity", r"\text{cm}^{-1}")
120    return data
121
122
123def set_beam_stop(data, radius, outer=None):
124    """
125    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
126    """
127    from sas.dataloader.manipulations import Ringcut
128    if hasattr(data, 'qx_data'):
129        data.mask = Ringcut(0, radius)(data)
130        if outer is not None:
131            data.mask += Ringcut(outer,np.inf)(data)
132    else:
133        data.mask = (data.x>=radius)
134        if outer is not None:
135            data.mask &= (data.x<outer)
136
137
138def set_half(data, half):
139    """
140    Select half of the data, either "right" or "left".
141    """
142    from sas.dataloader.manipulations import Boxcut
143    if half == 'right':
144        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
145    if half == 'left':
146        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
147
148
149def set_top(data, max):
150    """
151    Chop the top off the data, above *max*.
152    """
153    from sas.dataloader.manipulations import Boxcut
154    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
155
156
157def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
158    """
159    Plot the target value for the data.  This could be the data itself,
160    the theory calculation, or the residuals.
161
162    *scale* can be 'log' for log scale data, or 'linear'.
163    """
164    from numpy.ma import masked_array, masked
165    import matplotlib.pyplot as plt
166    if hasattr(data, 'qx_data'):
167        iq = iq+0
168        valid = np.isfinite(iq)
169        if scale == 'log':
170            valid[valid] = (iq[valid] > 0)
171            iq[valid] = np.log10(iq[valid])
172        iq[~valid|data.mask] = 0
173        #plottable = iq
174        plottable = masked_array(iq, ~valid|data.mask)
175        xmin, xmax = min(data.qx_data), max(data.qx_data)
176        ymin, ymax = min(data.qy_data), max(data.qy_data)
177        try:
178            if vmin is None: vmin = iq[valid&~data.mask].min()
179            if vmax is None: vmax = iq[valid&~data.mask].max()
180        except:
181            vmin, vmax = 0, 1
182        plt.imshow(plottable.reshape(128,128),
183                   interpolation='nearest', aspect=1, origin='upper',
184                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
185    else: # 1D data
186        if scale == 'linear':
187            idx = np.isfinite(iq)
188            plt.plot(data.x[idx], iq[idx])
189        else:
190            idx = np.isfinite(iq)
191            idx[idx] = (iq[idx]>0)
192            plt.loglog(data.x[idx], iq[idx])
193
194
195def _plot_result1D(data, theory, view):
196    """
197    Plot the data and residuals for 1D data.
198    """
199    import matplotlib.pyplot as plt
200    from numpy.ma import masked_array, masked
201    #print "not a number",sum(np.isnan(data.y))
202    #data.y[data.y<0.05] = 0.5
203    mdata = masked_array(data.y, data.mask)
204    mdata[np.isnan(mdata)] = masked
205    if view is 'log':
206        mdata[mdata <= 0] = masked
207    mtheory = masked_array(theory, mdata.mask)
208    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
209
210    plt.subplot(121)
211    plt.errorbar(data.x, mdata, yerr=data.dy)
212    plt.plot(data.x, mtheory, '-', hold=True)
213    plt.yscale(view)
214    plt.subplot(122)
215    plt.plot(data.x, mresid, 'x')
216    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
217    #plt.axhline(0, color='black', lw=1, hold=True)
218    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
219
220def _plot_sesans(data, theory, view):
221    import matplotlib.pyplot as plt
222    resid = (theory - data.y)/data.dy
223    plt.subplot(121)
224    plt.errorbar(data.x, data.y, yerr=data.dy)
225    plt.plot(data.x, theory, '-', hold=True)
226    plt.xlabel('spin echo length (A)')
227    plt.ylabel('polarization')
228    plt.subplot(122)
229    plt.plot(data.x, resid, 'x')
230    plt.xlabel('spin echo length (A)')
231    plt.ylabel('residuals')
232
233def _plot_result2D(data, theory, view):
234    """
235    Plot the data and residuals for 2D data.
236    """
237    import matplotlib.pyplot as plt
238    resid = (theory-data.data)/data.err_data
239    plt.subplot(131)
240    plot_data(data, data.data, scale=view)
241    plt.colorbar()
242    plt.subplot(132)
243    plot_data(data, theory, scale=view)
244    plt.colorbar()
245    plt.subplot(133)
246    plot_data(data, resid, scale='linear')
247    plt.colorbar()
248
249class BumpsModel(object):
250    """
251    Return a bumps wrapper for a SAS model.
252
253    *data* is the data to be fitted.
254
255    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
256
257    *cutoff* is the integration cutoff, which avoids computing the
258    the SAS model where the polydispersity weight is low.
259
260    Model parameters can be initialized with additional keyword
261    arguments, or by assigning to model.parameter_name.value.
262
263    The resulting bumps model can be used directly in a FitProblem call.
264    """
265    def __init__(self, data, model, cutoff=1e-5, **kw):
266        from bumps.names import Parameter
267
268        # remember inputs so we can inspect from outside
269        self.data = data
270        self.model = model
271        self.cutoff = cutoff
272# TODO       if  isinstance(data,SESANSData1D)       
273        if hasattr(data, 'lam'):
274            self.data_type = 'sesans'
275        elif hasattr(data, 'qx_data'):
276            self.data_type = 'Iqxy'
277        else:
278            self.data_type = 'Iq'
279
280        partype = model.info['partype']
281
282        # interpret data
283        if self.data_type == 'sesans':
284            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
285            self.index = slice(None,None)
286            self.iq = data.y
287            self.diq = data.dy
288            self._theory = np.zeros_like(q)
289            q_vectors = [q]
290        elif self.data_type == 'Iqxy':
291            self.index = (data.mask==0) & (~np.isnan(data.data))
292            self.iq = data.data[self.index]
293            self.diq = data.err_data[self.index]
294            self._theory = np.zeros_like(data.data)
295            if not partype['orientation'] and not partype['magnetic']:
296                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
297            else:
298                q_vectors = [data.qx_data, data.qy_data]
299        elif self.data_type == 'Iq':
300            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
301            self.iq = data.y[self.index]
302            self.diq = data.dy[self.index]
303            self._theory = np.zeros_like(data.y)
304            q_vectors = [data.x]
305        else:
306            raise ValueError("Unknown data type") # never gets here
307
308        # Remember function inputs so we can delay loading the function and
309        # so we can save/restore state
310        self._fn_inputs = [v[self.index] for v in q_vectors]
311        self._fn = None
312
313        # define bumps parameters
314        pars = []
315        for p in model.info['parameters']:
316            name, default, limits, ptype = p[0], p[2], p[3], p[4]
317            value = kw.pop(name, default)
318            setattr(self, name, Parameter.default(value, name=name, limits=limits))
319            pars.append(name)
320        for name in partype['pd-2d']:
321            for xpart,xdefault,xlimits in [
322                    ('_pd', 0, limits),
323                    ('_pd_n', 35, (0,1000)),
324                    ('_pd_nsigma', 3, (0, 10)),
325                    ('_pd_type', 'gaussian', None),
326                ]:
327                xname = name+xpart
328                xvalue = kw.pop(xname, xdefault)
329                if xlimits is not None:
330                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
331                    pars.append(xname)
332                setattr(self, xname, xvalue)
333        self._parameter_names = pars
334        if kw:
335            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
336        self.update()
337
338    def update(self):
339        self._cache = {}
340
341    def numpoints(self):
342        return len(self.iq)
343
344    def parameters(self):
345        return dict((k,getattr(self,k)) for k in self._parameter_names)
346
347    def theory(self):
348        if 'theory' not in self._cache:
349            if self._fn is None:
350                input = self.model.make_input(self._fn_inputs)
351                self._fn = self.model(input)
352
353            fixed_pars = [getattr(self,p).value for p in self._fn.fixed_pars]
354            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
355            #print fixed_pars,pd_pars
356            self._theory[self.index] = self._fn(fixed_pars, pd_pars, self.cutoff)
357            #self._theory[:] = self._fn.eval(pars, pd_pars)
358            if self.data_type == 'sesans':
359                P = sesans.hankel(self.data.x, self.data.lam*1e-9,
360                                  self.data.sample.thickness/10, self._fn_inputs[0],
361                                  self._theory)
362                self._cache['theory'] = P
363            else:
364                self._cache['theory'] = self._theory
365        return self._cache['theory']
366
367    def residuals(self):
368        #if np.any(self.err ==0): print "zeros in err"
369        return (self.theory()[self.index]-self.iq)/self.diq
370
371    def nllf(self):
372        R = self.residuals()
373        #if np.any(np.isnan(R)): print "NaN in residuals"
374        return 0.5*np.sum(R**2)
375
376    def __call__(self):
377        return 2*self.nllf()/self.dof
378
379    def plot(self, view='log'):
380        """
381        Plot the data and residuals.
382        """
383        data, theory = self.data, self.theory()
384        if self.data_type == 'Iq':
385            _plot_result1D(data, theory, view)
386        elif self.data_type == 'Iqxy':
387            _plot_result2D(data, theory, view)
388        elif self.data_type == 'sesans':
389            _plot_sesans(data, theory, view)
390        else:
391            raise ValueError("Unknown data type")
392
393    def simulate_data(self, noise=None):
394        print "noise", noise
395        if noise is None:
396            noise = self.diq[self.index]
397        else:
398            noise = 0.01*noise
399            self.diq[self.index] = noise
400        y = self.theory()
401        y += y*np.random.randn(*y.shape)*noise
402        if self.data_type == 'Iq':
403            self.data.y[self.index] = y
404        elif self.data_type == 'Iqxy':
405            self.data.data[self.index] = y
406        elif self.data_type == 'sesans':
407            self.data.y[self.index] = y
408        else:
409            raise ValueError("Unknown model")
410
411    def save(self, basename):
412        pass
413
414    def _get_weights(self, par):
415        from . import weights
416
417        relative = self.model.info['partype']['pd-rel']
418        limits = self.model.info['limits']
419        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
420                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
421        v,w = weights.get_weights(
422            disperser, int(npts.value), width.value, nsigma.value,
423            value.value, limits[par], par in relative)
424        return v,w/w.max()
425
426    def __getstate__(self):
427        # Can't pickle gpu functions, so instead make them lazy
428        state = self.__dict__.copy()
429        state['_fn'] = None
430        return state
431
432    def __setstate__(self, state):
433        self.__dict__ = state
434
435
436def demo():
437    data = load_data('DEC07086.DAT')
438    set_beam_stop(data, 0.004)
439    plot_data(data)
440    import matplotlib.pyplot as plt; plt.show()
441
442
443if __name__ == "__main__":
444    demo()
Note: See TracBrowser for help on using the repository browser.