source: sasview/park_integration/ScipyFitting.py @ 73b1c72

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 73b1c72 was 4408fb0, checked in by Gervaise Alina <gervyh@…>, 16 years ago

files moved

  • Property mode set to 100644
File size: 8.6 KB
Line 
1#class Fitting
2from sans.guitools.plottables import Data1D
3from Loader import Load
4from scipy import optimize
5#from Fitting import Fit
6
7class FitArrange:
8    def __init__(self):
9        """
10            Store a set of data for a given model to perform the Fit
11            @param model: the model selected by the user
12            @param Ldata: a list of data what the user want to fit
13        """
14        self.model = None
15        self.dList =[]
16       
17    def set_model(self,model):
18        """ set the model """
19        self.model = model
20       
21    def add_data(self,data):
22        """
23            @param data: Data to add in the list
24            fill a self.dataList with data to fit
25        """
26        if not data in self.dList:
27            self.dList.append(data)
28           
29    def get_model(self):
30        """ Return the model"""
31        return self.model   
32     
33    def get_data(self):
34        """ Return list of data"""
35        return self.dList
36     
37    def remove_data(self,data):
38        """
39            Remove one element from the list
40            @param data: Data to remove from the the lsit of data
41        """
42        if data in self.dList:
43            self.dList.remove(data)
44           
45class ScipyFit:
46    """
47        Performs the Fit.he user determine what kind of data
48    """
49    def __init__(self,data=[]):
50        #this is a dictionary of FitArrange elements
51        self.fitArrangeList={}
52        #the constraint of the Fit
53        self.constraint =None
54        #Specify the use of scipy or park fit
55        self.fitType =None
56       
57 
58   
59    def fit(self,pars, qmin=None, qmax=None):
60        """
61             Do the fit
62        """
63        #for item in self.fitArrangeList.:
64       
65        fitproblem=self.fitArrangeList.values()[0]
66        listdata=[]
67        model = fitproblem.get_model()
68        listdata = fitproblem.get_data()
69       
70        parameters = self.set_param(model,pars)
71       
72        # Do the fit with  data set (contains one or more data) and one model
73        xtemp,ytemp,dytemp=self._concatenateData( listdata)
74        if qmin==None:
75            qmin= min(xtemp)
76        if qmax==None:
77            qmax= max(xtemp) 
78        chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax)
79        return chisqr, out, cov
80   
81    def _concatenateData(self, listdata=[]):
82        """ concatenate each fields of all data contains ins listdata"""
83        if listdata==[]:
84            raise ValueError, " data list missing"
85        else:
86            xtemp=[]
87            ytemp=[]
88            dytemp=[]
89               
90            for data in listdata:
91                for i in range(len(data.x)):
92                    if not data.x[i] in xtemp:
93                        xtemp.append(data.x[i])
94                       
95                    if not data.y[i] in ytemp:
96                        ytemp.append(data.y[i])
97                       
98                    if not data.dy[i] in dytemp:
99                        dytemp.append(data.dy[i])
100            return xtemp, ytemp,dytemp
101       
102    def set_model(self,model,Uid):
103        """ Set model """
104        if self.fitArrangeList.has_key(Uid):
105            self.fitArrangeList[Uid].set_model(model)
106        else:
107            fitproblem= FitArrange()
108            fitproblem.set_model(model)
109            self.fitArrangeList[Uid]=fitproblem
110       
111    def set_data(self,data,Uid):
112        """ Receive plottable and create a list of data to fit"""
113       
114        if self.fitArrangeList.has_key(Uid):
115            self.fitArrangeList[Uid].add_data(data)
116        else:
117            fitproblem= FitArrange()
118            fitproblem.add_data(data)
119            self.fitArrangeList[Uid]=fitproblem
120           
121    def get_model(self,Uid):
122        """ return list of data"""
123        return self.fitArrangeList[Uid]
124   
125    def set_param(self,model, pars):
126        """ Recieve a dictionary of parameter and save it """
127        parameters=[]
128        if model==None:
129            raise ValueError, "Cannot set parameters for empty model"
130        else:
131            #for key ,value in pars:
132            for key, value in pars.iteritems():
133                param = Parameter(model, key, value)
134                parameters.append(param)
135        return parameters
136   
137    def add_constraint(self, constraint):
138        """ User specify contraint to fit """
139        self.constraint = str(constraint)
140       
141    def get_constraint(self):
142        """ return the contraint value """
143        return self.constraint
144   
145    def set_constraint(self,constraint):
146        """
147            receive a string as a constraint
148            @param constraint: a string used to constraint some parameters to get a
149                specific value
150        """
151        self.constraint= constraint
152   
153    def createProblem(self):
154        """
155            Check the contraint value and specify what kind of fit to use
156        """
157        mylist=[]
158        for k,value in self.fitArrangeList.iteritems():
159            couple=()
160            model=value.get_model()
161            data=value.get_data()
162            couple=(model,data)
163            mylist.append(couple)
164        #print mylist
165        return mylist
166   
167               
168
169class Parameter:
170    """
171        Class to handle model parameters
172    """
173    def __init__(self, model, name, value=None):
174            self.model = model
175            self.name = name
176            if not value==None:
177                self.model.setParam(self.name, value)
178           
179    def set(self, value):
180        """
181            Set the value of the parameter
182        """
183        self.model.setParam(self.name, value)
184
185    def __call__(self):
186        """
187            Return the current value of the parameter
188        """
189        return self.model.getParam(self.name)
190   
191def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
192    """
193        Fit function
194        @param model: sans model object
195        @param pars: list of parameters
196        @param x: vector of x data
197        @param y: vector of y data
198        @param err_y: vector of y errors
199    """
200    def f(params):
201        """
202            Calculates the vector of residuals for each point
203            in y for a given set of input parameters.
204            @param params: list of parameter values
205            @return: vector of residuals
206        """
207        i = 0
208        for p in pars:
209            p.set(params[i])
210            i += 1
211       
212        residuals = []
213        for j in range(len(x)):
214            if x[j]>qmin and x[j]<qmax:
215                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
216       
217        return residuals
218       
219    def chi2(params):
220        """
221            Calculates chi^2
222            @param params: list of parameter values
223            @return: chi^2
224        """
225        sum = 0
226        res = f(params)
227        for item in res:
228            sum += item*item
229        return sum
230       
231    p = [param() for param in pars]
232    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
233    print info, mesg, success
234    # Calculate chi squared
235    if len(pars)>1:
236        chisqr = chi2(out)
237    elif len(pars)==1:
238        chisqr = chi2([out])
239       
240    return chisqr, out, cov_x   
241
242     
243if __name__ == "__main__": 
244    load= Load()
245   
246    # test fit one data set one model
247    load.set_filename("testdata_line.txt")
248    load.set_values()
249    data1 = Data1D(x=[], y=[], dx=None,dy=None)
250    data1.name = "data1"
251    load.load_data(data1)
252    fitter =ScipyFit()
253    from sans.guitools.LineModel import LineModel
254    model  = LineModel()
255    fitter.set_model(model,1)
256    fitter.set_data(data1,1)
257   
258    chisqr, out, cov=fitter.fit({'A':2,'B':1},None,None)
259    print "my list of param",fitter.createProblem()
260    print"fit only one data",chisqr, out, cov
261    print "this model list of param",model.getParamList()
262    # test fit with 2 data and one model
263    fitter =ScipyFit()
264   
265    fitter.set_model(model,2 )
266    load.set_filename("testdata1.txt")
267    load.set_values()
268    data2 = Data1D(x=[], y=[], dx=None,dy=None)
269    data2.name = "data2"
270   
271    load.load_data(data2)
272    fitter.set_data(data2,2)
273   
274    load.set_filename("testdata2.txt")
275    load.set_values()
276    data3 = Data1D(x=[], y=[], dx=None,dy=None)
277    data3.name = "data2"
278    load.load_data(data3)
279    fitter.set_data(data3,2)
280    chisqr, out, cov=fitter.fit({'A':2,'B':1},None,None)
281    print"fit two data",chisqr, out, cov
282   
Note: See TracBrowser for help on using the repository browser.