source: sasview/park_integration/docs/FittingModule.py @ a55fac1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since a55fac1 was f24b8f4, checked in by Gervaise Alina <gervyh@…>, 16 years ago

made some chagnges on the unit test

  • Property mode set to 100644
File size: 7.9 KB
RevLine 
[4a0536a]1#class Fitting
[acb8788]2from sans.guitools.plottables import Data1D
3from Loader import Load
4from scipy import optimize
[0954398]5
6
7class FitArrange:
8    def __init__(self):
9        """
[f24b8f4]10            Store a set of data for a given model to perform the Fit
[0954398]11            @param model: the model selected by the user
12            @param Ldata: a list of data what the user want to fit
13        """
14        self.model = None
15        self.dList =[]
16       
17    def set_model(self,model):
18        """ set the model """
19        self.model = model
20       
21    def add_data(self,data):
22        """
23            @param data: Data to add in the list
24            fill a self.dataList with data to fit
25        """
26        if not data in self.dList:
27            self.dList.append(data)
28           
29    def get_model(self):
[eb06cbe]30        """ Return the model"""
[0954398]31        return self.model   
32     
33    def get_data(self):
34        """ Return list of data"""
35        return self.dList
36     
[eb06cbe]37    def remove_data(self,data):
[0954398]38        """
39            Remove one element from the list
40            @param data: Data to remove from the the lsit of data
41        """
42        if data in self.dList:
43            self.dList.remove(data)
44           
[4a0536a]45class Fitting:
46    """
47        Performs the Fit.he user determine what kind of data
48    """
[acb8788]49    def __init__(self,data=[]):
[f24b8f4]50        #this is a dictionary of FitArrange elements
[0954398]51        self.fitArrangeList={}
[f24b8f4]52        #the constraint of the Fit
[eb06cbe]53        self.constraint =None
[f24b8f4]54        #Specify the use of scipy or park fit
[73bbe35]55        self.fitType =None
[4a0536a]56       
[73bbe35]57    def fit_engine(self,word):
[4a0536a]58        """
59            Check the contraint value and specify what kind of fit to use
60        """
[73bbe35]61        self.fitType = word
[4a0536a]62        return True
[0954398]63   
[acb8788]64    def fit(self,pars, qmin=None, qmax=None):
[4a0536a]65        """
66             Do the fit
67        """
[e94d1d3]68        #for item in self.fitArrangeList.:
[0954398]69       
70        fitproblem=self.fitArrangeList.values()[0]
71        listdata=[]
[f24b8f4]72        model = fitproblem.get_model()
73        listdata = fitproblem.get_data()
74       
75        parameters = self.set_param(model,pars)
[e94d1d3]76        if listdata==[]:
77            raise ValueError, " data list missing"
78        else:
79            # Do the fit with more than one data set and one model
80            xtemp=[]
81            ytemp=[]
82            dytemp=[]
[f24b8f4]83           
[e94d1d3]84            for data in listdata:
85                for i in range(len(data.x)):
86                    if not data.x[i] in xtemp:
87                        xtemp.append(data.x[i])
88                       
89                    if not data.y[i] in ytemp:
90                        ytemp.append(data.y[i])
91                       
92                    if not data.dy[i] in dytemp:
93                        dytemp.append(data.dy[i])
94            if qmin==None:
95                qmin= min(xtemp)
96            if qmax==None:
97                qmax= max(xtemp) 
[f24b8f4]98            chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax)
[e94d1d3]99            return chisqr, out, cov
[4a0536a]100   
[0954398]101    def set_model(self,model,Uid):
[4a0536a]102        """ Set model """
[f24b8f4]103        if self.fitArrangeList.has_key(Uid):
104            self.fitArrangeList[Uid].set_model(model)
105        else:
106            fitproblem= FitArrange()
107            fitproblem.set_model(model)
108            self.fitArrangeList[Uid]=fitproblem
[acb8788]109       
[0954398]110    def set_data(self,data,Uid):
[acb8788]111        """ Receive plottable and create a list of data to fit"""
[f24b8f4]112       
[0954398]113        if self.fitArrangeList.has_key(Uid):
114            self.fitArrangeList[Uid].add_data(data)
115        else:
116            fitproblem= FitArrange()
117            fitproblem.add_data(data)
118            self.fitArrangeList[Uid]=fitproblem
119           
[f24b8f4]120    def get_model(self,Uid):
[acb8788]121        """ return list of data"""
[f24b8f4]122        return self.fitArrangeList[Uid]
[4a0536a]123   
[73bbe35]124    def set_param(self,model, pars):
125        """ Recieve a dictionary of parameter and save it """
[f24b8f4]126        parameters=[]
[e94d1d3]127        if model==None:
128            raise ValueError, "Cannot set parameters for empty model"
129        else:
130            #for key ,value in pars:
131            for key, value in pars.iteritems():
132                param = Parameter(model, key, value)
[f24b8f4]133                parameters.append(param)
134        return parameters
135   
136    def add_constraint(self, constraint):
[4a0536a]137        """ User specify contraint to fit """
[f24b8f4]138        self.constraint = str(constraint)
[acb8788]139       
[f24b8f4]140    def get_constraint(self):
[4a0536a]141        """ return the contraint value """
[f24b8f4]142        return self.constraint
143   
144    def set_constraint(self,constraint):
145        """
146            receive a string as a constraint
147            @param constraint: a string used to constraint some parameters to get a
148                specific value
149        """
150        self.constraint= constraint
151   
152   
153               
[acb8788]154
155class Parameter:
156    """
157        Class to handle model parameters
158    """
159    def __init__(self, model, name, value=None):
160            self.model = model
161            self.name = name
162            if not value==None:
163                self.model.setParam(self.name, value)
164           
165    def set(self, value):
166        """
167            Set the value of the parameter
168        """
169        self.model.setParam(self.name, value)
170
171    def __call__(self):
172        """
173            Return the current value of the parameter
174        """
175        return self.model.getParam(self.name)
176   
177def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None):
178    """
179        Fit function
180        @param model: sans model object
181        @param pars: list of parameters
182        @param x: vector of x data
183        @param y: vector of y data
184        @param err_y: vector of y errors
185    """
186    def f(params):
187        """
188            Calculates the vector of residuals for each point
189            in y for a given set of input parameters.
190            @param params: list of parameter values
191            @return: vector of residuals
192        """
193        i = 0
194        for p in pars:
195            p.set(params[i])
196            i += 1
197       
198        residuals = []
199        for j in range(len(x)):
200            if x[j]>qmin and x[j]<qmax:
201                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] )
202       
203        return residuals
204       
205    def chi2(params):
206        """
207            Calculates chi^2
208            @param params: list of parameter values
209            @return: chi^2
210        """
211        sum = 0
212        res = f(params)
213        for item in res:
214            sum += item*item
215        return sum
216       
217    p = [param() for param in pars]
218    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True)
219    print info, mesg, success
220    # Calculate chi squared
221    if len(pars)>1:
222        chisqr = chi2(out)
223    elif len(pars)==1:
224        chisqr = chi2([out])
[4a0536a]225       
[acb8788]226    return chisqr, out, cov_x   
227
228     
229if __name__ == "__main__": 
230    load= Load()
231   
232    # test fit one data set one model
233    load.set_filename("testdata_line.txt")
234    load.set_values()
235    data1 = Data1D(x=[], y=[], dx=None,dy=None)
236    data1.name = "data1"
237    load.load_data(data1)
238    Fit =Fitting()
[0954398]239   
[acb8788]240    from sans.guitools.LineModel import LineModel
241    model  = LineModel()
[e94d1d3]242    Fit.set_model(model,1)
[0954398]243    Fit.set_data(data1,1)
[acb8788]244   
[73bbe35]245    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
[acb8788]246    print"fit only one data",chisqr, out, cov
247   
248    # test fit with 2 data and one model
[0954398]249    Fit =Fitting()
250    Fit.set_model(model,2 )
[acb8788]251    load.set_filename("testdata1.txt")
252    load.set_values()
253    data2 = Data1D(x=[], y=[], dx=None,dy=None)
254    data2.name = "data2"
255   
256    load.load_data(data2)
[0954398]257    Fit.set_data(data2,2)
[acb8788]258   
259    load.set_filename("testdata2.txt")
260    load.set_values()
261    data3 = Data1D(x=[], y=[], dx=None,dy=None)
262    data3.name = "data2"
263    load.load_data(data3)
[0954398]264    Fit.set_data(data3,2)
[73bbe35]265    chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None)
[0954398]266    print"fit two data",chisqr, out, cov
267   
Note: See TracBrowser for help on using the repository browser.