source: sasview/park_integration/test/FittingModule.py @ e94d1d3

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since e94d1d3 was e94d1d3, checked in by Gervaise Alina <gervyh@…>, 16 years ago

more modif

  • Property mode set to 100644
File size: 7.8 KB
Line 
1#class Fitting
2from sans.guitools.plottables import Data1D
3from Loader import Load
4from scipy import optimize
5
6
7class FitArrange:
8    def __init__(self):
9        """
10            @param model: the model selected by the user
11            @param Ldata: a list of data what the user want to fit
12        """
13        self.model = None
14        self.dList =[]
15       
16    def set_model(self,model):
17        """ set the model """
18        self.model = model
19       
20    def add_data(self,data):
21        """
22            @param data: Data to add in the list
23            fill a self.dataList with data to fit
24        """
25        if not data in self.dList:
26            self.dList.append(data)
27           
28    def get_model(self):
29        return self.model   
30     
31    def get_data(self):
32        """ Return list of data"""
33        return self.dList
34     
35    def delete_data(self,data):
36        """
37            Remove one element from the list
38            @param data: Data to remove from the the lsit of data
39        """
40        if data in self.dList:
41            self.dList.remove(data)
42           
43class Fitting:
44    """
45        Performs the Fit.he user determine what kind of data
46    """
47    def __init__(self,data=[]):
48        #self.model is a list of all models to fit
49        #self.model={}
50        self.fitArrangeList={}
51        #the list of all data to fit
52        self.data = data
53        #list of models parameters
54        self.parameters=[]
55        self.contraint =None
56        self.residuals=[]
57        self.fitType =None
58       
59    def fit_engine(self,word):
60        """
61            Check the contraint value and specify what kind of fit to use
62        """
63        self.fitType = word
64        return True
65   
66    def fit(self,pars, qmin=None, qmax=None):
67        """
68             Do the fit
69        """
70        #for item in self.fitArrangeList.:
71       
72        fitproblem=self.fitArrangeList.values()[0]
73        listdata=[]
74        model =fitproblem.get_model()
75        listdata= fitproblem.get_data()
76        self.set_param(model, pars)
77        if listdata==[]:
78            raise ValueError, " data list missing"
79        else:
80            # Do the fit with more than one data set and one model
81            xtemp=[]
82            ytemp=[]
83            dytemp=[]
84           
85           
86            for data in listdata:
87                for i in range(len(data.x)):
88                    if not data.x[i] in xtemp:
89                        xtemp.append(data.x[i])
90                       
91                    if not data.y[i] in ytemp:
92                        ytemp.append(data.y[i])
93                       
94                    if not data.dy[i] in dytemp:
95                        dytemp.append(data.dy[i])
96            if qmin==None:
97                qmin= min(xtemp)
98            if qmax==None:
99                qmax= max(xtemp) 
100            chisqr, out, cov = fitHelper(model,self.parameters, xtemp,ytemp, dytemp ,qmin,qmax)
101            return chisqr, out, cov
102   
103    def set_model(self,model,Uid):
104        """ Set model """
105        #self.model[Uid] = model
106        fitproblem= FitArrange()
107        fitproblem.set_model(model)
108        self.fitArrangeList[Uid]=fitproblem
109       
110    def set_data(self,data,Uid):
111        """ Receive plottable and create a list of data to fit"""
112        #self.data.append(data)
113        if self.fitArrangeList.has_key(Uid):
114            self.fitArrangeList[Uid].add_data(data)
115        else:
116            fitproblem= FitArrange()
117            fitproblem.add_data(data)
118            self.fitArrangeList[Uid]=fitproblem
119           
120    def get_data(self):
121        """ return list of data"""
122        return self.data
123   
124    def set_param(self,model, pars):
125        """ Recieve a dictionary of parameter and save it """
126        self.parameters=[]
127        if model==None:
128            raise ValueError, "Cannot set parameters for empty model"
129        else:
130            #for key ,value in pars:
131            for key, value in pars.iteritems():
132                print "this is the key",key
133                print "this is the value",value
134                param = Parameter(model, key, value)
135                self.parameters.append(param)
136           
137    def add_contraint(self, contraint):
138        """ User specify contraint to fit """
139        self.contraint = str(contraint)
140       
141    def get_contraint(self):
142        """ return the contraint value """
143        return self.contraint
144
145
146
147class Parameter:
148    """
149        Class to handle model parameters
150    """
151    def __init__(self, model, name, value=None):
152            self.model = model
153            self.name = name
154            if not value==None:
155                self.model.setParam(self.name, value)
156           
157    def set(self, value):
158        """
159            Set the value of the parameter
160        """
161        self.model.setParam(self.name, value)
162
163    def __call__(self):
164        """
165            Return the current value of the parameter
166        """
167        return self.model.getParam(self.name)
168   
169def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
170    """
171        Fit function
172        @param model: sans model object
173        @param pars: list of parameters
174        @param x: vector of x data
175        @param y: vector of y data
176        @param err_y: vector of y errors
177    """
178    def f(params):
179        """
180            Calculates the vector of residuals for each point
181            in y for a given set of input parameters.
182            @param params: list of parameter values
183            @return: vector of residuals
184        """
185        i = 0
186        for p in pars:
187            p.set(params[i])
188            i += 1
189       
190        residuals = []
191        for j in range(len(x)):
192            if x[j]>qmin and x[j]<qmax:
193                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
194       
195        return residuals
196       
197    def chi2(params):
198        """
199            Calculates chi^2
200            @param params: list of parameter values
201            @return: chi^2
202        """
203        sum = 0
204        res = f(params)
205        for item in res:
206            sum += item*item
207        return sum
208       
209    p = [param() for param in pars]
210    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
211    print info, mesg, success
212    # Calculate chi squared
213    if len(pars)>1:
214        chisqr = chi2(out)
215    elif len(pars)==1:
216        chisqr = chi2([out])
217       
218    return chisqr, out, cov_x   
219
220     
221if __name__ == "__main__": 
222    load= Load()
223   
224    # test fit one data set one model
225    load.set_filename("testdata_line.txt")
226    load.set_values()
227    data1 = Data1D(x=[], y=[], dx=None,dy=None)
228    data1.name = "data1"
229    load.load_data(data1)
230    Fit =Fitting()
231   
232    from sans.guitools.LineModel import LineModel
233    model  = LineModel()
234    Fit.set_model(model,1)
235    Fit.set_data(data1,1)
236    #default_A = model.getParam('A')
237    #default_B = model.getParam('B')
238    #cstA = Parameter(model, 'A', default_A)
239    #cstB  = Parameter(model, 'B', default_B)
240   
241    #chisqr, out, cov=Fit.fit([cstA,cstB],None,None)
242    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
243    print"fit only one data",chisqr, out, cov
244   
245    # test fit with 2 data and one model
246    Fit =Fitting()
247    Fit.set_model(model,2 )
248    load.set_filename("testdata1.txt")
249    load.set_values()
250    data2 = Data1D(x=[], y=[], dx=None,dy=None)
251    data2.name = "data2"
252   
253    load.load_data(data2)
254    Fit.set_data(data2,2)
255   
256    load.set_filename("testdata2.txt")
257    load.set_values()
258    data3 = Data1D(x=[], y=[], dx=None,dy=None)
259    data3.name = "data2"
260    load.load_data(data3)
261    Fit.set_data(data3,2)
262    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
263    print"fit two data",chisqr, out, cov
264   
Note: See TracBrowser for help on using the repository browser.