Changeset fbc51ef in sasview for park_integration/test


Ignore:
Timestamp:
Jun 19, 2008 4:08:51 PM (16 years ago)
Author:
Gervaise Alina <gervyh@…>
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
7681bac
Parents:
3701620
Message:

nothing much changed

File:
1 moved

Legend:

Unmodified
Added
Removed
  • park_integration/test/ParkFitting.py

    r197ea24 rfbc51ef  
    11#class Fitting 
     2import time 
     3 
     4import numpy 
     5import park 
     6from scipy import optimize 
     7from park import fit,fitresult 
     8from park import assembly 
     9 
    210from sans.guitools.plottables import Data1D 
     11#from sans.guitools import plottables 
    312from Loader import Load 
    4 from scipy import optimize 
    5  
    6  
     13 
     14class SansParameter(park.Parameter): 
     15    """ 
     16    SANS model parameters for use in the PARK fitting service. 
     17    The parameter attribute value is redirected to the underlying 
     18    parameter value in the SANS model. 
     19    """ 
     20    def __init__(self, name, model): 
     21         self._model, self._name = model,name 
     22    def _getvalue(self): return self._model.getParam(self.name) 
     23    def _setvalue(self,value): self._model.setParam(self.name, value) 
     24    value = property(_getvalue,_setvalue) 
     25    def _getrange(self): 
     26        lo,hi = self._model.details[self.name][1:] 
     27        if lo is None: lo = -numpy.inf 
     28        if hi is None: hi = numpy.inf 
     29        return lo,hi 
     30    def _setrange(self,r): 
     31        self._model.details[self.name][1:] = r 
     32    range = property(_getrange,_setrange) 
     33 
     34class Model(object): 
     35    """ 
     36        PARK wrapper for SANS models. 
     37    """ 
     38    def __init__(self, sans_model): 
     39        self.model = sans_model 
     40        sansp = sans_model.getParamList() 
     41        parkp = [SansParameter(p,sans_model) for p in sansp] 
     42        self.parameterset = park.ParameterSet(sans_model.name,pars=parkp) 
     43    def eval(self,x): 
     44        return self.model.run(x) 
     45     
     46class Data(object): 
     47    """ Wrapper class  for SANS data """ 
     48    def __init__(self, sans_data): 
     49        self.x= sans_data.x 
     50        self.y= sans_data.y 
     51        self.dx= sans_data.dx 
     52        self.dy= sans_data.dy 
     53        self.qmin=None 
     54        self.qmax=None 
     55        
     56    def setFitRange(self,mini=None,maxi=None): 
     57        """ to set the fit range""" 
     58        self.qmin=mini 
     59        self.qmax=maxi 
     60         
     61    def residuals(self, fn): 
     62         
     63        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)] 
     64        if self.qmin==None and self.qmax==None:  
     65            return (y - fn(x))/dy 
     66         
     67        else: 
     68            idx = x>=self.qmin & x <= self.qmax 
     69            return (y[idx] - fn(x[idx]))/dy[idx] 
     70             
     71          
     72    def residuals_deriv(self, model, pars=[]): 
     73        """ Return residual derivatives .in this case just return empty array""" 
     74        return [] 
     75     
    776class FitArrange: 
    877    def __init__(self): 
     
    43112            self.dList.remove(data) 
    44113             
    45 class Fitting: 
     114class ParkFit: 
    46115    """  
    47116        Performs the Fit.he user determine what kind of data  
     
    55124        self.fitType =None 
    56125         
    57     def fit_engine(self,word): 
     126    def createProblem(self,pars={}): 
    58127        """ 
    59128            Check the contraint value and specify what kind of fit to use 
    60         """ 
    61         self.fitType = word 
    62         return True 
     129            return (M1,D1) 
     130        """ 
     131        mylist=[] 
     132        for k,value in self.fitArrangeList.iteritems(): 
     133            couple=() 
     134            model=value.get_model() 
     135            parameters= self.set_param(model, pars) 
     136            model = Model(model) 
     137            #print "model created",model.parameterset[0].value,model.parameterset[1].value 
     138            # Make all parameters fitting parameters 
     139            for p in model.parameterset: 
     140                p.set([-numpy.inf,numpy.inf]) 
     141                #p.set([-10,10]) 
     142            Ldata=value.get_data() 
     143            data=self._concatenateData(Ldata) 
     144            #print "this data",data 
     145            #print "data.residuals in createProblem",Ldata[0].residuals 
     146            #print "data.residuals in createProblem",data.residuals 
     147            #couple1=(model,Ldata[0]) 
     148            #mylist.append(couple1) 
     149            couple=(model,data) 
     150            mylist.append(couple) 
     151        #print mylist 
     152        return mylist 
     153        #return model,data 
    63154     
    64155    def fit(self,pars, qmin=None, qmax=None): 
     
    66157             Do the fit  
    67158        """ 
    68         #for item in self.fitArrangeList.: 
    69          
    70         fitproblem=self.fitArrangeList.values()[0] 
    71         listdata=[] 
    72         model = fitproblem.get_model() 
    73         listdata = fitproblem.get_data() 
    74          
    75         parameters = self.set_param(model,pars) 
     159         
     160        modelList=self.createProblem(pars) 
     161        #model,data=self.createProblem() 
     162        #fitness=assembly.Fitness(model,data) 
     163         
     164        problem =  park.Assembly(modelList) 
     165        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted 
     166        #problem[0].parameterset['A'].set([0,1000]) 
     167        #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted 
     168        fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1)) 
     169        #return fit.fit(problem) 
     170        #fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1)) 
     171        
     172     
     173    def set_model(self,model,Uid): 
     174        """ Set model """ 
     175         
     176        if self.fitArrangeList.has_key(Uid): 
     177            self.fitArrangeList[Uid].set_model(model) 
     178        else: 
     179            fitproblem= FitArrange() 
     180            fitproblem.set_model(model) 
     181            self.fitArrangeList[Uid]=fitproblem 
     182         
     183    def set_data(self,data,Uid): 
     184        """ Receive plottable and create a list of data to fit""" 
     185        data=Data(data) 
     186        if self.fitArrangeList.has_key(Uid): 
     187            self.fitArrangeList[Uid].add_data(data) 
     188        else: 
     189            fitproblem= FitArrange() 
     190            fitproblem.add_data(data) 
     191            self.fitArrangeList[Uid]=fitproblem 
     192             
     193    def get_model(self,Uid): 
     194        """ return list of data""" 
     195        return self.fitArrangeList[Uid] 
     196     
     197    def set_param(self,model, pars): 
     198        """ Recieve a dictionary of parameter and save it """ 
     199        parameters=[] 
     200        if model==None: 
     201            raise ValueError, "Cannot set parameters for empty model" 
     202        else: 
     203            #for key ,value in pars: 
     204            for key, value in pars.iteritems(): 
     205                param = Parameter(model, key, value) 
     206                parameters.append(param) 
     207        return parameters 
     208     
     209    def add_constraint(self, constraint): 
     210        """ User specify contraint to fit """ 
     211        self.constraint = str(constraint) 
     212         
     213    def get_constraint(self): 
     214        """ return the contraint value """ 
     215        return self.constraint 
     216    
     217    def set_constraint(self,constraint): 
     218        """  
     219            receive a string as a constraint 
     220            @param constraint: a string used to constraint some parameters to get a  
     221                specific value 
     222        """ 
     223        self.constraint= constraint 
     224    def _concatenateData(self, listdata=[]): 
     225        """ concatenate each fields of all Data contains ins listdata 
     226         return data 
     227        """ 
    76228        if listdata==[]: 
    77229            raise ValueError, " data list missing" 
    78230        else: 
    79             # Do the fit with more than one data set and one model  
    80231            xtemp=[] 
    81232            ytemp=[] 
    82233            dytemp=[] 
    83             
     234            resid=[] 
     235            resid_deriv=[] 
     236             
    84237            for data in listdata: 
    85238                for i in range(len(data.x)): 
     
    92245                    if not data.dy[i] in dytemp: 
    93246                        dytemp.append(data.dy[i]) 
    94             if qmin==None: 
    95                 qmin= min(xtemp) 
    96             if qmax==None: 
    97                 qmax= max(xtemp)   
    98             chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax) 
    99             return chisqr, out, cov 
    100      
    101     def set_model(self,model,Uid): 
    102         """ Set model """ 
    103         if self.fitArrangeList.has_key(Uid): 
    104             self.fitArrangeList[Uid].set_model(model) 
    105         else: 
    106             fitproblem= FitArrange() 
    107             fitproblem.set_model(model) 
    108             self.fitArrangeList[Uid]=fitproblem 
    109          
    110     def set_data(self,data,Uid): 
    111         """ Receive plottable and create a list of data to fit""" 
    112          
    113         if self.fitArrangeList.has_key(Uid): 
    114             self.fitArrangeList[Uid].add_data(data) 
    115         else: 
    116             fitproblem= FitArrange() 
    117             fitproblem.add_data(data) 
    118             self.fitArrangeList[Uid]=fitproblem 
    119              
    120     def get_model(self,Uid): 
    121         """ return list of data""" 
    122         return self.fitArrangeList[Uid] 
    123      
    124     def set_param(self,model, pars): 
    125         """ Recieve a dictionary of parameter and save it """ 
    126         parameters=[] 
    127         if model==None: 
    128             raise ValueError, "Cannot set parameters for empty model" 
    129         else: 
    130             #for key ,value in pars: 
    131             for key, value in pars.iteritems(): 
    132                 param = Parameter(model, key, value) 
    133                 parameters.append(param) 
    134         return parameters 
    135      
    136     def add_constraint(self, constraint): 
    137         """ User specify contraint to fit """ 
    138         self.constraint = str(constraint) 
    139          
    140     def get_constraint(self): 
    141         """ return the contraint value """ 
    142         return self.constraint 
    143     
    144     def set_constraint(self,constraint): 
    145         """  
    146             receive a string as a constraint 
    147             @param constraint: a string used to constraint some parameters to get a  
    148                 specific value 
    149         """ 
    150         self.constraint= constraint 
    151      
    152     
    153                  
    154  
     247                     
     248                    
     249            newplottable= Data1D(xtemp,ytemp,None,dytemp) 
     250            newdata=Data(newplottable) 
     251            
     252            #print "this is new data",newdata.dy 
     253            return newdata 
    155254class Parameter: 
    156255    """ 
     
    174273        """ 
    175274        return self.model.getParam(self.name) 
    176  
    177 class FitHelper: 
    178      
    179     def __init__(self,model, pars, x, y, err_y ,qmin=None, qmax=None): 
    180         self.x = x 
    181         self.y = y 
    182         self.model = model 
    183         self.err_y = err_y 
    184         self.qmin = qmin 
    185         self.qmax= qmax 
    186         self.pars = pars 
    187          
    188     def __call__(self, params): 
    189         i = 0 
    190         for p in self.pars: 
    191             p.set(params[i]) 
    192             i += 1 
    193          
    194         residuals = [] 
    195         for j in range(len(self.x)): 
    196             if self.x[j]>self.qmin and self.x[j]<self.qmax: 
    197                 residuals.append( ( self.y[j] - self.model.runXY(self.x[j]) ) / self.err_y[j] ) 
    198         
    199         return residuals 
    200      
    201      
    202  
    203 def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None): 
    204     """ 
    205         Fit function 
    206         @param model: sans model object 
    207         @param pars: list of parameters 
    208         @param x: vector of x data 
    209         @param y: vector of y data 
    210         @param err_y: vector of y errors  
    211     """ 
    212      
    213     f = FitHelper(model, pars, x, y, err_y ,qmin, qmax) 
    214      
    215     def ff(params): 
    216         """ 
    217             Calculates the vector of residuals for each point  
    218             in y for a given set of input parameters. 
    219             @param params: list of parameter values 
    220             @return: vector of residuals 
    221         """ 
    222         i = 0 
    223         for p in pars: 
    224             p.set(params[i]) 
    225             i += 1 
    226          
    227         residuals = [] 
    228         for j in range(len(x)): 
    229             if x[j]>qmin and x[j]<qmax: 
    230                 residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] ) 
    231         
    232         return residuals 
    233          
    234     def chi2(params): 
    235         """ 
    236             Calculates chi^2 
    237             @param params: list of parameter values 
    238             @return: chi^2 
    239         """ 
    240         sum = 0 
    241         res = f(params) 
    242         for item in res: 
    243             sum += item*item 
    244         return sum 
    245          
    246     p = [param() for param in pars] 
    247     out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True) 
    248     print info, mesg, success 
    249     # Calculate chi squared 
    250     if len(pars)>1: 
    251         chisqr = chi2(out) 
    252     elif len(pars)==1: 
    253         chisqr = chi2([out]) 
    254          
    255     return chisqr, out, cov_x     
     275     
    256276 
    257277       
     
    265285    data1.name = "data1" 
    266286    load.load_data(data1) 
    267     Fit =Fitting() 
    268      
    269     from LineModel import LineModel 
     287    fitter =ParkFit() 
     288     
     289    from sans.guitools.LineModel import LineModel 
    270290    model  = LineModel() 
    271     Fit.set_model(model,1) 
    272     Fit.set_data(data1,1) 
    273      
    274     chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None) 
    275     print"fit only one data",chisqr, out, cov  
    276      
    277     # test fit with 2 data and one model 
    278     Fit =Fitting() 
    279     Fit.set_model(model,2 ) 
    280     load.set_filename("testdata1.txt") 
    281     load.set_values() 
    282     data2 = Data1D(x=[], y=[], dx=None,dy=None) 
    283     data2.name = "data2" 
    284      
    285     load.load_data(data2) 
    286     Fit.set_data(data2,2) 
    287      
    288     load.set_filename("testdata2.txt") 
    289     load.set_values() 
    290     data3 = Data1D(x=[], y=[], dx=None,dy=None) 
    291     data3.name = "data2" 
    292     load.load_data(data3) 
    293     Fit.set_data(data3,2) 
    294     chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None) 
    295     print"fit two data",chisqr, out, cov  
    296      
     291    fitter.set_model(model,1) 
     292    fitter.set_data(data1,1) 
     293     
     294    print"PARK fit result \n",fitter.fit({'A':2,'B':1},None,None) 
     295    
     296     
     297    
     298     
Note: See TracChangeset for help on using the changeset viewer.