source: sasview/park_integration/FittingModule.py @ ad8bcd6

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ad8bcd6 was 197ea24, checked in by Mathieu Doucet <doucetm@…>, 16 years ago

Example for taking out residuals

  • Property mode set to 100644
File size: 8.7 KB
Line 
1#class Fitting
2from sans.guitools.plottables import Data1D
3from Loader import Load
4from scipy import optimize
5
6
7class FitArrange:
8    def __init__(self):
9        """
10            Store a set of data for a given model to perform the Fit
11            @param model: the model selected by the user
12            @param Ldata: a list of data what the user want to fit
13        """
14        self.model = None
15        self.dList =[]
16       
17    def set_model(self,model):
18        """ set the model """
19        self.model = model
20       
21    def add_data(self,data):
22        """
23            @param data: Data to add in the list
24            fill a self.dataList with data to fit
25        """
26        if not data in self.dList:
27            self.dList.append(data)
28           
29    def get_model(self):
30        """ Return the model"""
31        return self.model   
32     
33    def get_data(self):
34        """ Return list of data"""
35        return self.dList
36     
37    def remove_data(self,data):
38        """
39            Remove one element from the list
40            @param data: Data to remove from the the lsit of data
41        """
42        if data in self.dList:
43            self.dList.remove(data)
44           
45class Fitting:
46    """
47        Performs the Fit.he user determine what kind of data
48    """
49    def __init__(self,data=[]):
50        #this is a dictionary of FitArrange elements
51        self.fitArrangeList={}
52        #the constraint of the Fit
53        self.constraint =None
54        #Specify the use of scipy or park fit
55        self.fitType =None
56       
57    def fit_engine(self,word):
58        """
59            Check the contraint value and specify what kind of fit to use
60        """
61        self.fitType = word
62        return True
63   
64    def fit(self,pars, qmin=None, qmax=None):
65        """
66             Do the fit
67        """
68        #for item in self.fitArrangeList.:
69       
70        fitproblem=self.fitArrangeList.values()[0]
71        listdata=[]
72        model = fitproblem.get_model()
73        listdata = fitproblem.get_data()
74       
75        parameters = self.set_param(model,pars)
76        if listdata==[]:
77            raise ValueError, " data list missing"
78        else:
79            # Do the fit with more than one data set and one model
80            xtemp=[]
81            ytemp=[]
82            dytemp=[]
83           
84            for data in listdata:
85                for i in range(len(data.x)):
86                    if not data.x[i] in xtemp:
87                        xtemp.append(data.x[i])
88                       
89                    if not data.y[i] in ytemp:
90                        ytemp.append(data.y[i])
91                       
92                    if not data.dy[i] in dytemp:
93                        dytemp.append(data.dy[i])
94            if qmin==None:
95                qmin= min(xtemp)
96            if qmax==None:
97                qmax= max(xtemp) 
98            chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax)
99            return chisqr, out, cov
100   
101    def set_model(self,model,Uid):
102        """ Set model """
103        if self.fitArrangeList.has_key(Uid):
104            self.fitArrangeList[Uid].set_model(model)
105        else:
106            fitproblem= FitArrange()
107            fitproblem.set_model(model)
108            self.fitArrangeList[Uid]=fitproblem
109       
110    def set_data(self,data,Uid):
111        """ Receive plottable and create a list of data to fit"""
112       
113        if self.fitArrangeList.has_key(Uid):
114            self.fitArrangeList[Uid].add_data(data)
115        else:
116            fitproblem= FitArrange()
117            fitproblem.add_data(data)
118            self.fitArrangeList[Uid]=fitproblem
119           
120    def get_model(self,Uid):
121        """ return list of data"""
122        return self.fitArrangeList[Uid]
123   
124    def set_param(self,model, pars):
125        """ Recieve a dictionary of parameter and save it """
126        parameters=[]
127        if model==None:
128            raise ValueError, "Cannot set parameters for empty model"
129        else:
130            #for key ,value in pars:
131            for key, value in pars.iteritems():
132                param = Parameter(model, key, value)
133                parameters.append(param)
134        return parameters
135   
136    def add_constraint(self, constraint):
137        """ User specify contraint to fit """
138        self.constraint = str(constraint)
139       
140    def get_constraint(self):
141        """ return the contraint value """
142        return self.constraint
143   
144    def set_constraint(self,constraint):
145        """
146            receive a string as a constraint
147            @param constraint: a string used to constraint some parameters to get a
148                specific value
149        """
150        self.constraint= constraint
151   
152   
153               
154
155class Parameter:
156    """
157        Class to handle model parameters
158    """
159    def __init__(self, model, name, value=None):
160            self.model = model
161            self.name = name
162            if not value==None:
163                self.model.setParam(self.name, value)
164           
165    def set(self, value):
166        """
167            Set the value of the parameter
168        """
169        self.model.setParam(self.name, value)
170
171    def __call__(self):
172        """
173            Return the current value of the parameter
174        """
175        return self.model.getParam(self.name)
176
177class FitHelper:
178   
179    def __init__(self,model, pars, x, y, err_y ,qmin=None, qmax=None):
180        self.x = x
181        self.y = y
182        self.model = model
183        self.err_y = err_y
184        self.qmin = qmin
185        self.qmax= qmax
186        self.pars = pars
187       
188    def __call__(self, params):
189        i = 0
190        for p in self.pars:
191            p.set(params[i])
192            i += 1
193       
194        residuals = []
195        for j in range(len(self.x)):
196            if self.x[j]>self.qmin and self.x[j]<self.qmax:
197                residuals.append( ( self.y[j] - self.model.runXY(self.x[j]) ) / self.err_y[j] )
198       
199        return residuals
200   
201   
202
203def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
204    """
205        Fit function
206        @param model: sans model object
207        @param pars: list of parameters
208        @param x: vector of x data
209        @param y: vector of y data
210        @param err_y: vector of y errors
211    """
212   
213    f = FitHelper(model, pars, x, y, err_y ,qmin, qmax)
214   
215    def ff(params):
216        """
217            Calculates the vector of residuals for each point
218            in y for a given set of input parameters.
219            @param params: list of parameter values
220            @return: vector of residuals
221        """
222        i = 0
223        for p in pars:
224            p.set(params[i])
225            i += 1
226       
227        residuals = []
228        for j in range(len(x)):
229            if x[j]>qmin and x[j]<qmax:
230                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
231       
232        return residuals
233       
234    def chi2(params):
235        """
236            Calculates chi^2
237            @param params: list of parameter values
238            @return: chi^2
239        """
240        sum = 0
241        res = f(params)
242        for item in res:
243            sum += item*item
244        return sum
245       
246    p = [param() for param in pars]
247    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
248    print info, mesg, success
249    # Calculate chi squared
250    if len(pars)>1:
251        chisqr = chi2(out)
252    elif len(pars)==1:
253        chisqr = chi2([out])
254       
255    return chisqr, out, cov_x   
256
257     
258if __name__ == "__main__": 
259    load= Load()
260   
261    # test fit one data set one model
262    load.set_filename("testdata_line.txt")
263    load.set_values()
264    data1 = Data1D(x=[], y=[], dx=None,dy=None)
265    data1.name = "data1"
266    load.load_data(data1)
267    Fit =Fitting()
268   
269    from LineModel import LineModel
270    model  = LineModel()
271    Fit.set_model(model,1)
272    Fit.set_data(data1,1)
273   
274    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
275    print"fit only one data",chisqr, out, cov
276   
277    # test fit with 2 data and one model
278    Fit =Fitting()
279    Fit.set_model(model,2 )
280    load.set_filename("testdata1.txt")
281    load.set_values()
282    data2 = Data1D(x=[], y=[], dx=None,dy=None)
283    data2.name = "data2"
284   
285    load.load_data(data2)
286    Fit.set_data(data2,2)
287   
288    load.set_filename("testdata2.txt")
289    load.set_values()
290    data3 = Data1D(x=[], y=[], dx=None,dy=None)
291    data3.name = "data2"
292    load.load_data(data3)
293    Fit.set_data(data3,2)
294    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
295    print"fit two data",chisqr, out, cov
296   
Note: See TracBrowser for help on using the repository browser.