source: sasmodels/sasmodels/bumps_model.py @ 78356b31

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 78356b31 was 78356b31, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

fix warning that pyopencl failed

  • Property mode set to 100644
File size: 11.2 KB
Line 
1"""
2Sasmodels core.
3"""
4import sys, os
5import datetime
6
7import numpy as np
8
9try:
10    from .gpu import load_model as _loader
11except ImportError,exc:
12    import warnings
13    warnings.warn(str(exc))
14    warnings.warn("OpenCL not available --- using ctypes instead")
15    from .dll import load_model as _loader
16
17def load_model(modelname, dtype='single'):
18    """
19    Load model by name.
20    """
21    sasmodels = __import__('sasmodels.models.'+modelname)
22    module = getattr(sasmodels.models, modelname, None)
23    model = _loader(module, dtype=dtype)
24    return model
25
26
27def tic():
28    """
29    Timer function.
30
31    Use "toc=tic()" to start the clock and "toc()" to measure
32    a time interval.
33    """
34    then = datetime.datetime.now()
35    return lambda: (datetime.datetime.now()-then).total_seconds()
36
37
38def load_data(filename):
39    """
40    Load data using a sasview loader.
41    """
42    from sans.dataloader.loader import Loader
43    loader = Loader()
44    data = loader.load(filename)
45    if data is None:
46        raise IOError("Data %r could not be loaded"%filename)
47    return data
48
49
50def empty_data1D(q):
51    """
52    Create empty 1D data using the given *q* as the x value.
53
54    Resolutions dq/q is 5%.
55    """
56
57    from sans.dataloader.data_info import Data1D
58
59    Iq = 100*np.ones_like(q)
60    dIq = np.sqrt(Iq)
61    data = Data1D(q, Iq, dx=0.05*q, dy=dIq)
62    data.filename = "fake data"
63    data.qmin, data.qmax = q.min(), q.max()
64    return data
65
66
67def empty_data2D(qx, qy=None):
68    """
69    Create empty 2D data using the given mesh.
70
71    If *qy* is missing, create a square mesh with *qy=qx*.
72
73    Resolution dq/q is 5%.
74    """
75    from sans.dataloader.data_info import Data2D, Detector
76
77    if qy is None:
78        qy = qx
79    Qx,Qy = np.meshgrid(qx,qy)
80    Qx,Qy = Qx.flatten(), Qy.flatten()
81    Iq = 100*np.ones_like(Qx)
82    dIq = np.sqrt(Iq)
83    mask = np.ones(len(Iq), dtype='bool')
84
85    data = Data2D()
86    data.filename = "fake data"
87    data.qx_data = Qx
88    data.qy_data = Qy
89    data.data = Iq
90    data.err_data = dIq
91    data.mask = mask
92
93    # 5% dQ/Q resolution
94    data.dqx_data = 0.05*Qx
95    data.dqy_data = 0.05*Qy
96
97    detector = Detector()
98    detector.pixel_size.x = 5 # mm
99    detector.pixel_size.y = 5 # mm
100    detector.distance = 4 # m
101    data.detector.append(detector)
102    data.xbins = qx
103    data.ybins = qy
104    data.source.wavelength = 5 # angstroms
105    data.source.wavelength_unit = "A"
106    data.Q_unit = "1/A"
107    data.I_unit = "1/cm"
108    data.q_data = np.sqrt(Qx**2 + Qy**2)
109    data.xaxis("Q_x", "A^{-1}")
110    data.yaxis("Q_y", "A^{-1}")
111    data.zaxis("Intensity", r"\text{cm}^{-1}")
112    return data
113
114
115def set_beam_stop(data, radius, outer=None):
116    """
117    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
118    """
119    from sans.dataloader.manipulations import Ringcut
120    if hasattr(data, 'qx_data'):
121        data.mask = Ringcut(0, radius)(data)
122        if outer is not None:
123            data.mask += Ringcut(outer,np.inf)(data)
124    else:
125        data.mask = (data.x>=radius)
126        if outer is not None:
127            data.mask &= (data.x<outer)
128
129
130def set_half(data, half):
131    """
132    Select half of the data, either "right" or "left".
133    """
134    from sans.dataloader.manipulations import Boxcut
135    if half == 'right':
136        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
137    if half == 'left':
138        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
139
140
141def set_top(data, max):
142    """
143    Chop the top off the data, above *max*.
144    """
145    from sans.dataloader.manipulations import Boxcut
146    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
147
148
149def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
150    """
151    Plot the target value for the data.  This could be the data itself,
152    the theory calculation, or the residuals.
153
154    *scale* can be 'log' for log scale data, or 'linear'.
155    """
156    from numpy.ma import masked_array, masked
157    import matplotlib.pyplot as plt
158    if hasattr(data, 'qx_data'):
159        iq = iq[:]
160        valid = np.isfinite(iq)
161        if scale == 'log':
162            valid[valid] = (iq[valid] > 0)
163            iq[valid] = np.log10(iq[valid])
164        iq[~valid|data.mask] = 0
165        #plottable = iq
166        plottable = masked_array(iq, ~valid|data.mask)
167        xmin, xmax = min(data.qx_data), max(data.qx_data)
168        ymin, ymax = min(data.qy_data), max(data.qy_data)
169        plt.imshow(plottable.reshape(128,128),
170                   interpolation='nearest', aspect=1, origin='upper',
171                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
172    else: # 1D data
173        if scale == 'linear':
174            idx = np.isfinite(iq)
175            plt.plot(data.x[idx], iq[idx])
176        else:
177            idx = np.isfinite(iq)
178            idx[idx] = (iq[idx]>0)
179            plt.loglog(data.x[idx], iq[idx])
180
181
182def _plot_result1D(data, theory, view):
183    """
184    Plot the data and residuals for 1D data.
185    """
186    import matplotlib.pyplot as plt
187    from numpy.ma import masked_array, masked
188    #print "not a number",sum(np.isnan(data.y))
189    #data.y[data.y<0.05] = 0.5
190    mdata = masked_array(data.y, data.mask)
191    mdata[np.isnan(mdata)] = masked
192    if view is 'log':
193        mdata[mdata <= 0] = masked
194    mtheory = masked_array(theory, mdata.mask)
195    mresid = masked_array((theory-data.y)/data.dy, mdata.mask)
196
197    plt.subplot(121)
198    plt.errorbar(data.x, mdata, yerr=data.dy)
199    plt.plot(data.x, mtheory, '-', hold=True)
200    plt.yscale(view)
201    plt.subplot(122)
202    plt.plot(data.x, mresid, 'x')
203    #plt.axhline(1, color='black', ls='--',lw=1, hold=True)
204    #plt.axhline(0, color='black', lw=1, hold=True)
205    #plt.axhline(-1, color='black', ls='--',lw=1, hold=True)
206
207
208def _plot_result2D(data, theory, view):
209    """
210    Plot the data and residuals for 2D data.
211    """
212    import matplotlib.pyplot as plt
213    resid = (theory-data.data)/data.err_data
214    plt.subplot(131)
215    plot_data(data, data.data, scale=view)
216    plt.colorbar()
217    plt.subplot(132)
218    plot_data(data, theory, scale=view)
219    plt.colorbar()
220    plt.subplot(133)
221    plot_data(data, resid, scale='linear')
222    plt.colorbar()
223
224def plot_result(data, theory, view='log'):
225    """
226    Plot the data and residuals.
227    """
228    if hasattr(data, 'qx_data'):
229        _plot_result2D(data, theory, view)
230    else:
231        _plot_result1D(data, theory, view)
232
233
234class BumpsModel(object):
235    """
236    Return a bumps wrapper for a SAS model.
237
238    *data* is the data to be fitted.
239
240    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
241
242    *cutoff* is the integration cutoff, which avoids computing the
243    the SAS model where the polydispersity weight is low.
244
245    Model parameters can be initialized with additional keyword
246    arguments, or by assigning to model.parameter_name.value.
247
248    The resulting bumps model can be used directly in a FitProblem call.
249    """
250    def __init__(self, data, model, cutoff=1e-5, **kw):
251        from bumps.names import Parameter
252
253        # remember inputs so we can inspect from outside
254        self.data = data
255        self.model = model
256        self.cutoff = cutoff
257
258        partype = model.info['partype']
259
260        # interpret data
261        if hasattr(data, 'qx_data'):
262            self.index = (data.mask==0) & (~np.isnan(data.data))
263            self.iq = data.data[self.index]
264            self.diq = data.err_data[self.index]
265            self._theory = np.zeros_like(data.data)
266            if not partype['orientation'] and not partype['magnetic']:
267                q_vectors = [np.sqrt(data.qx_data**2+data.qy_data**2)]
268            else:
269                q_vectors = [data.qx_data, data.qy_data]
270        else:
271            self.index = (data.x>=data.qmin) & (data.x<=data.qmax) & ~np.isnan(data.y)
272            self.iq = data.y[self.index]
273            self.diq = data.dy[self.index]
274            self._theory = np.zeros_like(data.y)
275            q_vectors = [data.x]
276
277        # Remember function inputs so we can delay loading the function and
278        # so we can save/restore state
279        self._fn_inputs = [v[self.index] for v in q_vectors]
280        self._fn = None
281
282        # define bumps parameters
283        pars = []
284        for p in model.info['parameters']:
285            name, default, limits, ptype = p[0], p[2], p[3], p[4]
286            value = kw.pop(name, default)
287            setattr(self, name, Parameter.default(value, name=name, limits=limits))
288            pars.append(name)
289        for name in partype['pd-2d']:
290            for xpart,xdefault,xlimits in [
291                    ('_pd', 0, limits),
292                    ('_pd_n', 35, (0,1000)),
293                    ('_pd_nsigma', 3, (0, 10)),
294                    ('_pd_type', 'gaussian', None),
295                ]:
296                xname = name+xpart
297                xvalue = kw.pop(xname, xdefault)
298                if xlimits is not None:
299                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
300                    pars.append(xname)
301                setattr(self, xname, xvalue)
302        self._parameter_names = pars
303        if kw:
304            raise TypeError("unexpected parameters: %s"%(", ".join(sorted(kw.keys()))))
305        self.update()
306
307    def update(self):
308        self._cache = {}
309
310    def numpoints(self):
311        return len(self.iq)
312
313    def parameters(self):
314        return dict((k,getattr(self,k)) for k in self._parameter_names)
315
316    def theory(self):
317        if 'theory' not in self._cache:
318            if self._fn is None:
319                input = self.model.make_input(self._fn_inputs)
320                self._fn = self.model(input)
321
322            pars = [getattr(self,p).value for p in self._fn.fixed_pars]
323            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
324            #print pars
325            self._theory[self.index] = self._fn(pars, pd_pars, self.cutoff)
326            #self._theory[:] = self._fn.eval(pars, pd_pars)
327            self._cache['theory'] = self._theory
328        return self._cache['theory']
329
330    def residuals(self):
331        #if np.any(self.err ==0): print "zeros in err"
332        return (self.theory()[self.index]-self.iq)/self.diq
333
334    def nllf(self):
335        R = self.residuals()
336        #if np.any(np.isnan(R)): print "NaN in residuals"
337        return 0.5*np.sum(R**2)
338
339    def __call__(self):
340        return 2*self.nllf()/self.dof
341
342    def plot(self, view='log'):
343        plot_result(self.data, self.theory(), view=view)
344
345    def save(self, basename):
346        pass
347
348    def _get_weights(self, par):
349        from . import weights
350
351        relative = self.model.info['partype']['pd-rel']
352        limits = self.model.info['limits']
353        disperser,value,npts,width,nsigma = [getattr(self, par+ext)
354                for ext in ('_pd_type','','_pd_n','_pd','_pd_nsigma')]
355        v,w = weights.get_weights(
356            disperser, int(npts.value), width.value, nsigma.value,
357            value.value, limits[par], par in relative)
358        return v,w/w.max()
359
360    def __getstate__(self):
361        # Can't pickle gpu functions, so instead make them lazy
362        state = self.__dict__.copy()
363        state['_fn'] = None
364        return state
365
366    def __setstate__(self, state):
367        self.__dict__ = state
368
369
370def demo():
371    data = load_data('DEC07086.DAT')
372    set_beam_stop(data, 0.004)
373    plot_data(data)
374    import matplotlib.pyplot as plt; plt.show()
375
376
377if __name__ == "__main__":
378    demo()
Note: See TracBrowser for help on using the repository browser.