source: sasview/park_integration/ScipyFitting.py @ 9a3adab

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 9a3adab was cf3b781, checked in by Gervaise Alina <gervyh@…>, 16 years ago

need more tests.but usecase 3 implemented

  • Property mode set to 100644
File size: 7.9 KB
Line 
1#class Fitting
2from sans.guitools.plottables import Data1D
3from Loader import Load
4from scipy import optimize
5#from Fitting import Fit
6
7class FitArrange:
8    def __init__(self):
9        """
10            Store a set of data for a given model to perform the Fit
11            @param model: the model selected by the user
12            @param Ldata: a list of data what the user want to fit
13        """
14        self.model = None
15        self.dList =[]
16       
17    def set_model(self,model):
18        """ set the model """
19        self.model = model
20       
21    def add_data(self,data):
22        """
23            @param data: Data to add in the list
24            fill a self.dataList with data to fit
25        """
26        if not data in self.dList:
27            self.dList.append(data)
28           
29    def get_model(self):
30        """ Return the model"""
31        return self.model   
32     
33    def get_data(self):
34        """ Return list of data"""
35        return self.dList
36     
37    def remove_data(self,data):
38        """
39            Remove one element from the list
40            @param data: Data to remove from the the lsit of data
41        """
42        if data in self.dList:
43            self.dList.remove(data)
44           
45class ScipyFit:
46    """
47        Performs the Fit.he user determine what kind of data
48    """
49    def __init__(self,data=[]):
50        #this is a dictionary of FitArrange elements
51        self.fitArrangeList={}
52        #the constraint of the Fit
53        self.constraint =None
54        #Specify the use of scipy or park fit
55        self.fitType =None
56       
57 
58   
59    def fit(self,pars, qmin=None, qmax=None):
60        """
61             Do the fit
62        """
63        #for item in self.fitArrangeList.:
64       
65        fitproblem=self.fitArrangeList.values()[0]
66        listdata=[]
67        model = fitproblem.get_model()
68        listdata = fitproblem.get_data()
69       
70        parameters = self.set_param(model,model.name,pars)
71       
72        # Do the fit with  data set (contains one or more data) and one model
73        xtemp,ytemp,dytemp=self._concatenateData( listdata)
74        print "dytemp",dytemp
75        if qmin==None:
76            qmin= min(xtemp)
77        if qmax==None:
78            qmax= max(xtemp) 
79        chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax)
80        return chisqr, out, cov
81   
82    def _concatenateData(self, listdata=[]):
83        """ concatenate each fields of all data contains ins listdata"""
84        if listdata==[]:
85            raise ValueError, " data list missing"
86        else:
87            xtemp=[]
88            ytemp=[]
89            dytemp=[]
90               
91            for data in listdata:
92                for i in range(len(data.x)):
93                    if not data.x[i] in xtemp:
94                        xtemp.append(data.x[i])
95                       
96                    if not data.y[i] in ytemp:
97                        ytemp.append(data.y[i])
98                       
99                    if not data.dy[i] in dytemp:
100                        dytemp.append(data.dy[i])
101            return xtemp, ytemp,dytemp
102       
103    def set_model(self,model,Uid):
104        """ Set model """
105        if self.fitArrangeList.has_key(Uid):
106            self.fitArrangeList[Uid].set_model(model)
107        else:
108            fitproblem= FitArrange()
109            fitproblem.set_model(model)
110            self.fitArrangeList[Uid]=fitproblem
111       
112    def set_data(self,data,Uid):
113        """ Receive plottable and create a list of data to fit"""
114       
115        if self.fitArrangeList.has_key(Uid):
116            self.fitArrangeList[Uid].add_data(data)
117        else:
118            fitproblem= FitArrange()
119            fitproblem.add_data(data)
120            self.fitArrangeList[Uid]=fitproblem
121           
122    def get_model(self,Uid):
123        """ return list of data"""
124        return self.fitArrangeList[Uid]
125   
126    def set_param(self,model,name, pars):
127        """ Recieve a dictionary of parameter and save it """
128        parameters=[]
129        if model==None:
130            raise ValueError, "Cannot set parameters for empty model"
131        else:
132            model.name=name
133            for key, value in pars.iteritems():
134                param = Parameter(model, key, value)
135                parameters.append(param)
136        return parameters
137   
138    def add_constraint(self, constraint):
139        """ User specify contraint to fit """
140        self.constraint = str(constraint)
141       
142    def get_constraint(self):
143        """ return the contraint value """
144        return self.constraint
145   
146    def set_constraint(self,constraint):
147        """
148            receive a string as a constraint
149            @param constraint: a string used to constraint some parameters to get a
150                specific value
151        """
152        self.constraint= constraint
153   
154    def createProblem(self):
155        """
156            Check the contraint value and specify what kind of fit to use
157        """
158        mylist=[]
159        for k,value in self.fitArrangeList.iteritems():
160            couple=()
161            model=value.get_model()
162            data=value.get_data()
163            couple=(model,data)
164            mylist.append(couple)
165        #print mylist
166        return mylist
167    def remove_data(self,Uid,data=None):
168        """ remove one or all data"""
169        if data==None:# remove all element in data list
170            if self.fitArrangeList.has_key(Uid):
171                self.fitArrangeList[Uid].remove_datalist()
172        else:
173            if self.fitArrangeList.has_key(Uid):
174                self.fitArrangeList[Uid].remove_data(data)
175               
176    def remove_model(self,Uid):
177        """ remove model """
178        if self.fitArrangeList.has_key(Uid):
179            self.fitArrangeList[Uid].remove_model()
180               
181
182class Parameter:
183    """
184        Class to handle model parameters
185    """
186    def __init__(self, model, name, value=None):
187            self.model = model
188            self.name = name
189            if not value==None:
190                self.model.setParam(self.name, value)
191           
192    def set(self, value):
193        """
194            Set the value of the parameter
195        """
196        self.model.setParam(self.name, value)
197
198    def __call__(self):
199        """
200            Return the current value of the parameter
201        """
202        return self.model.getParam(self.name)
203   
204def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
205    """
206        Fit function
207        @param model: sans model object
208        @param pars: list of parameters
209        @param x: vector of x data
210        @param y: vector of y data
211        @param err_y: vector of y errors
212    """
213    def f(params):
214        """
215            Calculates the vector of residuals for each point
216            in y for a given set of input parameters.
217            @param params: list of parameter values
218            @return: vector of residuals
219        """
220        i = 0
221        for p in pars:
222            p.set(params[i])
223            i += 1
224       
225        residuals = []
226        for j in range(len(x)):
227            if x[j]>qmin and x[j]<qmax:
228                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
229           
230        return residuals
231       
232    def chi2(params):
233        """
234            Calculates chi^2
235            @param params: list of parameter values
236            @return: chi^2
237        """
238        sum = 0
239        res = f(params)
240        for item in res:
241            sum += item*item
242        return sum
243       
244    p = [param() for param in pars]
245    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
246    print info, mesg, success
247    # Calculate chi squared
248    if len(pars)>1:
249        chisqr = chi2(out)
250    elif len(pars)==1:
251        chisqr = chi2([out])
252       
253    return chisqr, out, cov_x   
254
Note: See TracBrowser for help on using the repository browser.