source: sasview/park_integration/ScipyFitting.py @ 4c718654

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4c718654 was 4c718654, checked in by Mathieu Doucet <doucetm@…>, 16 years ago

Introduced abstract engine

  • Property mode set to 100644
File size: 8.7 KB
Line 
1"""
2    @organization: ScipyFitting module contains FitArrange , ScipyFit,
3    Parameter classes.All listed classes work together to perform a
4    simple fit with scipy optimizer.
5"""
6from sans.guitools.plottables import Data1D
7from Loader import Load
8from scipy import optimize
9from AbstractFitEngine import FitEngine, Parameter
10
11class FitArrange:
12    def __init__(self):
13        """
14            Class FitArrange contains a set of data for a given model
15            to perform the Fit.FitArrange must contain exactly one model
16            and at least one data for the fit to be performed.
17            model: the model selected by the user
18            Ldata: a list of data what the user wants to fit
19           
20        """
21        self.model = None
22        self.dList =[]
23       
24    def set_model(self,model):
25        """
26            set_model save a copy of the model
27            @param model: the model being set
28        """
29        self.model = model
30       
31    def add_data(self,data):
32        """
33            add_data fill a self.dList with data to fit
34            @param data: Data to add in the list 
35        """
36        if not data in self.dList:
37            self.dList.append(data)
38           
39    def get_model(self):
40        """ @return: saved model """
41        return self.model   
42     
43    def get_data(self):
44        """ @return:  list of data dList"""
45        return self.dList
46     
47    def remove_data(self,data):
48        """
49            Remove one element from the list
50            @param data: Data to remove from dList
51        """
52        if data in self.dList:
53            self.dList.remove(data)
54    def remove_datalist(self):
55        """ empty the complet list dLst"""
56        self.dList=[]
57           
58class ScipyFit(FitEngine):
59    """
60        ScipyFit performs the Fit.This class can be used as follow:
61        #Do the fit SCIPY
62        create an engine: engine = ScipyFit()
63        Use data must be of type plottable
64        Use a sans model
65       
66        Add data with a dictionnary of FitArrangeList where Uid is a key and data
67        is saved in FitArrange object.
68        engine.set_data(data,Uid)
69       
70        Set model parameter "M1"= model.name add {model.parameter.name:value}.
71        @note: Set_param() if used must always preceded set_model()
72             for the fit to be performed.In case of Scipyfit set_param is called in
73             fit () automatically.
74        engine.set_param( model,"M1", {'A':2,'B':4})
75       
76        Add model with a dictionnary of FitArrangeList{} where Uid is a key and model
77        is save in FitArrange object.
78        engine.set_model(model,Uid)
79       
80        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
81        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
82    """
83    def __init__(self):
84        """
85            Creates a dictionary (self.fitArrangeList={})of FitArrange elements
86            with Uid as keys
87        """
88        self.fitArrangeList={}
89       
90    def fit(self,qmin=None, qmax=None):
91        """
92            Performs fit with scipy optimizer.It can only perform fit with one model
93            and a set of data.
94            @note: Cannot perform more than one fit at the time.
95           
96            @param pars: Dictionary of parameter names for the model and their values
97            @param qmin: The minimum value of data's range to be fit
98            @param qmax: The maximum value of data's range to be fit
99            @return chisqr: Value of the goodness of fit metric
100            @return out: list of parameter with the best value found during fitting
101            @return cov: Covariance matrix
102        """
103        # fitproblem contains first fitArrange object(one model and a list of data)
104        fitproblem=self.fitArrangeList.values()[0]
105        listdata=[]
106        model = fitproblem.get_model()
107        listdata = fitproblem.get_data()
108       
109       
110        # Concatenate dList set (contains one or more data)before fitting
111        xtemp,ytemp,dytemp=self._concatenateData( listdata)
112       
113        #print "dytemp",dytemp
114        #Assign a fit range is not boundaries were given
115        if qmin==None:
116            qmin= min(xtemp)
117        if qmax==None:
118            qmax= max(xtemp) 
119       
120        #perform the fit
121        chisqr, out, cov = fitHelper(model,self.parameters, xtemp,ytemp, dytemp ,qmin,qmax)
122       
123        return chisqr, out, cov
124   
125    def set_model(self,model,name,Uid,pars={}):
126        """
127     
128            Receive a dictionary of parameter and save it Parameter list
129            For scipy.fit use.
130            Set model in a FitArrange object and add that object in a dictionary
131            with key Uid.
132            @param model: model on with parameter values are set
133            @param name: model name
134            @param Uid: unique key corresponding to a fitArrange object with model
135            @param pars: dictionary of paramaters name and value
136            pars={parameter's name: parameter's value}
137           
138        """
139        self.parameters=[]
140        if model==None:
141            raise ValueError, "Cannot set parameters for empty model"
142        else:
143            model.name=name
144            for key, value in pars.iteritems():
145                param = Parameter(model, key, value)
146                self.parameters.append(param)
147       
148        #A fitArrange is already created but contains dList only at Uid
149        if self.fitArrangeList.has_key(Uid):
150            self.fitArrangeList[Uid].set_model(model)
151        else:
152        #no fitArrange object has been create with this Uid
153            fitproblem= FitArrange()
154            fitproblem.set_model(model)
155            self.fitArrangeList[Uid]=fitproblem
156       
157    def set_data(self,data,Uid):
158        """ Receives plottable, creates a list of data to fit,set data
159            in a FitArrange object and adds that object in a dictionary
160            with key Uid.
161            @param data: data added
162            @param Uid: unique key corresponding to a fitArrange object with data
163            """
164        #A fitArrange is already created but contains model only at Uid
165        if self.fitArrangeList.has_key(Uid):
166            self.fitArrangeList[Uid].add_data(data)
167        else:
168        #no fitArrange object has been create with this Uid
169            fitproblem= FitArrange()
170            fitproblem.add_data(data)
171            self.fitArrangeList[Uid]=fitproblem
172           
173    def get_model(self,Uid):
174        """
175            @param Uid: Uid is key in the dictionary containing the model to return
176            @return  a model at this uid or None if no FitArrange element was created
177            with this Uid
178        """
179        if self.fitArrangeList.has_key(Uid):
180            return self.fitArrangeList[Uid].get_model()
181        else:
182            return None
183   
184   
185   
186    def remove_Fit_Problem(self,Uid):
187        """remove   fitarrange in Uid"""
188        if self.fitArrangeList.has_key(Uid):
189            del self.fitArrangeList[Uid]
190     
191def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
192    """
193        Fit function
194        @param model: sans model object
195        @param pars: list of parameters
196        @param x: vector of x data
197        @param y: vector of y data
198        @param err_y: vector of y errors
199        @return chisqr: Value of the goodness of fit metric
200        @return out: list of parameter with the best value found during fitting
201        @return cov: Covariance matrix
202    """
203    def f(params):
204        """
205            Calculates the vector of residuals for each point
206            in y for a given set of input parameters.
207            @param params: list of parameter values
208            @return: vector of residuals
209        """
210        i = 0
211        for p in pars:
212            p.set(params[i])
213            i += 1
214       
215        residuals = []
216        for j in range(len(x)):
217            if x[j]>qmin and x[j]<qmax:
218                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
219           
220        return residuals
221       
222    def chi2(params):
223        """
224            Calculates chi^2
225            @param params: list of parameter values
226            @return: chi^2
227        """
228        sum = 0
229        res = f(params)
230        for item in res:
231            sum += item*item
232        return sum
233       
234    p = [param() for param in pars]
235    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
236    print info, mesg, success
237    # Calculate chi squared
238    if len(pars)>1:
239        chisqr = chi2(out)
240    elif len(pars)==1:
241        chisqr = chi2([out])
242       
243    return chisqr, out, cov_x   
244
Note: See TracBrowser for help on using the repository browser.