Changeset acb8788 in sasview


Ignore:
Timestamp:
May 21, 2008 11:11:48 AM (17 years ago)
Author:
Gervaise Alina <gervyh@…>
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
c14c503
Parents:
93c2ef2
Message:

mod

File:
1 edited

Legend:

Unmodified
Added
Removed
  • park_integration/test/FittingModule.py

    r4a0536a racb8788  
    11#class Fitting 
    2 from sans.guitools.fittings import Parameter 
    3 import sans.guitools.fittings 
     2from sans.guitools.plottables import Data1D 
     3from Loader import Load 
     4from scipy import optimize 
    45class Fitting: 
    56    """  
    67        Performs the Fit.he user determine what kind of data  
    78    """ 
    8     def __init__(self,data1,data2=None): 
    9         self.model = "" 
    10         self.data1 = data1 
    11         self.data2= data2 
     9    def __init__(self,data=[]): 
     10        #self.model is a list of all models to fit 
     11        self.model=[] 
     12        #the list of all data to fit  
     13        self.data = data 
     14        #dictionary of models parameters 
     15        self.parameters={} 
    1216        self.contraint =None 
     17        self.residuals=[] 
    1318         
    1419    def fit_engine(self): 
     
    1722        """ 
    1823        return True 
    19     def fit(self): 
     24    def fit(self,pars, qmin=None, qmax=None): 
    2025        """ 
    2126             Do the fit  
    2227        """ 
    23            
    24         #Display the fittings values 
    25         self.default_A = self.model.getParam('A')  
    26         self.default_B = self.model.getParam('B')  
    27         self.cstA = Parameter(self.model, 'A', self.default_A) 
    28         self.cstB  = Parameter(self.model, 'B', self.default_B) 
    29         xmin= min(self.data1.x) 
    30         xmax= max(self.data1.x) 
    31          
    32         chisqr, out, cov = sans.guitools.fittings.sansfit(self.model,  
    33         [self.cstA, self.cstB],self.data1.x, self.data1.y,self.data1.dy,xmin,xmax) 
    34          
     28         
     29        # Do the fit with 2 data set and one model  
     30        numberData= len(self.data) 
     31        numberModel= len(self.model) 
     32        if numberData==1 and numberModel==1: 
     33            if qmin==None: 
     34                xmin= min(self.data[0].x) 
     35            if qmax==None: 
     36                xmax= max(self.data[0].x) 
     37            
     38            #chisqr, out, cov = fitHelper(self.model[0],self.data[0],pars,xmin,xmax) 
     39            chisqr, out, cov =fitHelper(self.model[0], pars, self.data[0].x, 
     40                                 self.data[0].y, self.data[0].dy ,xmin,xmax) 
     41        else:# More than one data to fit with one model 
     42            xtemp=[] 
     43            ytemp=[] 
     44            dytemp=[] 
     45            for data in self.data: 
     46                for i in range(len(data.x)): 
     47                    if not data.x[i] in xtemp: 
     48                        xtemp.append(data.x[i]) 
     49                        
     50                    if not data.y[i] in ytemp: 
     51                        ytemp.append(data.y[i]) 
     52                         
     53                    if not data.dy[i] in dytemp: 
     54                        dytemp.append(data.dy[i]) 
     55            if qmin==None: 
     56                xmin= min(xtemp) 
     57            if qmax==None: 
     58                xmax= max(xtemp)       
     59            #chisqr, out, cov = fitHelper(self.model[0],  
     60            #temp,pars,min(temp.x),max(temp.x)) 
     61            chisqr, out, cov =fitHelper(self.model[0], pars, xtemp, 
     62                                 ytemp, dytemp ,xmin,xmax) 
    3563        return chisqr, out, cov 
    3664     
    3765    def set_model(self,model): 
    3866        """ Set model """ 
    39         self.model = model 
    40          
    41      
    42     def set_data(self,x,y,dx,dy): 
    43         """ Receive values from Loader class and set plottable variables""" 
    44         self.data1.x = x 
    45         self.data1.y = y 
    46         self.data1.dx= dx 
    47         self.data1.dy= dy 
     67        self.model.append(model) 
     68         
     69    def set_data(self,data): 
     70        """ Receive plottable and create a list of data to fit""" 
     71        self.data.append(data) 
     72         
    4873    def get_data(self): 
    49         """ return data""" 
    50         return self.data1 
     74        """ return list of data""" 
     75        return self.data 
    5176     
    5277    def add_contraint(self, contraint): 
    5378        """ User specify contraint to fit """ 
    5479        self.contraint = str(contraint) 
     80         
    5581    def get_contraint(self): 
    5682        """ return the contraint value """ 
    5783        return self.contraint 
    58          
    59      
     84     
     85def get_residuals(model,data,qmin=None,qmax=None): 
     86    """ 
     87        Calculates the vector of residuals for each point  
     88        in y for a given set of input parameters. 
     89        @param params: list of parameter values 
     90        @return: vector of residuals 
     91    """ 
     92    residuals = [] 
     93    
     94    for j in range(len(data.x)): 
     95        if data.x[j]> qmin and data.x[j]< qmax: 
     96            residuals.append( ( data.y[j] - model.runXY(data.x[j]) ) / data.dy[j]) 
     97     
     98    return residuals 
     99 
     100    
     101def chi2(params):  
     102    """ 
     103        Calculates chi^2 
     104        @param params: list of parameter values 
     105        @return: chi^2 
     106    """ 
     107    sum = 0 
     108    res = get_residuals(params) 
     109    for item in res: 
     110        sum += item*item 
     111    return sum  
     112     
     113     
     114    def residual(self): 
     115        return self.residuals 
     116     
     117def fitHelper(model,data,pars,qmin=None,qmax=None): 
     118    """ Do the actual fitting""" 
     119     
     120    p = [param() for param in pars] 
     121    out, cov_x, info, mesg, success = optimize.leastsq(get_residuals, p, full_output=1, warning=True) 
     122    print info, mesg, success 
     123    # Calculate chi squared 
     124    if len(pars)>1: 
     125        chisqr = self.chi2(out) 
     126    elif len(pars)==1: 
     127        chisqr = self.chi2([out]) 
     128         
     129    return chisqr, out, cov_x 
     130 
     131class Parameter: 
     132    """ 
     133        Class to handle model parameters 
     134    """ 
     135    def __init__(self, model, name, value=None): 
     136            self.model = model 
     137            self.name = name 
     138            if not value==None: 
     139                self.model.setParam(self.name, value) 
     140            
     141    def set(self, value): 
     142        """ 
     143            Set the value of the parameter 
     144        """ 
     145        self.model.setParam(self.name, value) 
     146 
     147    def __call__(self): 
     148        """  
     149            Return the current value of the parameter 
     150        """ 
     151        return self.model.getParam(self.name) 
     152     
     153def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None): 
     154    """ 
     155        Fit function 
     156        @param model: sans model object 
     157        @param pars: list of parameters 
     158        @param x: vector of x data 
     159        @param y: vector of y data 
     160        @param err_y: vector of y errors  
     161    """ 
     162    def f(params): 
     163        """ 
     164            Calculates the vector of residuals for each point  
     165            in y for a given set of input parameters. 
     166            @param params: list of parameter values 
     167            @return: vector of residuals 
     168        """ 
     169        i = 0 
     170        for p in pars: 
     171            p.set(params[i]) 
     172            i += 1 
     173         
     174        residuals = [] 
     175        for j in range(len(x)): 
     176            if x[j]>qmin and x[j]<qmax: 
     177                residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] ) 
     178        
     179        return residuals 
     180         
     181    def chi2(params): 
     182        """ 
     183            Calculates chi^2 
     184            @param params: list of parameter values 
     185            @return: chi^2 
     186        """ 
     187        sum = 0 
     188        res = f(params) 
     189        for item in res: 
     190            sum += item*item 
     191        return sum 
     192         
     193    p = [param() for param in pars] 
     194    out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True) 
     195    print info, mesg, success 
     196    # Calculate chi squared 
     197    if len(pars)>1: 
     198        chisqr = chi2(out) 
     199    elif len(pars)==1: 
     200        chisqr = chi2([out]) 
     201         
     202    return chisqr, out, cov_x     
     203 
     204       
     205if __name__ == "__main__":  
     206    load= Load() 
     207     
     208    # test fit one data set one model 
     209    load.set_filename("testdata_line.txt") 
     210    load.set_values() 
     211    data1 = Data1D(x=[], y=[], dx=None,dy=None) 
     212    data1.name = "data1" 
     213    load.load_data(data1) 
     214    Fit =Fitting() 
     215    Fit.set_data(data1) 
     216    from sans.guitools.LineModel import LineModel 
     217    model  = LineModel() 
     218    Fit.set_model(model) 
     219     
     220    default_A = model.getParam('A')  
     221    default_B = model.getParam('B')  
     222    cstA = Parameter(model, 'A', default_A) 
     223    cstB  = Parameter(model, 'B', default_B) 
     224     
     225    chisqr, out, cov=Fit.fit([cstA,cstB],None,None) 
     226    print"fit only one data",chisqr, out, cov  
     227     
     228    # test fit with 2 data and one model 
     229    load.set_filename("testdata1.txt") 
     230    load.set_values() 
     231    data2 = Data1D(x=[], y=[], dx=None,dy=None) 
     232    data2.name = "data2" 
     233     
     234    load.load_data(data2) 
     235    Fit.set_data(data2) 
     236     
     237    load.set_filename("testdata2.txt") 
     238    load.set_values() 
     239    data3 = Data1D(x=[], y=[], dx=None,dy=None) 
     240    data3.name = "data2" 
     241    load.load_data(data3) 
     242    Fit.set_data(data3) 
     243    chisqr, out, cov=Fit.fit([cstA,cstB],None,None) 
     244    print"fit two data",chisqr, out, cov  
Note: See TracChangeset for help on using the changeset viewer.