source: sasmodels/multisasmodels.py @ 5d4777d

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5d4777d was d772f5d, checked in by HMP1 <helen.park@…>, 10 years ago

Added 1D Fit, fixed fitting error

  • Property mode set to 100644
File size: 4.7 KB
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import sys
5import numpy as np
6import pyopencl as cl
7from bumps.names import Parameter
8from sans.dataloader.loader import Loader
9from sans.dataloader.manipulations import Ringcut, Boxcut
10
11
12def load_data(filename):
13    loader = Loader()
14    data = loader.load(filename)
15    if data is None:
16        raise IOError("Data %r could not be loaded"%filename)
17    return data
18
19
20def set_beam_stop(data, radius, outer=None):
21    data.mask = Ringcut(0, radius)(data)
22    if outer is not None:
23        data.mask += Ringcut(outer,np.inf)(data)
24
25def set_half(data, half):
26    if half == 'left':
27        data.mask += Boxcut(x_min=-np.inf, x_max=0.0, y_min=-np.inf, y_max=np.inf)(data)
28    if half == 'right':
29        data.mask += Boxcut(x_min=0.0, x_max=np.inf, y_min=-np.inf, y_max=np.inf)(data)
30
31
32
33def plot_data(data, iq, vmin=None, vmax=None):
34    from numpy.ma import masked_array
35    import matplotlib.pyplot as plt
36    img = masked_array(iq, data.mask)
37    xmin, xmax = min(data.qx_data), max(data.qx_data)
38    ymin, ymax = min(data.qy_data), max(data.qy_data)
39    plt.imshow(img.reshape(128,128),
40               interpolation='nearest', aspect=1, origin='upper',
41               extent=[xmin, xmax, ymin, ymax], vmin=vmin, vmax=vmax)
42
43def plot_result(data, theory, view='linear'):
44    import matplotlib.pyplot as plt
45    from numpy.ma import masked_array, masked
46    plt.subplot(1, 3, 1)
47    #print "not a number",sum(np.isnan(data.data))
48    #data.data[data.data<0.05] = 0.5
49    mdata = masked_array(data.data, data.mask)
50    mdata[np.isnan(mdata)] = masked
51    if view is 'log':
52        mdata[mdata <= 0] = masked
53        mdata = np.log10(mdata)
54        mtheory = masked_array(np.log10(theory), mdata.mask)
55    else:
56        mtheory = masked_array(theory, mdata.mask)
57    mresid = masked_array((theory-data.data)/data.err_data, data.mask)
58    vmin = min(mdata.min(), mtheory.min())
59    vmax = max(mdata.max(), mtheory.max())
60
61    plot_data(data, mdata, vmin=vmin, vmax=vmax)
62    plt.colorbar()
63    plt.subplot(1, 3, 2)
64    plot_data(data, mtheory, vmin=vmin, vmax=vmax)
65    plt.colorbar()
66    plt.subplot(1, 3, 3)
67    plot_data(data, mresid)
68    plt.colorbar()
69
70
71def demo():
72    data = load_data('JUN03289.DAT')
73    set_beam_stop(data, 0.004)
74    plot_data(data)
75    import matplotlib.pyplot as plt; plt.show()
76
77
78GPU_CONTEXT = None
79GPU_QUEUE = None
80def card():
81    global GPU_CONTEXT, GPU_QUEUE
82    if GPU_CONTEXT is None:
83        GPU_CONTEXT = cl.create_some_context()
84        GPU_QUEUE = cl.CommandQueue(GPU_CONTEXT)
85    return GPU_CONTEXT, GPU_QUEUE
86
87
88class SasModel(object):
89    def __init__(self, data, model, dtype='float32', **kw):
90        self.__dict__['_parameters'] = {}
91        self.index = (data.mask==0) & (~np.isnan(data.data))
92        self.iq = data.data[self.index]
93        self.diq = data.err_data[self.index]
94        self.data = data
95        self.qx = data.qx_data
96        self.qy = data.qy_data
97        self.gpu = model(self.qx, self.qy, dtype=dtype)
98        pd_pars = set(base+attr for base in model.PD_PARS for attr in ('_pd','_pd_n','_pd_nsigma'))
99        total_pars = set(model.PARS.keys()) | pd_pars
100        extra_pars = set(kw.keys()) - total_pars
101        if extra_pars:
102            raise TypeError("unexpected parameters %s"%(str(extra_pars,)))
103        pars = model.PARS.copy()
104        pars.update((base+'_pd', 0) for base in model.PD_PARS)
105        pars.update((base+'_pd_n', 35) for base in model.PD_PARS)
106        pars.update((base+'_pd_nsigma', 3) for base in model.PD_PARS)
107        pars.update(kw)
108        self._parameters = dict((k, Parameter.default(v, name=k)) for k, v in pars.items())
109
110    def set_result(self, result):
111        self.result = result
112        return self.result
113
114    def get_result(self):
115        return self.result
116
117    def numpoints(self):
118        return len(self.iq)
119
120    def parameters(self):
121        return self._parameters
122
123    def __getattr__(self, par):
124        return self._parameters[par]
125
126    def __setattr__(self, par, val):
127        if par in self._parameters:
128            self._parameters[par] = val
129        else:
130            self.__dict__[par] = val
131
132    def theory(self):
133        pars = dict((k,v.value) for k,v in self._parameters.items())
134        result = self.gpu.eval(pars)
135        return result
136
137    def residuals(self):
138        #if np.any(self.err ==0): print "zeros in err"
139        return (self.get_result()[self.index]-self.iq)/self.diq
140
141    def nllf(self):
142        R = self.residuals()
143        #if np.any(np.isnan(R)): print "NaN in residuals"
144        return 0.5*np.sum(R**2)
145
146    def __call__(self):
147        return 2*self.nllf()/self.dof
148
149    def plot(self, view='log'):
150        plot_result(self.data, self.get_result(), view=view)
151
152    def save(self, basename):
153        pass
154
155    def update(self):
156        pass
157
Note: See TracBrowser for help on using the repository browser.