source: sasview/park_integration/ScipyFitting.py @ 83ca047

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 83ca047 was 792db7d5, checked in by Gervaise Alina <gervyh@…>, 16 years ago

more tests added …most of them are failing because of uncertainty , scipy result and park resuls also little bit different

  • Property mode set to 100644
File size: 11.7 KB
Line 
1"""
2    @organization: ScipyFitting module contains FitArrange , ScipyFit,
3    Parameter classes.All listed classes work together to perform a
4    simple fit with scipy optimizer.
5"""
6from sans.guitools.plottables import Data1D
7from Loader import Load
8from scipy import optimize
9
10
11class FitArrange:
12    def __init__(self):
13        """
14            Class FitArrange contains a set of data for a given model
15            to perform the Fit.FitArrange must contain exactly one model
16            and at least one data for the fit to be performed.
17            model: the model selected by the user
18            Ldata: a list of data what the user wants to fit
19           
20        """
21        self.model = None
22        self.dList =[]
23       
24    def set_model(self,model):
25        """
26            set_model save a copy of the model
27            @param model: the model being set
28        """
29        self.model = model
30       
31    def add_data(self,data):
32        """
33            add_data fill a self.dList with data to fit
34            @param data: Data to add in the list 
35        """
36        if not data in self.dList:
37            self.dList.append(data)
38           
39    def get_model(self):
40        """ @return: saved model """
41        return self.model   
42     
43    def get_data(self):
44        """ @return:  list of data dList"""
45        return self.dList
46     
47    def remove_data(self,data):
48        """
49            Remove one element from the list
50            @param data: Data to remove from dList
51        """
52        if data in self.dList:
53            self.dList.remove(data)
54    def remove_datalist(self):
55        """ empty the complet list dLst"""
56        self.dList=[]
57           
58class ScipyFit:
59    """
60        ScipyFit performs the Fit.This class can be used as follow:
61        #Do the fit SCIPY
62        create an engine: engine = ScipyFit()
63        Use data must be of type plottable
64        Use a sans model
65       
66        Add data with a dictionnary of FitArrangeList where Uid is a key and data
67        is saved in FitArrange object.
68        engine.set_data(data,Uid)
69       
70        Set model parameter "M1"= model.name add {model.parameter.name:value}.
71        @note: Set_param() if used must always preceded set_model()
72             for the fit to be performed.In case of Scipyfit set_param is called in
73             fit () automatically.
74        engine.set_param( model,"M1", {'A':2,'B':4})
75       
76        Add model with a dictionnary of FitArrangeList{} where Uid is a key and model
77        is save in FitArrange object.
78        engine.set_model(model,Uid)
79       
80        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
81        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
82    """
83    def __init__(self):
84        """
85            Creates a dictionary (self.fitArrangeList={})of FitArrange elements
86            with Uid as keys
87        """
88        self.fitArrangeList={}
89       
90    def fit(self,pars, qmin=None, qmax=None):
91        """
92            Performs fit with scipy optimizer.It can only perform fit with one model
93            and a set of data.
94            @note: Cannot perform more than one fit at the time.
95           
96            @param pars: Dictionary of parameter names for the model and their values
97            @param qmin: The minimum value of data's range to be fit
98            @param qmax: The maximum value of data's range to be fit
99            @return chisqr: Value of the goodness of fit metric
100            @return out: list of parameter with the best value found during fitting
101            @return cov: Covariance matrix
102        """
103        # fitproblem contains first fitArrange object(one model and a list of data)
104        fitproblem=self.fitArrangeList.values()[0]
105        listdata=[]
106        model = fitproblem.get_model()
107        listdata = fitproblem.get_data()
108       
109        #Create list of Parameter instances and save parameters values in model
110        parameters = self.set_param(model,model.name,pars)
111       
112        # Concatenate dList set (contains one or more data)before fitting
113        xtemp,ytemp,dytemp=self._concatenateData( listdata)
114       
115        #print "dytemp",dytemp
116        #Assign a fit range is not boundaries were given
117        if qmin==None:
118            qmin= min(xtemp)
119        if qmax==None:
120            qmax= max(xtemp) 
121           
122        #perform the fit
123        chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax)
124        return chisqr, out, cov
125   
126    def _concatenateData(self, listdata=[]):
127        """ 
128            _concatenateData method concatenates each fields of all data contains ins listdata.
129            @param listdata: list of data
130           
131            @return xtemp, ytemp,dytemp:  x,y,dy respectively of data all combined
132                if xi,yi,dyi of two or more data are the same the second appearance of xi,yi,
133                dyi is ignored in the concatenation.
134               
135            @raise: if listdata is empty  will return None
136            @raise: if data in listdata don't contain dy field ,will create an error
137            during fitting
138        """
139        if listdata==[]:
140            raise ValueError, " data list missing"
141        else:
142            xtemp=[]
143            ytemp=[]
144            dytemp=[]
145               
146            for data in listdata:
147                for i in range(len(data.x)):
148                    if not data.x[i] in xtemp:
149                        xtemp.append(data.x[i])
150                       
151                    if not data.y[i] in ytemp:
152                        ytemp.append(data.y[i])
153                    if data.dy and len(data.dy)>0:   
154                        if not data.dy[i] in dytemp:
155                            dytemp.append(data.dy[i])
156                    else:
157                        raise ValueError,"dy is missing will not be able to fit later on"
158            return xtemp, ytemp,dytemp
159       
160    def set_model(self,model,Uid):
161        """
162            Set model in a FitArrange object and add that object in a dictionary
163            with key Uid.
164            @param model: the model added
165            @param Uid: unique key corresponding to a fitArrange object with model
166        """
167        #A fitArrange is already created but contains dList only at Uid
168        if self.fitArrangeList.has_key(Uid):
169            self.fitArrangeList[Uid].set_model(model)
170        else:
171        #no fitArrange object has been create with this Uid
172            fitproblem= FitArrange()
173            fitproblem.set_model(model)
174            self.fitArrangeList[Uid]=fitproblem
175       
176    def set_data(self,data,Uid):
177        """ Receives plottable, creates a list of data to fit,set data
178            in a FitArrange object and adds that object in a dictionary
179            with key Uid.
180            @param data: data added
181            @param Uid: unique key corresponding to a fitArrange object with data
182            """
183        #A fitArrange is already created but contains model only at Uid
184        if self.fitArrangeList.has_key(Uid):
185            self.fitArrangeList[Uid].add_data(data)
186        else:
187        #no fitArrange object has been create with this Uid
188            fitproblem= FitArrange()
189            fitproblem.add_data(data)
190            self.fitArrangeList[Uid]=fitproblem
191           
192    def get_model(self,Uid):
193        """
194            @param Uid: Uid is key in the dictionary containing the model to return
195            @return  a model at this uid or None if no FitArrange element was created
196            with this Uid
197        """
198        if self.fitArrangeList.has_key(Uid):
199            return self.fitArrangeList[Uid].get_model()
200        else:
201            return None
202   
203    def set_param(self,model,name, pars):
204        """
205            Recieve a dictionary of parameter and save it
206            @param model: model on with parameter values are set
207            @param name: model name
208            @param pars: dictionary of paramaters name and value
209            pars={parameter's name: parameter's value}
210            @return list of Parameter instance
211        """
212        parameters=[]
213        if model==None:
214            raise ValueError, "Cannot set parameters for empty model"
215        else:
216            model.name=name
217            for key, value in pars.iteritems():
218                param = Parameter(model, key, value)
219                parameters.append(param)
220        return parameters
221   
222    def remove_data(self,Uid,data=None):
223        """ remove one or all data.if data ==None will remove the whole
224            list of data at Uid; else will remove only data in that list.
225            @param Uid: unique id containing FitArrange object with data
226            @param data:data to be removed
227        """
228        if data==None:
229        # remove all element in data list
230            if self.fitArrangeList.has_key(Uid):
231                self.fitArrangeList[Uid].remove_datalist()
232        else:
233        #remove only data in dList
234            if self.fitArrangeList.has_key(Uid):
235                self.fitArrangeList[Uid].remove_data(data)
236               
237    def remove_model(self,Uid):
238        """
239            remove model in FitArrange object with Uid.
240            @param Uid: Unique id corresponding to the FitArrange object
241            where model must be removed.
242        """
243        if self.fitArrangeList.has_key(Uid):
244            self.fitArrangeList[Uid].remove_model()
245               
246
247class Parameter:
248    """
249        Class to handle model parameters
250    """
251    def __init__(self, model, name, value=None):
252            self.model = model
253            self.name = name
254            if not value==None:
255                self.model.setParam(self.name, value)
256           
257    def set(self, value):
258        """
259            Set the value of the parameter
260        """
261        self.model.setParam(self.name, value)
262
263    def __call__(self):
264        """
265            Return the current value of the parameter
266        """
267        return self.model.getParam(self.name)
268   
269def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
270    """
271        Fit function
272        @param model: sans model object
273        @param pars: list of parameters
274        @param x: vector of x data
275        @param y: vector of y data
276        @param err_y: vector of y errors
277        @return chisqr: Value of the goodness of fit metric
278        @return out: list of parameter with the best value found during fitting
279        @return cov: Covariance matrix
280    """
281    def f(params):
282        """
283            Calculates the vector of residuals for each point
284            in y for a given set of input parameters.
285            @param params: list of parameter values
286            @return: vector of residuals
287        """
288        i = 0
289        for p in pars:
290            p.set(params[i])
291            i += 1
292       
293        residuals = []
294        for j in range(len(x)):
295            if x[j]>qmin and x[j]<qmax:
296                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
297           
298        return residuals
299       
300    def chi2(params):
301        """
302            Calculates chi^2
303            @param params: list of parameter values
304            @return: chi^2
305        """
306        sum = 0
307        res = f(params)
308        for item in res:
309            sum += item*item
310        return sum
311       
312    p = [param() for param in pars]
313    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
314    print info, mesg, success
315    # Calculate chi squared
316    if len(pars)>1:
317        chisqr = chi2(out)
318    elif len(pars)==1:
319        chisqr = chi2([out])
320       
321    return chisqr, out, cov_x   
322
Note: See TracBrowser for help on using the repository browser.