source: sasview/park_integration/ScipyFitting.py @ 4dd63eb

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4dd63eb was 4dd63eb, checked in by Gervaise Alina <gervyh@…>, 16 years ago

change on set_model

  • Property mode set to 100644
File size: 10.7 KB
Line 
1"""
2    @organization: ScipyFitting module contains FitArrange , ScipyFit,
3    Parameter classes.All listed classes work together to perform a
4    simple fit with scipy optimizer.
5"""
6from sans.guitools.plottables import Data1D
7from Loader import Load
8from scipy import optimize
9
10
11class FitArrange:
12    def __init__(self):
13        """
14            Class FitArrange contains a set of data for a given model
15            to perform the Fit.FitArrange must contain exactly one model
16            and at least one data for the fit to be performed.
17            model: the model selected by the user
18            Ldata: a list of data what the user wants to fit
19           
20        """
21        self.model = None
22        self.dList =[]
23       
24    def set_model(self,model):
25        """
26            set_model save a copy of the model
27            @param model: the model being set
28        """
29        self.model = model
30       
31    def add_data(self,data):
32        """
33            add_data fill a self.dList with data to fit
34            @param data: Data to add in the list 
35        """
36        if not data in self.dList:
37            self.dList.append(data)
38           
39    def get_model(self):
40        """ @return: saved model """
41        return self.model   
42     
43    def get_data(self):
44        """ @return:  list of data dList"""
45        return self.dList
46     
47    def remove_data(self,data):
48        """
49            Remove one element from the list
50            @param data: Data to remove from dList
51        """
52        if data in self.dList:
53            self.dList.remove(data)
54    def remove_datalist(self):
55        """ empty the complet list dLst"""
56        self.dList=[]
57           
58class ScipyFit:
59    """
60        ScipyFit performs the Fit.This class can be used as follow:
61        #Do the fit SCIPY
62        create an engine: engine = ScipyFit()
63        Use data must be of type plottable
64        Use a sans model
65       
66        Add data with a dictionnary of FitArrangeList where Uid is a key and data
67        is saved in FitArrange object.
68        engine.set_data(data,Uid)
69       
70        Set model parameter "M1"= model.name add {model.parameter.name:value}.
71        @note: Set_param() if used must always preceded set_model()
72             for the fit to be performed.In case of Scipyfit set_param is called in
73             fit () automatically.
74        engine.set_param( model,"M1", {'A':2,'B':4})
75       
76        Add model with a dictionnary of FitArrangeList{} where Uid is a key and model
77        is save in FitArrange object.
78        engine.set_model(model,Uid)
79       
80        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
81        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
82    """
83    def __init__(self):
84        """
85            Creates a dictionary (self.fitArrangeList={})of FitArrange elements
86            with Uid as keys
87        """
88        self.fitArrangeList={}
89       
90    def fit(self,qmin=None, qmax=None):
91        """
92            Performs fit with scipy optimizer.It can only perform fit with one model
93            and a set of data.
94            @note: Cannot perform more than one fit at the time.
95           
96            @param pars: Dictionary of parameter names for the model and their values
97            @param qmin: The minimum value of data's range to be fit
98            @param qmax: The maximum value of data's range to be fit
99            @return chisqr: Value of the goodness of fit metric
100            @return out: list of parameter with the best value found during fitting
101            @return cov: Covariance matrix
102        """
103        # fitproblem contains first fitArrange object(one model and a list of data)
104        fitproblem=self.fitArrangeList.values()[0]
105        listdata=[]
106        model = fitproblem.get_model()
107        listdata = fitproblem.get_data()
108       
109       
110        # Concatenate dList set (contains one or more data)before fitting
111        xtemp,ytemp,dytemp=self._concatenateData( listdata)
112       
113        #print "dytemp",dytemp
114        #Assign a fit range is not boundaries were given
115        if qmin==None:
116            qmin= min(xtemp)
117        if qmax==None:
118            qmax= max(xtemp) 
119       
120        #perform the fit
121        chisqr, out, cov = fitHelper(model,self.parameters, xtemp,ytemp, dytemp ,qmin,qmax)
122       
123        return chisqr, out, cov
124   
125    def _concatenateData(self, listdata=[]):
126        """ 
127            _concatenateData method concatenates each fields of all data contains ins listdata.
128            @param listdata: list of data
129           
130            @return xtemp, ytemp,dytemp:  x,y,dy respectively of data all combined
131                if xi,yi,dyi of two or more data are the same the second appearance of xi,yi,
132                dyi is ignored in the concatenation.
133               
134            @raise: if listdata is empty  will return None
135            @raise: if data in listdata don't contain dy field ,will create an error
136            during fitting
137        """
138        if listdata==[]:
139            raise ValueError, " data list missing"
140        else:
141            xtemp=[]
142            ytemp=[]
143            dytemp=[]
144               
145            for data in listdata:
146                for i in range(len(data.x)):
147                    if not data.x[i] in xtemp:
148                        xtemp.append(data.x[i])
149                       
150                    if not data.y[i] in ytemp:
151                        ytemp.append(data.y[i])
152                    if data.dy and len(data.dy)>0:   
153                        if not data.dy[i] in dytemp:
154                            dytemp.append(data.dy[i])
155                    else:
156                        raise ValueError,"dy is missing will not be able to fit later on"
157            return xtemp, ytemp,dytemp
158       
159    def set_model(self,model,name,Uid,pars={}):
160        """
161     
162            Receive a dictionary of parameter and save it Parameter list
163            For scipy.fit use.
164            Set model in a FitArrange object and add that object in a dictionary
165            with key Uid.
166            @param model: model on with parameter values are set
167            @param name: model name
168            @param Uid: unique key corresponding to a fitArrange object with model
169            @param pars: dictionary of paramaters name and value
170            pars={parameter's name: parameter's value}
171           
172        """
173        self.parameters=[]
174        if model==None:
175            raise ValueError, "Cannot set parameters for empty model"
176        else:
177            model.name=name
178            for key, value in pars.iteritems():
179                param = Parameter(model, key, value)
180                self.parameters.append(param)
181       
182        #A fitArrange is already created but contains dList only at Uid
183        if self.fitArrangeList.has_key(Uid):
184            self.fitArrangeList[Uid].set_model(model)
185        else:
186        #no fitArrange object has been create with this Uid
187            fitproblem= FitArrange()
188            fitproblem.set_model(model)
189            self.fitArrangeList[Uid]=fitproblem
190       
191    def set_data(self,data,Uid):
192        """ Receives plottable, creates a list of data to fit,set data
193            in a FitArrange object and adds that object in a dictionary
194            with key Uid.
195            @param data: data added
196            @param Uid: unique key corresponding to a fitArrange object with data
197            """
198        #A fitArrange is already created but contains model only at Uid
199        if self.fitArrangeList.has_key(Uid):
200            self.fitArrangeList[Uid].add_data(data)
201        else:
202        #no fitArrange object has been create with this Uid
203            fitproblem= FitArrange()
204            fitproblem.add_data(data)
205            self.fitArrangeList[Uid]=fitproblem
206           
207    def get_model(self,Uid):
208        """
209            @param Uid: Uid is key in the dictionary containing the model to return
210            @return  a model at this uid or None if no FitArrange element was created
211            with this Uid
212        """
213        if self.fitArrangeList.has_key(Uid):
214            return self.fitArrangeList[Uid].get_model()
215        else:
216            return None
217   
218   
219   
220    def remove_Fit_Problem(self,Uid):
221        """remove   fitarrange in Uid"""
222        if self.fitArrangeList.has_key(Uid):
223            del self.fitArrangeList[Uid]
224               
225
226class Parameter:
227    """
228        Class to handle model parameters
229    """
230    def __init__(self, model, name, value=None):
231            self.model = model
232            self.name = name
233            if not value==None:
234                self.model.setParam(self.name, value)
235           
236    def set(self, value):
237        """
238            Set the value of the parameter
239        """
240        self.model.setParam(self.name, value)
241
242    def __call__(self):
243        """
244            Return the current value of the parameter
245        """
246        return self.model.getParam(self.name)
247   
248def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
249    """
250        Fit function
251        @param model: sans model object
252        @param pars: list of parameters
253        @param x: vector of x data
254        @param y: vector of y data
255        @param err_y: vector of y errors
256        @return chisqr: Value of the goodness of fit metric
257        @return out: list of parameter with the best value found during fitting
258        @return cov: Covariance matrix
259    """
260    def f(params):
261        """
262            Calculates the vector of residuals for each point
263            in y for a given set of input parameters.
264            @param params: list of parameter values
265            @return: vector of residuals
266        """
267        i = 0
268        for p in pars:
269            p.set(params[i])
270            i += 1
271       
272        residuals = []
273        for j in range(len(x)):
274            if x[j]>qmin and x[j]<qmax:
275                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
276           
277        return residuals
278       
279    def chi2(params):
280        """
281            Calculates chi^2
282            @param params: list of parameter values
283            @return: chi^2
284        """
285        sum = 0
286        res = f(params)
287        for item in res:
288            sum += item*item
289        return sum
290       
291    p = [param() for param in pars]
292    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
293    print info, mesg, success
294    # Calculate chi squared
295    if len(pars)>1:
296        chisqr = chi2(out)
297    elif len(pars)==1:
298        chisqr = chi2([out])
299       
300    return chisqr, out, cov_x   
301
Note: See TracBrowser for help on using the repository browser.