source: sasview/park_integration/ScipyFitting.py @ 4dd63eb

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4dd63eb was 4dd63eb, checked in by Gervaise Alina <gervyh@…>, 16 years ago

change on set_model

  • Property mode set to 100644
File size: 10.7 KB
RevLine 
[792db7d5]1"""
2    @organization: ScipyFitting module contains FitArrange , ScipyFit,
3    Parameter classes.All listed classes work together to perform a
4    simple fit with scipy optimizer.
5"""
[7705306]6from sans.guitools.plottables import Data1D
7from Loader import Load
8from scipy import optimize
[792db7d5]9
[7705306]10
11class FitArrange:
12    def __init__(self):
13        """
[792db7d5]14            Class FitArrange contains a set of data for a given model
15            to perform the Fit.FitArrange must contain exactly one model
16            and at least one data for the fit to be performed.
17            model: the model selected by the user
18            Ldata: a list of data what the user wants to fit
19           
[7705306]20        """
21        self.model = None
22        self.dList =[]
23       
24    def set_model(self,model):
[792db7d5]25        """
26            set_model save a copy of the model
27            @param model: the model being set
28        """
[7705306]29        self.model = model
30       
31    def add_data(self,data):
32        """
[792db7d5]33            add_data fill a self.dList with data to fit
34            @param data: Data to add in the list 
[7705306]35        """
36        if not data in self.dList:
37            self.dList.append(data)
38           
39    def get_model(self):
[792db7d5]40        """ @return: saved model """
[7705306]41        return self.model   
42     
43    def get_data(self):
[792db7d5]44        """ @return:  list of data dList"""
[7705306]45        return self.dList
46     
47    def remove_data(self,data):
48        """
49            Remove one element from the list
[792db7d5]50            @param data: Data to remove from dList
[7705306]51        """
52        if data in self.dList:
53            self.dList.remove(data)
[792db7d5]54    def remove_datalist(self):
55        """ empty the complet list dLst"""
56        self.dList=[]
[7705306]57           
58class ScipyFit:
59    """
[792db7d5]60        ScipyFit performs the Fit.This class can be used as follow:
61        #Do the fit SCIPY
62        create an engine: engine = ScipyFit()
63        Use data must be of type plottable
64        Use a sans model
65       
66        Add data with a dictionnary of FitArrangeList where Uid is a key and data
67        is saved in FitArrange object.
68        engine.set_data(data,Uid)
69       
70        Set model parameter "M1"= model.name add {model.parameter.name:value}.
71        @note: Set_param() if used must always preceded set_model()
72             for the fit to be performed.In case of Scipyfit set_param is called in
73             fit () automatically.
74        engine.set_param( model,"M1", {'A':2,'B':4})
75       
76        Add model with a dictionnary of FitArrangeList{} where Uid is a key and model
77        is save in FitArrange object.
78        engine.set_model(model,Uid)
79       
80        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
81        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]82    """
[792db7d5]83    def __init__(self):
84        """
85            Creates a dictionary (self.fitArrangeList={})of FitArrange elements
86            with Uid as keys
87        """
[7705306]88        self.fitArrangeList={}
89       
[4dd63eb]90    def fit(self,qmin=None, qmax=None):
[7705306]91        """
[792db7d5]92            Performs fit with scipy optimizer.It can only perform fit with one model
93            and a set of data.
94            @note: Cannot perform more than one fit at the time.
95           
96            @param pars: Dictionary of parameter names for the model and their values
97            @param qmin: The minimum value of data's range to be fit
98            @param qmax: The maximum value of data's range to be fit
99            @return chisqr: Value of the goodness of fit metric
100            @return out: list of parameter with the best value found during fitting
101            @return cov: Covariance matrix
[7705306]102        """
[792db7d5]103        # fitproblem contains first fitArrange object(one model and a list of data)
[7705306]104        fitproblem=self.fitArrangeList.values()[0]
105        listdata=[]
106        model = fitproblem.get_model()
107        listdata = fitproblem.get_data()
108       
109       
[792db7d5]110        # Concatenate dList set (contains one or more data)before fitting
[7705306]111        xtemp,ytemp,dytemp=self._concatenateData( listdata)
[792db7d5]112       
113        #print "dytemp",dytemp
114        #Assign a fit range is not boundaries were given
[7705306]115        if qmin==None:
116            qmin= min(xtemp)
117        if qmax==None:
[792db7d5]118            qmax= max(xtemp) 
[4dd63eb]119       
[792db7d5]120        #perform the fit
[4dd63eb]121        chisqr, out, cov = fitHelper(model,self.parameters, xtemp,ytemp, dytemp ,qmin,qmax)
122       
[7705306]123        return chisqr, out, cov
124   
125    def _concatenateData(self, listdata=[]):
[792db7d5]126        """ 
127            _concatenateData method concatenates each fields of all data contains ins listdata.
128            @param listdata: list of data
129           
130            @return xtemp, ytemp,dytemp:  x,y,dy respectively of data all combined
131                if xi,yi,dyi of two or more data are the same the second appearance of xi,yi,
132                dyi is ignored in the concatenation.
133               
134            @raise: if listdata is empty  will return None
135            @raise: if data in listdata don't contain dy field ,will create an error
136            during fitting
137        """
[7705306]138        if listdata==[]:
139            raise ValueError, " data list missing"
140        else:
141            xtemp=[]
142            ytemp=[]
143            dytemp=[]
144               
145            for data in listdata:
146                for i in range(len(data.x)):
147                    if not data.x[i] in xtemp:
148                        xtemp.append(data.x[i])
149                       
150                    if not data.y[i] in ytemp:
151                        ytemp.append(data.y[i])
[792db7d5]152                    if data.dy and len(data.dy)>0:   
153                        if not data.dy[i] in dytemp:
154                            dytemp.append(data.dy[i])
155                    else:
156                        raise ValueError,"dy is missing will not be able to fit later on"
[7705306]157            return xtemp, ytemp,dytemp
158       
[4dd63eb]159    def set_model(self,model,name,Uid,pars={}):
[792db7d5]160        """
[4dd63eb]161     
162            Receive a dictionary of parameter and save it Parameter list
163            For scipy.fit use.
[792db7d5]164            Set model in a FitArrange object and add that object in a dictionary
165            with key Uid.
[4dd63eb]166            @param model: model on with parameter values are set
167            @param name: model name
[792db7d5]168            @param Uid: unique key corresponding to a fitArrange object with model
[4dd63eb]169            @param pars: dictionary of paramaters name and value
170            pars={parameter's name: parameter's value}
171           
[792db7d5]172        """
[4dd63eb]173        self.parameters=[]
174        if model==None:
175            raise ValueError, "Cannot set parameters for empty model"
176        else:
177            model.name=name
178            for key, value in pars.iteritems():
179                param = Parameter(model, key, value)
180                self.parameters.append(param)
181       
[792db7d5]182        #A fitArrange is already created but contains dList only at Uid
[7705306]183        if self.fitArrangeList.has_key(Uid):
184            self.fitArrangeList[Uid].set_model(model)
185        else:
[792db7d5]186        #no fitArrange object has been create with this Uid
[7705306]187            fitproblem= FitArrange()
188            fitproblem.set_model(model)
189            self.fitArrangeList[Uid]=fitproblem
190       
191    def set_data(self,data,Uid):
[792db7d5]192        """ Receives plottable, creates a list of data to fit,set data
193            in a FitArrange object and adds that object in a dictionary
194            with key Uid.
195            @param data: data added
196            @param Uid: unique key corresponding to a fitArrange object with data
197            """
198        #A fitArrange is already created but contains model only at Uid
[7705306]199        if self.fitArrangeList.has_key(Uid):
200            self.fitArrangeList[Uid].add_data(data)
201        else:
[792db7d5]202        #no fitArrange object has been create with this Uid
[7705306]203            fitproblem= FitArrange()
204            fitproblem.add_data(data)
205            self.fitArrangeList[Uid]=fitproblem
206           
207    def get_model(self,Uid):
[792db7d5]208        """
209            @param Uid: Uid is key in the dictionary containing the model to return
210            @return  a model at this uid or None if no FitArrange element was created
211            with this Uid
212        """
213        if self.fitArrangeList.has_key(Uid):
214            return self.fitArrangeList[Uid].get_model()
215        else:
216            return None
[7705306]217   
218   
[4dd63eb]219   
220    def remove_Fit_Problem(self,Uid):
221        """remove   fitarrange in Uid"""
[cf3b781]222        if self.fitArrangeList.has_key(Uid):
[4dd63eb]223            del self.fitArrangeList[Uid]
[7705306]224               
225
226class Parameter:
227    """
228        Class to handle model parameters
229    """
230    def __init__(self, model, name, value=None):
231            self.model = model
232            self.name = name
233            if not value==None:
234                self.model.setParam(self.name, value)
235           
236    def set(self, value):
237        """
238            Set the value of the parameter
239        """
240        self.model.setParam(self.name, value)
241
242    def __call__(self):
243        """
244            Return the current value of the parameter
245        """
246        return self.model.getParam(self.name)
247   
248def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
249    """
250        Fit function
251        @param model: sans model object
252        @param pars: list of parameters
253        @param x: vector of x data
254        @param y: vector of y data
255        @param err_y: vector of y errors
[792db7d5]256        @return chisqr: Value of the goodness of fit metric
257        @return out: list of parameter with the best value found during fitting
258        @return cov: Covariance matrix
[7705306]259    """
260    def f(params):
261        """
262            Calculates the vector of residuals for each point
263            in y for a given set of input parameters.
264            @param params: list of parameter values
265            @return: vector of residuals
266        """
267        i = 0
268        for p in pars:
269            p.set(params[i])
270            i += 1
271       
272        residuals = []
273        for j in range(len(x)):
274            if x[j]>qmin and x[j]<qmax:
275                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
[cf3b781]276           
[7705306]277        return residuals
278       
279    def chi2(params):
280        """
281            Calculates chi^2
282            @param params: list of parameter values
283            @return: chi^2
284        """
285        sum = 0
286        res = f(params)
287        for item in res:
288            sum += item*item
289        return sum
290       
291    p = [param() for param in pars]
292    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
293    print info, mesg, success
294    # Calculate chi squared
295    if len(pars)>1:
296        chisqr = chi2(out)
297    elif len(pars)==1:
298        chisqr = chi2([out])
299       
300    return chisqr, out, cov_x   
301
Note: See TracBrowser for help on using the repository browser.