source: sasmodels/sasmodels/bumps_model.py @ 750ffa5

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 750ffa5 was 750ffa5, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

allow test of dll using single precision

  • Property mode set to 100644
File size: 14.5 KB
Line 
1"""
2Sasmodels core.
3"""
4import datetime
5
6from sasmodels import sesans
7
8# CRUFT python 2.6
9if not hasattr(datetime.timedelta, 'total_seconds'):
10    def delay(dt):
11        """Return number date-time delta as number seconds"""
12        return dt.days * 86400 + dt.seconds + 1e-6 * dt.microseconds
13else:
14    def delay(dt):
15        """Return number date-time delta as number seconds"""
16        return dt.total_seconds()
17
18import numpy as np
19
20try:
21    from .kernelcl import load_model as _loader
22except RuntimeError, exc:
23    import warnings
24    warnings.warn(str(exc))
25    warnings.warn("OpenCL not available --- using ctypes instead")
26    from .kerneldll import load_model as _loader
27
28def load_model(modelname, dtype='single'):
29    """
30    Load model by name.
31    """
32    sasmodels = __import__('sasmodels.models.' + modelname)
33    module = getattr(sasmodels.models, modelname, None)
34    model = _loader(module, dtype=dtype)
35    return model
36
37
38def tic():
39    """
40    Timer function.
41
42    Use "toc=tic()" to start the clock and "toc()" to measure
43    a time interval.
44    """
45    then = datetime.datetime.now()
46    return lambda: delay(datetime.datetime.now() - then)
47
48
49def load_data(filename):
50    """
51    Load data using a sasview loader.
52    """
53    from sas.dataloader.loader import Loader
54    loader = Loader()
55    data = loader.load(filename)
56    if data is None:
57        raise IOError("Data %r could not be loaded" % filename)
58    return data
59
60
61def empty_data1D(q):
62    """
63    Create empty 1D data using the given *q* as the x value.
64
65    Resolutions dq/q is 5%.
66    """
67
68    from sas.dataloader.data_info import Data1D
69
70    Iq = 100 * np.ones_like(q)
71    dIq = np.sqrt(Iq)
72    data = Data1D(q, Iq, dx=0.05 * q, dy=dIq)
73    data.filename = "fake data"
74    data.qmin, data.qmax = q.min(), q.max()
75    return data
76
77
78def empty_data2D(qx, qy=None):
79    """
80    Create empty 2D data using the given mesh.
81
82    If *qy* is missing, create a square mesh with *qy=qx*.
83
84    Resolution dq/q is 5%.
85    """
86    from sas.dataloader.data_info import Data2D, Detector
87
88    if qy is None:
89        qy = qx
90    Qx, Qy = np.meshgrid(qx, qy)
91    Qx, Qy = Qx.flatten(), Qy.flatten()
92    Iq = 100 * np.ones_like(Qx)
93    dIq = np.sqrt(Iq)
94    mask = np.ones(len(Iq), dtype='bool')
95
96    data = Data2D()
97    data.filename = "fake data"
98    data.qx_data = Qx
99    data.qy_data = Qy
100    data.data = Iq
101    data.err_data = dIq
102    data.mask = mask
103
104    # 5% dQ/Q resolution
105    data.dqx_data = 0.05 * Qx
106    data.dqy_data = 0.05 * Qy
107
108    detector = Detector()
109    detector.pixel_size.x = 5 # mm
110    detector.pixel_size.y = 5 # mm
111    detector.distance = 4 # m
112    data.detector.append(detector)
113    data.xbins = qx
114    data.ybins = qy
115    data.source.wavelength = 5 # angstroms
116    data.source.wavelength_unit = "A"
117    data.Q_unit = "1/A"
118    data.I_unit = "1/cm"
119    data.q_data = np.sqrt(Qx ** 2 + Qy ** 2)
120    data.xaxis("Q_x", "A^{-1}")
121    data.yaxis("Q_y", "A^{-1}")
122    data.zaxis("Intensity", r"\text{cm}^{-1}")
123    return data
124
125
126def set_beam_stop(data, radius, outer=None):
127    """
128    Add a beam stop of the given *radius*.  If *outer*, make an annulus.
129    """
130    from sas.dataloader.manipulations import Ringcut
131    if hasattr(data, 'qx_data'):
132        data.mask = Ringcut(0, radius)(data)
133        if outer is not None:
134            data.mask += Ringcut(outer, np.inf)(data)
135    else:
136        data.mask = (data.x >= radius)
137        if outer is not None:
138            data.mask &= (data.x < outer)
139
140
141def set_half(data, half):
142    """
143    Select half of the data, either "right" or "left".
144    """
145    from sas.dataloader.manipulations import Boxcut
146    if half == 'right':
147        data.mask += \
148            Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
149    if half == 'left':
150        data.mask += \
151            Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
152
153
154def set_top(data, cutoff):
155    """
156    Chop the top off the data, above *cutoff*.
157    """
158    from sas.dataloader.manipulations import Boxcut
159    data.mask += \
160        Boxcut(x_min=-np.inf, x_max=np.inf, y_min=-np.inf, y_max=cutoff)(data)
161
162
163def plot_data(data, Iq, vmin=None, vmax=None, view='log'):
164    """
165    Plot the target value for the data.  This could be the data itself,
166    the theory calculation, or the residuals.
167
168    *scale* can be 'log' for log scale data, or 'linear'.
169    """
170    from numpy.ma import masked_array
171    import matplotlib.pyplot as plt
172    if hasattr(data, 'qx_data'):
173        Iq = Iq + 0
174        valid = np.isfinite(Iq)
175        if view == 'log':
176            valid[valid] = (Iq[valid] > 0)
177            Iq[valid] = np.log10(Iq[valid])
178        elif view == 'q4':
179            Iq[valid] = Iq*(data.qx_data[valid]**2+data.qy_data[valid]**2)**2
180        Iq[~valid | data.mask] = 0
181        #plottable = Iq
182        plottable = masked_array(Iq, ~valid | data.mask)
183        xmin, xmax = min(data.qx_data), max(data.qx_data)
184        ymin, ymax = min(data.qy_data), max(data.qy_data)
185        try:
186            if vmin is None: vmin = Iq[valid & ~data.mask].min()
187            if vmax is None: vmax = Iq[valid & ~data.mask].max()
188        except:
189            vmin, vmax = 0, 1
190        plt.imshow(plottable.reshape(128, 128),
191                   interpolation='nearest', aspect=1, origin='upper',
192                   extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
193    else: # 1D data
194        if view == 'linear' or view == 'q4':
195            #idx = np.isfinite(Iq)
196            scale = data.x**4 if view == 'q4' else 1.0
197            plt.plot(data.x, scale*Iq) #, '.')
198        else:
199            # Find the values that are finite and positive
200            idx = np.isfinite(Iq)
201            idx[idx] = (Iq[idx] > 0)
202            Iq[~idx] = np.nan
203            plt.loglog(data.x, Iq)
204
205
206def _plot_result1D(data, theory, view):
207    """
208    Plot the data and residuals for 1D data.
209    """
210    import matplotlib.pyplot as plt
211    from numpy.ma import masked_array, masked
212    #print "not a number",sum(np.isnan(data.y))
213    #data.y[data.y<0.05] = 0.5
214    mdata = masked_array(data.y, data.mask)
215    mdata[np.isnan(mdata)] = masked
216    if view is 'log':
217        mdata[mdata <= 0] = masked
218    mtheory = masked_array(theory, mdata.mask)
219    mresid = masked_array((theory - data.y) / data.dy, mdata.mask)
220
221    scale = data.x**4 if view == 'q4' else 1.0
222    plt.subplot(121)
223    plt.errorbar(data.x, scale*mdata, yerr=data.dy)
224    plt.plot(data.x, scale*mtheory, '-', hold=True)
225    plt.yscale('linear' if view == 'q4' else view)
226    plt.subplot(122)
227    plt.plot(data.x, mresid, 'x')
228
229# pylint: disable=unused-argument
230def _plot_sesans(data, theory, view):
231    import matplotlib.pyplot as plt
232    resid = (theory - data.y) / data.dy
233    plt.subplot(121)
234    plt.errorbar(data.x, data.y, yerr=data.dy)
235    plt.plot(data.x, theory, '-', hold=True)
236    plt.xlabel('spin echo length (A)')
237    plt.ylabel('polarization')
238    plt.subplot(122)
239    plt.plot(data.x, resid, 'x')
240    plt.xlabel('spin echo length (A)')
241    plt.ylabel('residuals')
242
243def _plot_result2D(data, theory, view):
244    """
245    Plot the data and residuals for 2D data.
246    """
247    import matplotlib.pyplot as plt
248    resid = (theory - data.data) / data.err_data
249    plt.subplot(131)
250    plot_data(data, data.data, view=view)
251    plt.colorbar()
252    plt.subplot(132)
253    plot_data(data, theory, view=view)
254    plt.colorbar()
255    plt.subplot(133)
256    plot_data(data, resid, view='linear')
257    plt.colorbar()
258
259class BumpsModel(object):
260    """
261    Return a bumps wrapper for a SAS model.
262
263    *data* is the data to be fitted.
264
265    *model* is the SAS model, e.g., from :func:`gen.opencl_model`.
266
267    *cutoff* is the integration cutoff, which avoids computing the
268    the SAS model where the polydispersity weight is low.
269
270    Model parameters can be initialized with additional keyword
271    arguments, or by assigning to model.parameter_name.value.
272
273    The resulting bumps model can be used directly in a FitProblem call.
274    """
275    def __init__(self, data, model, cutoff=1e-5, **kw):
276        from bumps.names import Parameter
277
278        # remember inputs so we can inspect from outside
279        self.data = data
280        self.model = model
281        self.cutoff = cutoff
282# TODO       if  isinstance(data,SESANSData1D)
283        if hasattr(data, 'lam'):
284            self.data_type = 'sesans'
285        elif hasattr(data, 'qx_data'):
286            self.data_type = 'Iqxy'
287        else:
288            self.data_type = 'Iq'
289
290        partype = model.info['partype']
291
292        # interpret data
293        if self.data_type == 'sesans':
294            q = sesans.make_q(data.sample.zacceptance, data.Rmax)
295            self.index = slice(None, None)
296            self.Iq = data.y
297            self.dIq = data.dy
298            self._theory = np.zeros_like(q)
299            q_vectors = [q]
300        elif self.data_type == 'Iqxy':
301            self.index = (data.mask == 0) & (~np.isnan(data.data))
302            self.Iq = data.data[self.index]
303            self.dIq = data.err_data[self.index]
304            self._theory = np.zeros_like(data.data)
305            if not partype['orientation'] and not partype['magnetic']:
306                q_vectors = [np.sqrt(data.qx_data ** 2 + data.qy_data ** 2)]
307            else:
308                q_vectors = [data.qx_data, data.qy_data]
309        elif self.data_type == 'Iq':
310            self.index = (data.x >= data.qmin) & (data.x <= data.qmax) & ~np.isnan(data.y)
311            self.Iq = data.y[self.index]
312            self.dIq = data.dy[self.index]
313            self._theory = np.zeros_like(data.y)
314            q_vectors = [data.x]
315        else:
316            raise ValueError("Unknown data type") # never gets here
317
318        # Remember function inputs so we can delay loading the function and
319        # so we can save/restore state
320        self._fn_inputs = [v[self.index] for v in q_vectors]
321        self._fn = None
322
323        # define bumps parameters
324        pars = []
325        for p in model.info['parameters']:
326            name, default, limits = p[0], p[2], p[3]
327            value = kw.pop(name, default)
328            setattr(self, name, Parameter.default(value, name=name, limits=limits))
329            pars.append(name)
330        for name in partype['pd-2d']:
331            for xpart, xdefault, xlimits in [
332                    ('_pd', 0, limits),
333                    ('_pd_n', 35, (0, 1000)),
334                    ('_pd_nsigma', 3, (0, 10)),
335                    ('_pd_type', 'gaussian', None),
336                ]:
337                xname = name + xpart
338                xvalue = kw.pop(xname, xdefault)
339                if xlimits is not None:
340                    xvalue = Parameter.default(xvalue, name=xname, limits=xlimits)
341                    pars.append(xname)
342                setattr(self, xname, xvalue)
343        self._parameter_names = pars
344        if kw:
345            raise TypeError("unexpected parameters: %s"
346                            % (", ".join(sorted(kw.keys()))))
347        self.update()
348
349    def update(self):
350        self._cache = {}
351
352    def numpoints(self):
353        """
354            Return the number of points
355        """
356        return len(self.Iq)
357
358    def parameters(self):
359        """
360            Return a dictionary of parameters
361        """
362        return dict((k, getattr(self, k)) for k in self._parameter_names)
363
364    def theory(self):
365        if 'theory' not in self._cache:
366            if self._fn is None:
367                q_input= self.model.make_input(self._fn_inputs)
368                self._fn = self.model(q_input)
369
370            fixed_pars = [getattr(self, p).value for p in self._fn.fixed_pars]
371            pd_pars = [self._get_weights(p) for p in self._fn.pd_pars]
372            #print fixed_pars,pd_pars
373            self._theory[self.index] = self._fn(fixed_pars, pd_pars, self.cutoff)
374            #self._theory[:] = self._fn.eval(pars, pd_pars)
375            if self.data_type == 'sesans':
376                result = sesans.hankel(self.data.x, self.data.lam * 1e-9,
377                                       self.data.sample.thickness / 10,
378                                       self._fn_inputs[0], self._theory)
379                self._cache['theory'] = result
380            else:
381                self._cache['theory'] = self._theory
382        return self._cache['theory']
383
384    def residuals(self):
385        #if np.any(self.err ==0): print "zeros in err"
386        return (self.theory()[self.index] - self.Iq) / self.dIq
387
388    def nllf(self):
389        delta = self.residuals()
390        #if np.any(np.isnan(R)): print "NaN in residuals"
391        return 0.5 * np.sum(delta ** 2)
392
393    #def __call__(self):
394    #    return 2 * self.nllf() / self.dof
395
396    def plot(self, view='log'):
397        """
398        Plot the data and residuals.
399        """
400        data, theory = self.data, self.theory()
401        if self.data_type == 'Iq':
402            _plot_result1D(data, theory, view)
403        elif self.data_type == 'Iqxy':
404            _plot_result2D(data, theory, view)
405        elif self.data_type == 'sesans':
406            _plot_sesans(data, theory, view)
407        else:
408            raise ValueError("Unknown data type")
409
410    def simulate_data(self, noise=None):
411        print "noise", noise
412        if noise is None:
413            noise = self.dIq[self.index]
414        else:
415            noise = 0.01 * noise
416            self.dIq[self.index] = noise
417        y = self.theory()
418        y += y * np.random.randn(*y.shape) * noise
419        if self.data_type == 'Iq':
420            self.data.y[self.index] = y
421        elif self.data_type == 'Iqxy':
422            self.data.data[self.index] = y
423        elif self.data_type == 'sesans':
424            self.data.y[self.index] = y
425        else:
426            raise ValueError("Unknown model")
427
428    def save(self, basename):
429        pass
430
431    def _get_weights(self, par):
432        """
433            Get parameter dispersion weights
434        """
435        from . import weights
436
437        relative = self.model.info['partype']['pd-rel']
438        limits = self.model.info['limits']
439        disperser, value, npts, width, nsigma = [
440            getattr(self, par + ext)
441            for ext in ('_pd_type', '', '_pd_n', '_pd', '_pd_nsigma')]
442        value, weight = weights.get_weights(
443            disperser, int(npts.value), width.value, nsigma.value,
444            value.value, limits[par], par in relative)
445        return value, weight / np.sum(weight)
446
447    def __getstate__(self):
448        # Can't pickle gpu functions, so instead make them lazy
449        state = self.__dict__.copy()
450        state['_fn'] = None
451        return state
452
453    def __setstate__(self, state):
454        # pylint: disable=attribute-defined-outside-init
455        self.__dict__ = state
456
457
458def demo():
459    data = load_data('DEC07086.DAT')
460    set_beam_stop(data, 0.004)
461    plot_data(data, data.data)
462    import matplotlib.pyplot as plt; plt.show()
463
464
465if __name__ == "__main__":
466    demo()
Note: See TracBrowser for help on using the repository browser.