source: sasmodels/sasmodels/bumps_model.py @ c97724e

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since c97724e was c97724e, checked in by pkienzle, 9 years ago

add sesans support to bumps model

  • Property mode set to 100644
File size: 13.5 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7from sasmodels import sesans
8
9# CRUFT python 2.6
10if not hasattr(datetime.timedelta, 'total_seconds'):
11    def delay(dt): return dt.days*86400 + dt.seconds + 1e-6*dt.microseconds
12else:
13    def delay(dt): return dt.total_seconds()
14
15import numpy as np
16
17try:
18    from .kernelcl import load_model as _loader
19except ImportError,exc:
20    import warnings
21    warnings.warn(str(exc))
22    warnings.warn("OpenCL not available --- using ctypes instead")
23    from .kerneldll import load_model as _loader
24
25def load_model(modelname, dtype='single'):
26    """
27    Load model by name.
28    """
29    sasmodels = __import__('sasmodels.models.'+modelname)
30    module = getattr(sasmodels.models, modelname, None)
31    model = _loader(module, dtype=dtype)
32    return model
33
34
35def tic():
36    """
37    Timer function.
38
39    Use "toc=tic()" to start the clock and "toc()" to measure
40    a time interval.
41    """
42    then = datetime.datetime.now()
43    return lambda: delay(datetime.datetime.now()-then)
44
45
46def load_data(filename):
47    """
48    Load data using a sasview loader.
49    """
50    from sas.dataloader.loader import Loader
51    loader = Loader()
52    data = loader.load(filename)
53    if data is None:
54        raise IOError("Data %r could not be loaded"%filename)
55    return data
56
57
58def empty_data1D(q):
59    """
60    Create empty 1D data using the given *q* as the x value.
61
62    Resolutions dq/q is 5%.
63    """
64
65    from sas.dataloader.data_info import Data1D
66
67    Iq = 100*np.ones_like(q)
68    dIq = np.sqrt(Iq)
69    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
70    data.filename = "fake data"
71    data.qmin, data.qmax = q.min(), q.max()
72    return data
73
74
75def empty_data2D(qx, qy=None):
76    """
77    Create empty 2D data using the given mesh.
78
79    If *qy* is missing, create a square mesh with *qy=qx*.
80
81    Resolution dq/q is 5%.
82    """
83    from sas.dataloader.data_info import Data2D, Detector
84
85    if qy is None:
86        qy = qx
87    Qx,Qy = np.meshgrid(qx,qy)
88    Qx,Qy = Qx.flatten(), Qy.flatten()
89    Iq = 100*np.ones_like(Qx)
90    dIq = np.sqrt(Iq)
91    mask = np.ones(len(Iq), dtype='bool')
92
93    data = Data2D()
94    data.filename = "fake data"
95    data.qx_data = Qx
96    data.qy_data = Qy
97    data.data = Iq
98    data.err_data = dIq
99    data.mask = mask
100
101    # 5% dQ/Q resolution
102    data.dqx_data = 0.05*Qx
103    data.dqy_data = 0.05*Qy
104
105    detector = Detector()
106    detector.pixel_size.x = 5 # mm
107    detector.pixel_size.y = 5 # mm
108    detector.distance = 4 # m
109    data.detector.append(detector)
110    data.xbins = qx
111    data.ybins = qy
112    data.source.wavelength = 5 # angstroms
113    data.source.wavelength_unit = "A"
114    data.Q_unit = "1/A"
115    data.I_unit = "1/cm"
116    data.q_data = np.sqrt(Qx**2 + Qy**2)
117    data.xaxis("Q_x", "A^{-1}")
118    data.yaxis("Q_y", "A^{-1}")
119    data.zaxis("Intensity", r"\text{cm}^{-1}")
120    return data
121
122
123def set_beam_stop(data, radius, outer=None):
124    """
125    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
126    """
127    from sas.dataloader.manipulations import Ringcut
128    if hasattr(data, 'qx_data'):
129        data.mask = Ringcut(0, radius)(data)
130        if outer is not None:
131            data.mask += Ringcut(outer,np.inf)(data)
132    else:
133        data.mask = (data.x>=radius)
134        if outer is not None:
135            data.mask &= (data.x<outer)
136
137
138def set_half(data, half):
139    """
140    Select half of the data, either "right" or "left".
141    """
142    from sas.dataloader.manipulations import Boxcut
143    if half == 'right':
144        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
145    if half == 'left':
146        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
147
148
149def set_top(data, max):
150    """
151    Chop the top off the data, above *max*.
152    """
153    from sas.dataloader.manipulations import Boxcut
154    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
155
156
157def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
158    """
159    Plot the target value for the data.  This could be the data itself,
160    the theory calculation, or the residuals.
161
162    *scale* can be 'log' for log scale data, or 'linear'.
163    """
164    from numpy.ma import masked_array, masked
165    import matplotlib.pyplot as plt
166    if hasattr(data, 'qx_data'):
167        iq = iq[:]
168        valid = np.isfinite(iq)
169        if scale == 'log':
170            valid[valid] = (iq[valid] > 0)
171            iq[valid] = np.log10(iq[valid])
172        iq[~valid|data.mask] = 0
173        #plottable = iq
174        plottable = masked_array(iq, ~valid|data.mask)
175        xmin, xmax = min(data.qx_data), max(data.qx_data)
176        ymin, ymax = min(data.qy_data), max(data.qy_data)
177        if vmin is None: vmin = iq[valid&~data.mask].min()
178        if vmax is None: vmax = iq[valid&~data.mask].max()
179        plt.imshow(plottable.reshape(128,128),
180                   interpolation='nearest', aspect=1, origin='upper',
181                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
182    else: # 1D data
183        if scale == 'linear':
184            idx = np.isfinite(iq)
185            plt.plot(data.x[idx], iq[idx])
186        else:
187            idx = np.isfinite(iq)
188            idx[idx] = (iq[idx]>0)
189            plt.loglog(data.x[idx], iq[idx])
190
191
192def _plot_result1D(data, theory, view):
193    """
194    Plot the data and residuals for 1D data.
195    """
196    import matplotlib.pyplot as plt
197    from numpy.ma import masked_array, masked
198    #print "not a number",sum(np.isnan(data.y))
199    #data.y[data.y<0.05] = 0.5
200    mdata = masked_array(data.y, data.mask)
201    mdata[np.isnan(mdata)] = masked
202    if view is 'log':
203        mdata[mdata <= 0] = masked
204    mtheory = masked_array(theory, mdata.mask)
205    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
206
207    plt.subplot(121)
208    plt.errorbar(data.x, mdata, yerr=data.dy)
209    plt.plot(data.x, mtheory, '-', hold=True)
210    plt.yscale(view)
211    plt.subplot(122)
212    plt.plot(data.x, mresid, 'x')
213    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
214    #plt.axhline(0, color='black', lw=1, hold=True)
215    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
216
217def _plot_sesans(data, theory, view):
218    import matplotlib.pyplot as plt
219    resid = (theory - data.data)/data.err_data
220    plt.subplot(121)
221    plt.errorbar(data.SElength, data.data, yerr=data.err_data)
222    plt.plot(data.SElength, theory, '-', hold=True)
223    plt.xlabel('spin echo length (A)')
224    plt.ylabel('polarization')
225    plt.subplot(122)
226    plt.plot(data.SElength, resid, 'x')
227    plt.xlabel('spin echo length (A)')
228    plt.ylabel('residuals')
229
230def _plot_result2D(data, theory, view):
231    """
232    Plot the data and residuals for 2D data.
233    """
234    import matplotlib.pyplot as plt
235    resid = (theory-data.data)/data.err_data
236    plt.subplot(131)
237    plot_data(data, data.data, scale=view)
238    plt.colorbar()
239    plt.subplot(132)
240    plot_data(data, theory, scale=view)
241    plt.colorbar()
242    plt.subplot(133)
243    plot_data(data, resid, scale='linear')
244    plt.colorbar()
245
246class BumpsModel(object):
247    """
248    Return a bumps wrapper for a SAS model.
249
250    *data* is the data to be fitted.
251
252    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
253
254    *cutoff* is the integration cutoff, which avoids computing the
255    the SAS model where the polydispersity weight is low.
256
257    Model parameters can be initialized with additional keyword
258    arguments, or by assigning to model.parameter_name.value.
259
260    The resulting bumps model can be used directly in a FitProblem call.
261    """
262    def __init__(self, data, model, cutoff=1e-5, **kw):
263        from bumps.names import Parameter
264
265        # remember inputs so we can inspect from outside
266        self.data = data
267        self.model = model
268        self.cutoff = cutoff
269        if hasattr(data, 'SElength'):
270            self.data_type = 'sesans'
271        elif hasattr(data, 'qx_data'):
272            self.data_type = 'Iqxy'
273        else:
274            self.data_type = 'Iq'
275
276        partype = model.info['partype']
277
278        # interpret data
279        if self.data_type == 'sesans':
280            q = sesans.make_q(data.q_zmax, data.Rmax)
281            self.index = slice(None,None)
282            self.iq = data.data
283            self.diq = data.err_data
284            self._theory = np.zeros_like(q)
285            q_vectors = [q]
286        elif self.data_type == 'Iqxy':
287            self.index = (data.mask==0) & (~np.isnan(data.data))
288            self.iq = data.data[self.index]
289            self.diq = data.err_data[self.index]
290            self._theory = np.zeros_like(data.data)
291            if not partype['orientation'] and not partype['magnetic']:
292                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
293            else:
294                q_vectors = [data.qx_data, data.qy_data]
295        elif self.data_type == 'Iq':
296            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
297            self.iq = data.y[self.index]
298            self.diq = data.dy[self.index]
299            self._theory = np.zeros_like(data.y)
300            q_vectors = [data.x]
301        else:
302            raise ValueError("Unknown data type") # never gets here
303
304        # Remember function inputs so we can delay loading the function and
305        # so we can save/restore state
306        self._fn_inputs = [v[self.index] for v in q_vectors]
307        self._fn = None
308
309        # define bumps parameters
310        pars = []
311        for p in model.info['parameters']:
312            name, default, limits, ptype = p[0], p[2], p[3], p[4]
313            value = kw.pop(name, default)
314            setattr(self, name, Parameter.default(value, name=name, limits=limits))
315            pars.append(name)
316        for name in partype['pd-2d']:
317            for xpart,xdefault,xlimits in [
318                    ('_pd', 0, limits),
319                    ('_pd_n', 35, (0,1000)),
320                    ('_pd_nsigma', 3, (0, 10)),
321                    ('_pd_type', 'gaussian', None),
322                ]:
323                xname = name+xpart
324                xvalue = kw.pop(xname, xdefault)
325                if xlimits is not None:
326                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
327                    pars.append(xname)
328                setattr(self, xname, xvalue)
329        self._parameter_names = pars
330        if kw:
331            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
332        self.update()
333
334    def update(self):
335        self._cache = {}
336
337    def numpoints(self):
338        return len(self.iq)
339
340    def parameters(self):
341        return dict((k,getattr(self,k)) for k in self._parameter_names)
342
343    def theory(self):
344        if 'theory' not in self._cache:
345            if self._fn is None:
346                input = self.model.make_input(self._fn_inputs)
347                self._fn = self.model(input)
348
349            pars = [getattr(self,p).value for p in self._fn.fixed_pars]
350            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
351            #print pars
352            self._theory[self.index] = self._fn(pars, pd_pars, self.cutoff)
353            #self._theory[:] = self._fn.eval(pars, pd_pars)
354            if self.data_type == 'sesans':
355                P = sesans.hankel(self.data.SElength, self.data.wavelength,
356                                  self.data.thickness, self._fn_inputs[0],
357                                  self._theory)
358                self._cache['theory'] = P
359            else:
360                self._cache['theory'] = self._theory
361        return self._cache['theory']
362
363    def residuals(self):
364        #if np.any(self.err ==0): print "zeros in err"
365        return (self.theory()[self.index]-self.iq)/self.diq
366
367    def nllf(self):
368        R = self.residuals()
369        #if np.any(np.isnan(R)): print "NaN in residuals"
370        return 0.5*np.sum(R**2)
371
372    def __call__(self):
373        return 2*self.nllf()/self.dof
374
375    def plot(self, view='log'):
376        """
377        Plot the data and residuals.
378        """
379        data, theory = self.data, self.theory()
380        if self.data_type == 'Iq':
381            _plot_result1D(data, theory, view)
382        elif self.data_type == 'Iqxy':
383            _plot_result2D(data, theory, view)
384        elif self.data_type == 'sesans':
385            _plot_sesans(data, theory, view)
386        else:
387            raise ValueError("Unknown data type")
388
389    def simulate_data(self, noise=None):
390        print "noise", noise
391        if noise is None:
392            noise = self.diq[self.index]
393        else:
394            noise = 0.01*noise
395            self.diq[self.index] = noise
396        y = self.theory()
397        y += y*np.random.randn(*y.shape)*noise
398        if self.data_type == 'Iq':
399            self.data.y[self.index] = y
400        elif self.data_type == 'Iqxy':
401            self.data.data[self.index] = y
402        elif self.data_type == 'sesans':
403            self.data.data[self.index] = y
404        else:
405            raise ValueError("Unknown model")
406
407    def save(self, basename):
408        pass
409
410    def _get_weights(self, par):
411        from . import weights
412
413        relative = self.model.info['partype']['pd-rel']
414        limits = self.model.info['limits']
415        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
416                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
417        v,w = weights.get_weights(
418            disperser, int(npts.value), width.value, nsigma.value,
419            value.value, limits[par], par in relative)
420        return v,w/w.max()
421
422    def __getstate__(self):
423        # Can't pickle gpu functions, so instead make them lazy
424        state = self.__dict__.copy()
425        state['_fn'] = None
426        return state
427
428    def __setstate__(self, state):
429        self.__dict__ = state
430
431
432def demo():
433    data = load_data('DEC07086.DAT')
434    set_beam_stop(data, 0.004)
435    plot_data(data)
436    import matplotlib.pyplot as plt; plt.show()
437
438
439if __name__ == "__main__":
440    demo()
Note: See TracBrowser for help on using the repository browser.