source: sasmodels/sasmodels/bumps_model.py @ 7e224c2

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 7e224c2 was 7e224c2, checked in by Doucet, Mathieu <doucetm@…>, 9 years ago

pylint fixes

  • Property mode set to 100644
File size: 13.8 KB
Line 
1"""
2Sasmodels core.
3"""
4import datetime
5
6from sasmodels import sesans
7
8# CRUFT python 2.6
9if not hasattr(datetime.timedelta, 'total_seconds'):
10    def delay(dt): return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
11else:
12    def delay(dt): return dt.total_seconds()
13
14import numpy as np
15
16try:
17    from .kernelcl import load_model as _loader
18except RuntimeError, exc:
19    import warnings
20    warnings.warn(str(exc))
21    warnings.warn("OpenCL not available --- using ctypes instead")
22    from .kerneldll import load_model as _loader
23
24def load_model(modelname, dtype='single'):
25    """
26    Load model by name.
27    """
28    sasmodels = __import__('sasmodels.models.' + modelname)
29    module = getattr(sasmodels.models, modelname, None)
30    model = _loader(module, dtype=dtype)
31    return model
32
33
34def tic():
35    """
36    Timer function.
37
38    Use "toc=tic()" to start the clock and "toc()" to measure
39    a time interval.
40    """
41    then = datetime.datetime.now()
42    return lambda: delay(datetime.datetime.now() - then)
43
44
45def load_data(filename):
46    """
47    Load data using a sasview loader.
48    """
49    from sas.dataloader.loader import Loader
50    loader = Loader()
51    data = loader.load(filename)
52    if data is None:
53        raise IOError("Data %r could not be loaded" % filename)
54    return data
55
56
57def empty_data1D(q):
58    """
59    Create empty 1D data using the given *q* as the x value.
60
61    Resolutions dq/q is 5%.
62    """
63
64    from sas.dataloader.data_info import Data1D
65
66    Iq = 100 * np.ones_like(q)
67    dIq = np.sqrt(Iq)
68    data = Data1D(q, Iq, dx=0.05 * q, dy=dIq)
69    data.filename = "fake data"
70    data.qmin, data.qmax = q.min(), q.max()
71    return data
72
73
74def empty_data2D(qx, qy=None):
75    """
76    Create empty 2D data using the given mesh.
77
78    If *qy* is missing, create a square mesh with *qy=qx*.
79
80    Resolution dq/q is 5%.
81    """
82    from sas.dataloader.data_info import Data2D, Detector
83
84    if qy is None:
85        qy = qx
86    Qx, Qy = np.meshgrid(qx, qy)
87    Qx, Qy = Qx.flatten(), Qy.flatten()
88    Iq = 100 * np.ones_like(Qx)
89    dIq = np.sqrt(Iq)
90    mask = np.ones(len(Iq), dtype='bool')
91
92    data = Data2D()
93    data.filename = "fake data"
94    data.qx_data = Qx
95    data.qy_data = Qy
96    data.data = Iq
97    data.err_data = dIq
98    data.mask = mask
99
100    # 5% dQ/Q resolution
101    data.dqx_data = 0.05 * Qx
102    data.dqy_data = 0.05 * Qy
103
104    detector = Detector()
105    detector.pixel_size.x = 5 # mm
106    detector.pixel_size.y = 5 # mm
107    detector.distance = 4 # m
108    data.detector.append(detector)
109    data.xbins = qx
110    data.ybins = qy
111    data.source.wavelength = 5 # angstroms
112    data.source.wavelength_unit = "A"
113    data.Q_unit = "1/A"
114    data.I_unit = "1/cm"
115    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
116    data.xaxis("Q_x", "A^{-1}")
117    data.yaxis("Q_y", "A^{-1}")
118    data.zaxis("Intensity", r"\text{cm}^{-1}")
119    return data
120
121
122def set_beam_stop(data, radius, outer=None):
123    """
124    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
125    """
126    from sas.dataloader.manipulations import Ringcut
127    if hasattr(data, 'qx_data'):
128        data.mask = Ringcut(0, radius)(data)
129        if outer is not None:
130            data.mask += Ringcut(outer, np.inf)(data)
131    else:
132        data.mask = (data.x >= radius)
133        if outer is not None:
134            data.mask &= (data.x < outer)
135
136
137def set_half(data, half):
138    """
139    Select half of the data, either "right" or "left".
140    """
141    from sas.dataloader.manipulations import Boxcut
142    if half == 'right':
143        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
144    if half == 'left':
145        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
146
147
148def set_top(data, max):
149    """
150    Chop the top off the data, above *max*.
151    """
152    from sas.dataloader.manipulations import Boxcut
153    data.mask += Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=max)(data)
154
155
156def plot_data(data, iq, vmin=None, vmax=None, scale='log'):
157    """
158    Plot the target value for the data.  This could be the data itself,
159    the theory calculation, or the residuals.
160
161    *scale* can be 'log' for log scale data, or 'linear'.
162    """
163    from numpy.ma import masked_array, masked
164    import matplotlib.pyplot as plt
165    if hasattr(data, 'qx_data'):
166        iq = iq + 0
167        valid = np.isfinite(iq)
168        if scale == 'log':
169            valid[valid] = (iq[valid] > 0)
170            iq[valid] = np.log10(iq[valid])
171        iq[~valid | data.mask] = 0
172        #plottable = iq
173        plottable = masked_array(iq, ~valid | data.mask)
174        xmin, xmax = min(data.qx_data), max(data.qx_data)
175        ymin, ymax = min(data.qy_data), max(data.qy_data)
176        try:
177            if vmin is None: vmin = iq[valid & ~data.mask].min()
178            if vmax is None: vmax = iq[valid & ~data.mask].max()
179        except:
180            vmin, vmax = 0, 1
181        plt.imshow(plottable.reshape(128, 128),
182                   interpolation='nearest', aspect=1, origin='upper',
183                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
184    else: # 1D data
185        if scale == 'linear':
186            idx = np.isfinite(iq)
187            plt.plot(data.x[idx], iq[idx])
188        else:
189            idx = np.isfinite(iq)
190            idx[idx] = (iq[idx] > 0)
191            plt.loglog(data.x[idx], iq[idx])
192
193
194def _plot_result1D(data, theory, view):
195    """
196    Plot the data and residuals for 1D data.
197    """
198    import matplotlib.pyplot as plt
199    from numpy.ma import masked_array, masked
200    #print "not a number",sum(np.isnan(data.y))
201    #data.y[data.y<0.05] = 0.5
202    mdata = masked_array(data.y, data.mask)
203    mdata[np.isnan(mdata)] = masked
204    if view is 'log':
205        mdata[mdata <= 0] = masked
206    mtheory = masked_array(theory, mdata.mask)
207    mresid = masked_array((theory - data.y) / data.dy, mdata.mask)
208
209    plt.subplot(121)
210    plt.errorbar(data.x, mdata, yerr=data.dy)
211    plt.plot(data.x, mtheory, '-', hold=True)
212    plt.yscale(view)
213    plt.subplot(122)
214    plt.plot(data.x, mresid, 'x')
215
216def _plot_sesans(data, theory, view):
217    import matplotlib.pyplot as plt
218    resid = (theory - data.y) / data.dy
219    plt.subplot(121)
220    plt.errorbar(data.x, data.y, yerr=data.dy)
221    plt.plot(data.x, theory, '-', hold=True)
222    plt.xlabel('spin echo length (A)')
223    plt.ylabel('polarization')
224    plt.subplot(122)
225    plt.plot(data.x, resid, 'x')
226    plt.xlabel('spin echo length (A)')
227    plt.ylabel('residuals')
228
229def _plot_result2D(data, theory, view):
230    """
231    Plot the data and residuals for 2D data.
232    """
233    import matplotlib.pyplot as plt
234    resid = (theory - data.data) / data.err_data
235    plt.subplot(131)
236    plot_data(data, data.data, scale=view)
237    plt.colorbar()
238    plt.subplot(132)
239    plot_data(data, theory, scale=view)
240    plt.colorbar()
241    plt.subplot(133)
242    plot_data(data, resid, scale='linear')
243    plt.colorbar()
244
245class BumpsModel(object):
246    """
247    Return a bumps wrapper for a SAS model.
248
249    *data* is the data to be fitted.
250
251    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
252
253    *cutoff* is the integration cutoff, which avoids computing the
254    the SAS model where the polydispersity weight is low.
255
256    Model parameters can be initialized with additional keyword
257    arguments, or by assigning to model.parameter_name.value.
258
259    The resulting bumps model can be used directly in a FitProblem call.
260    """
261    def __init__(self, data, model, cutoff=1e-5, **kw):
262        from bumps.names import Parameter
263
264        # remember inputs so we can inspect from outside
265        self.data = data
266        self.model = model
267        self.cutoff = cutoff
268# TODO       if  isinstance(data,SESANSData1D)
269        if hasattr(data, 'lam'):
270            self.data_type = 'sesans'
271        elif hasattr(data, 'qx_data'):
272            self.data_type = 'Iqxy'
273        else:
274            self.data_type = 'Iq'
275
276        partype = model.info['partype']
277
278        # interpret data
279        if self.data_type == 'sesans':
280            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
281            self.index = slice(None, None)
282            self.iq = data.y
283            self.diq = data.dy
284            self._theory = np.zeros_like(q)
285            q_vectors = [q]
286        elif self.data_type == 'Iqxy':
287            self.index = (data.mask == 0) & (~np.isnan(data.data))
288            self.iq = data.data[self.index]
289            self.diq = data.err_data[self.index]
290            self._theory = np.zeros_like(data.data)
291            if not partype['orientation'] and not partype['magnetic']:
292                q_vectors = [np.sqrt(data.qx_data ** 2 + data.qy_data ** 2)]
293            else:
294                q_vectors = [data.qx_data, data.qy_data]
295        elif self.data_type == 'Iq':
296            self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y)
297            self.iq = data.y[self.index]
298            self.diq = data.dy[self.index]
299            self._theory = np.zeros_like(data.y)
300            q_vectors = [data.x]
301        else:
302            raise ValueError("Unknown data type") # never gets here
303
304        # Remember function inputs so we can delay loading the function and
305        # so we can save/restore state
306        self._fn_inputs = [v[self.index] for v in q_vectors]
307        self._fn = None
308
309        # define bumps parameters
310        pars = []
311        for p in model.info['parameters']:
312            name, default, limits, ptype = p[0], p[2], p[3], p[4]
313            value = kw.pop(name, default)
314            setattr(self, name, Parameter.default(value, name=name, limits=limits))
315            pars.append(name)
316        for name in partype['pd-2d']:
317            for xpart, xdefault, xlimits in [
318                    ('_pd', 0, limits),
319                    ('_pd_n', 35, (0, 1000)),
320                    ('_pd_nsigma', 3, (0, 10)),
321                    ('_pd_type', 'gaussian', None),
322                ]:
323                xname = name + xpart
324                xvalue = kw.pop(xname, xdefault)
325                if xlimits is not None:
326                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
327                    pars.append(xname)
328                setattr(self, xname, xvalue)
329        self._parameter_names = pars
330        if kw:
331            raise TypeError("unexpected parameters: %s" % (", ".join(sorted(kw.keys()))))
332        self.update()
333
334    def update(self):
335        self._cache = {}
336
337    def numpoints(self):
338        """
339            Return the number of points
340        """
341        return len(self.iq)
342
343    def parameters(self):
344        """
345            Return a dictionary of parameters
346        """
347        return dict((k, getattr(self, k)) for k in self._parameter_names)
348
349    def theory(self):
350        if 'theory' not in self._cache:
351            if self._fn is None:
352                input_value = self.model.make_input(self._fn_inputs)
353                self._fn = self.model(input_value)
354
355            fixed_pars = [getattr(self, p).value for p in self._fn.fixed_pars]
356            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
357            #print fixed_pars,pd_pars
358            self._theory[self.index] = self._fn(fixed_pars, pd_pars, self.cutoff)
359            #self._theory[:] = self._fn.eval(pars, pd_pars)
360            if self.data_type == 'sesans':
361                P = sesans.hankel(self.data.x, self.data.lam * 1e-9,
362                                  self.data.sample.thickness / 10, self._fn_inputs[0],
363                                  self._theory)
364                self._cache['theory'] = P
365            else:
366                self._cache['theory'] = self._theory
367        return self._cache['theory']
368
369    def residuals(self):
370        #if np.any(self.err ==0): print "zeros in err"
371        return (self.theory()[self.index] - self.iq) / self.diq
372
373    def nllf(self):
374        R = self.residuals()
375        #if np.any(np.isnan(R)): print "NaN in residuals"
376        return 0.5 * np.sum(R ** 2)
377
378    def __call__(self):
379        return 2 * self.nllf() / self.dof
380
381    def plot(self, view='log'):
382        """
383        Plot the data and residuals.
384        """
385        data, theory = self.data, self.theory()
386        if self.data_type == 'Iq':
387            _plot_result1D(data, theory, view)
388        elif self.data_type == 'Iqxy':
389            _plot_result2D(data, theory, view)
390        elif self.data_type == 'sesans':
391            _plot_sesans(data, theory, view)
392        else:
393            raise ValueError("Unknown data type")
394
395    def simulate_data(self, noise=None):
396        print "noise", noise
397        if noise is None:
398            noise = self.diq[self.index]
399        else:
400            noise = 0.01 * noise
401            self.diq[self.index] = noise
402        y = self.theory()
403        y += y * np.random.randn(*y.shape) * noise
404        if self.data_type == 'Iq':
405            self.data.y[self.index] = y
406        elif self.data_type == 'Iqxy':
407            self.data.data[self.index] = y
408        elif self.data_type == 'sesans':
409            self.data.y[self.index] = y
410        else:
411            raise ValueError("Unknown model")
412
413    def save(self, basename):
414        pass
415
416    def _get_weights(self, par):
417        """
418            Get parameter dispersion weights
419        """
420        from . import weights
421
422        relative = self.model.info['partype']['pd-rel']
423        limits = self.model.info['limits']
424        disperser, value, npts, width, nsigma = \
425            [getattr(self, par + ext) for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')]
426        v, w = weights.get_weights(
427            disperser, int(npts.value), width.value, nsigma.value,
428            value.value, limits[par], par in relative)
429        return v, w / w.max()
430
431    def __getstate__(self):
432        # Can't pickle gpu functions, so instead make them lazy
433        state = self.__dict__.copy()
434        state['_fn'] = None
435        return state
436
437    def __setstate__(self, state):
438        self.__dict__ = state
439
440
441def demo():
442    data = load_data('DEC07086.DAT')
443    set_beam_stop(data, 0.004)
444    plot_data(data)
445    import matplotlib.pyplot as plt; plt.show()
446
447
448if __name__ == "__main__":
449    demo()
Note: See TracBrowser for help on using the repository browser.