source: sasmodels/sasmodels/bumps_model.py @ abb22f4

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since abb22f4 was abb22f4, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

support deepcopy on bumps model

  • Property mode set to 100644
File size: 11.2 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7import numpy as np
8
9try:
10    from .gpu import load_model as _loader
11except ImportError,exc:
12    warnings.warn(str(exc))
13    warnings.warn("OpenCL not available --- using ctypes instead")
14    from .dll import load_model as _loader
15
16def load_model(modelname, dtype='single'):
17    """
18    Load model by name.
19    """
20    sasmodels = __import__('sasmodels.models.'+modelname)
21    module = getattr(sasmodels.models, modelname, None)
22    model = _loader(module, dtype=dtype)
23    return model
24
25
26def tic():
27    """
28    Timer function.
29
30    Use "toc=tic()" to start the clock and "toc()" to measure
31    a time interval.
32    """
33    then = datetime.datetime.now()
34    return lambda: (datetime.datetime.now()-then).total_seconds()
35
36
37def load_data(filename):
38    """
39    Load data using a sasview loader.
40    """
41    from sans.dataloader.loader import Loader
42    loader = Loader()
43    data = loader.load(filename)
44    if data is None:
45        raise IOError("Data %r could not be loaded"%filename)
46    return data
47
48
49def empty_data1D(q):
50    """
51    Create empty 1D data using the given *q* as the x value.
52
53    Resolutions dq/q is 5%.
54    """
55
56    from sans.dataloader.data_info import Data1D
57
58    Iq = 100*np.ones_like(q)
59    dIq = np.sqrt(Iq)
60    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
61    data.filename = "fake data"
62    data.qmin, data.qmax = q.min(), q.max()
63    return data
64
65
66def empty_data2D(qx, qy=None):
67    """
68    Create empty 2D data using the given mesh.
69
70    If *qy* is missing, create a square mesh with *qy=qx*.
71
72    Resolution dq/q is 5%.
73    """
74    from sans.dataloader.data_info import Data2D, Detector
75
76    if qy is None:
77        qy = qx
78    Qx,Qy = np.meshgrid(qx,qy)
79    Qx,Qy = Qx.flatten(), Qy.flatten()
80    Iq = 100*np.ones_like(Qx)
81    dIq = np.sqrt(Iq)
82    mask = np.ones(len(Iq), dtype='bool')
83
84    data = Data2D()
85    data.filename = "fake data"
86    data.qx_data = Qx
87    data.qy_data = Qy
88    data.data = Iq
89    data.err_data = dIq
90    data.mask = mask
91
92    # 5% dQ/Q resolution
93    data.dqx_data = 0.05*Qx
94    data.dqy_data = 0.05*Qy
95
96    detector = Detector()
97    detector.pixel_size.x = 5 # mm
98    detector.pixel_size.y = 5 # mm
99    detector.distance = 4 # m
100    data.detector.append(detector)
101    data.xbins = qx
102    data.ybins = qy
103    data.source.wavelength = 5 # angstroms
104    data.source.wavelength_unit = "A"
105    data.Q_unit = "1/A"
106    data.I_unit = "1/cm"
107    data.q_data = np.sqrt(Qx**2 + Qy**2)
108    data.xaxis("Q_x", "A^{-1}")
109    data.yaxis("Q_y", "A^{-1}")
110    data.zaxis("Intensity", r"\text{cm}^{-1}")
111    return data
112
113
114def set_beam_stop(data, radius, outer=None):
115    """
116    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
117    """
118    from sans.dataloader.manipulations import Ringcut
119    if hasattr(data, 'qx_data'):
120        data.mask = Ringcut(0, radius)(data)
121        if outer is not None:
122            data.mask += Ringcut(outer,np.inf)(data)
123    else:
124        data.mask = (data.x>=radius)
125        if outer is not None:
126            data.mask &= (data.x<outer)
127
128
129def set_half(data, half):
130    """
131    Select half of the data, either "right" or "left".
132    """
133    from sans.dataloader.manipulations import Boxcut
134    if half == 'right':
135        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
136    if half == 'left':
137        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
138
139
140def set_top(data, max):
141    """
142    Chop the top off the data, above *max*.
143    """
144    from sans.dataloader.manipulations import Boxcut
145    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
146
147
148def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
149    """
150    Plot the target value for the data.  This could be the data itself,
151    the theory calculation, or the residuals.
152
153    *scale* can be 'log' for log scale data, or 'linear'.
154    """
155    from numpy.ma import masked_array, masked
156    import matplotlib.pyplot as plt
157    if hasattr(data, 'qx_data'):
158        iq = iq[:]
159        valid = np.isfinite(iq)
160        if scale == 'log':
161            valid[valid] = (iq[valid] > 0)
162            iq[valid] = np.log10(iq[valid])
163        iq[~valid|data.mask] = 0
164        #plottable = iq
165        plottable = masked_array(iq, ~valid|data.mask)
166        xmin, xmax = min(data.qx_data), max(data.qx_data)
167        ymin, ymax = min(data.qy_data), max(data.qy_data)
168        plt.imshow(plottable.reshape(128,128),
169                   interpolation='nearest', aspect=1, origin='upper',
170                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
171    else: # 1D data
172        if scale == 'linear':
173            idx = np.isfinite(iq)
174            plt.plot(data.x[idx], iq[idx])
175        else:
176            idx = np.isfinite(iq)
177            idx[idx] = (iq[idx]>0)
178            plt.loglog(data.x[idx], iq[idx])
179
180
181def _plot_result1D(data, theory, view):
182    """
183    Plot the data and residuals for 1D data.
184    """
185    import matplotlib.pyplot as plt
186    from numpy.ma import masked_array, masked
187    #print "not a number",sum(np.isnan(data.y))
188    #data.y[data.y<0.05] = 0.5
189    mdata = masked_array(data.y, data.mask)
190    mdata[np.isnan(mdata)] = masked
191    if view is 'log':
192        mdata[mdata <= 0] = masked
193    mtheory = masked_array(theory, mdata.mask)
194    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
195
196    plt.subplot(121)
197    plt.errorbar(data.x, mdata, yerr=data.dy)
198    plt.plot(data.x, mtheory, '-', hold=True)
199    plt.yscale(view)
200    plt.subplot(122)
201    plt.plot(data.x, mresid, 'x')
202    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
203    #plt.axhline(0, color='black', lw=1, hold=True)
204    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
205
206
207def _plot_result2D(data, theory, view):
208    """
209    Plot the data and residuals for 2D data.
210    """
211    import matplotlib.pyplot as plt
212    resid = (theory-data.data)/data.err_data
213    plt.subplot(131)
214    plot_data(data, data.data, scale=view)
215    plt.colorbar()
216    plt.subplot(132)
217    plot_data(data, theory, scale=view)
218    plt.colorbar()
219    plt.subplot(133)
220    plot_data(data, resid, scale='linear')
221    plt.colorbar()
222
223def plot_result(data, theory, view='log'):
224    """
225    Plot the data and residuals.
226    """
227    if hasattr(data, 'qx_data'):
228        _plot_result2D(data, theory, view)
229    else:
230        _plot_result1D(data, theory, view)
231
232
233class BumpsModel(object):
234    """
235    Return a bumps wrapper for a SAS model.
236
237    *data* is the data to be fitted.
238
239    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
240
241    *cutoff* is the integration cutoff, which avoids computing the
242    the SAS model where the polydispersity weight is low.
243
244    Model parameters can be initialized with additional keyword
245    arguments, or by assigning to model.parameter_name.value.
246
247    The resulting bumps model can be used directly in a FitProblem call.
248    """
249    def __init__(self, data, model, cutoff=1e-5, **kw):
250        from bumps.names import Parameter
251
252        # remember inputs so we can inspect from outside
253        self.data = data
254        self.model = model
255        self.cutoff = cutoff
256
257        partype = model.info['partype']
258
259        # interpret data
260        if hasattr(data, 'qx_data'):
261            self.index = (data.mask==0) & (~np.isnan(data.data))
262            self.iq = data.data[self.index]
263            self.diq = data.err_data[self.index]
264            self._theory = np.zeros_like(data.data)
265            if not partype['orientation'] and not partype['magnetic']:
266                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
267            else:
268                q_vectors = [data.qx_data, data.qy_data]
269        else:
270            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
271            self.iq = data.y[self.index]
272            self.diq = data.dy[self.index]
273            self._theory = np.zeros_like(data.y)
274            q_vectors = [data.x]
275
276        # Remember function inputs so we can delay loading the function and
277        # so we can save/restore state
278        self._fn_inputs = [v[self.index] for v in q_vectors]
279        self._fn = None
280
281        # define bumps parameters
282        pars = []
283        for p in model.info['parameters']:
284            name, default, limits, ptype = p[0], p[2], p[3], p[4]
285            value = kw.pop(name, default)
286            setattr(self, name, Parameter.default(value, name=name, limits=limits))
287            pars.append(name)
288        for name in partype['pd-2d']:
289            for xpart,xdefault,xlimits in [
290                    ('_pd', 0, limits),
291                    ('_pd_n', 35, (0,1000)),
292                    ('_pd_nsigma', 3, (0, 10)),
293                    ('_pd_type', 'gaussian', None),
294                ]:
295                xname = name+xpart
296                xvalue = kw.pop(xname, xdefault)
297                if xlimits is not None:
298                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
299                    pars.append(xname)
300                setattr(self, xname, xvalue)
301        self._parameter_names = pars
302        if kw:
303            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
304        self.update()
305
306    def update(self):
307        self._cache = {}
308
309    def numpoints(self):
310        return len(self.iq)
311
312    def parameters(self):
313        return dict((k,getattr(self,k)) for k in self._parameter_names)
314
315    def theory(self):
316        if 'theory' not in self._cache:
317            if self._fn is None:
318                input = self.model.make_input(self._fn_inputs)
319                self._fn = self.model(input)
320
321            pars = [getattr(self,p).value for p in self._fn.fixed_pars]
322            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
323            #print pars
324            self._theory[self.index] = self._fn(pars, pd_pars, self.cutoff)
325            #self._theory[:] = self._fn.eval(pars, pd_pars)
326            self._cache['theory'] = self._theory
327        return self._cache['theory']
328
329    def residuals(self):
330        #if np.any(self.err ==0): print "zeros in err"
331        return (self.theory()[self.index]-self.iq)/self.diq
332
333    def nllf(self):
334        R = self.residuals()
335        #if np.any(np.isnan(R)): print "NaN in residuals"
336        return 0.5*np.sum(R**2)
337
338    def __call__(self):
339        return 2*self.nllf()/self.dof
340
341    def plot(self, view='log'):
342        plot_result(self.data, self.theory(), view=view)
343
344    def save(self, basename):
345        pass
346
347    def _get_weights(self, par):
348        from . import weights
349
350        relative = self.model.info['partype']['pd-rel']
351        limits = self.model.info['limits']
352        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
353                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
354        v,w = weights.get_weights(
355            disperser, int(npts.value), width.value, nsigma.value,
356            value.value, limits[par], par in relative)
357        return v,w/w.max()
358
359    def __getstate__(self):
360        # Can't pickle gpu functions, so instead make them lazy
361        state = self.__dict__.copy()
362        state['_fn'] = None
363        return state
364
365    def __setstate__(self, state):
366        self.__dict__ = state
367
368
369def demo():
370    data = load_data('DEC07086.DAT')
371    set_beam_stop(data, 0.004)
372    plot_data(data)
373    import matplotlib.pyplot as plt; plt.show()
374
375
376if __name__ == "__main__":
377    demo()
Note: See TracBrowser for help on using the repository browser.