Changeset 48882d1 in sasview for park_integration


Ignore:
Timestamp:
Aug 22, 2008 5:51:05 PM (16 years ago)
Author:
Gervaise Alina <gervyh@…>
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
3c404d3
Parents:
d6513cd
Message:

park fitting with new model and new data

Location:
park_integration
Files:
2 added
1 deleted
6 edited

Legend:

Unmodified
Added
Removed
  • park_integration/AbstractFitEngine.py

    r985c88b r48882d1  
    11 
     2import park,numpy 
     3 
     4class SansParameter(park.Parameter): 
     5    """ 
     6        SANS model parameters for use in the PARK fitting service. 
     7        The parameter attribute value is redirected to the underlying 
     8        parameter value in the SANS model. 
     9    """ 
     10    def __init__(self, name, model): 
     11         self._model, self._name = model,name 
     12         self.set(model.getParam(name)) 
     13          
     14    def _getvalue(self): return self._model.getParam(self.name) 
     15     
     16    def _setvalue(self,value):  
     17        self._model.setParam(self.name, value) 
     18         
     19    value = property(_getvalue,_setvalue) 
     20     
     21    def _getrange(self): 
     22        lo,hi = self._model.details[self.name][1:] 
     23        if lo is None: lo = -numpy.inf 
     24        if hi is None: hi = numpy.inf 
     25        return lo,hi 
     26     
     27    def _setrange(self,r): 
     28        self._model.details[self.name][1:] = r 
     29    range = property(_getrange,_setrange) 
     30 
     31 
     32class Model(object): 
     33    """ 
     34        PARK wrapper for SANS models. 
     35    """ 
     36    def __init__(self, sans_model): 
     37        self.model = sans_model 
     38        #print "ParkFitting:sans model",self.model 
     39        self.sansp = sans_model.getParamList() 
     40        #print "ParkFitting: sans model parameter list",sansp 
     41        self.parkp = [SansParameter(p,sans_model) for p in self.sansp] 
     42        #print "ParkFitting: park model parameter ",self.parkp 
     43        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp) 
     44        self.pars=[] 
     45         
     46    def getParams(self,fitparams): 
     47        list=[] 
     48        self.pars=[] 
     49        self.pars=fitparams 
     50        for item in fitparams: 
     51            for element in self.parkp: 
     52                 if element.name ==str(item): 
     53                     list.append(element.value) 
     54        #print "abstractfitengine: getparams",list 
     55        return list 
     56     
     57    def setParams(self, params): 
     58        list=[] 
     59        for item in self.parkp: 
     60            list.append(item.name) 
     61        list.sort() 
     62        for i in range(len(params)): 
     63            #self.parkp[i].value = params[i] 
     64            #print "abstractfitengine: set-params",list[i],params[i] 
     65             
     66            self.model.setParam(list[i],params[i]) 
     67   
     68    def eval(self,x): 
     69        #print "eval",self.parameterset[0].value,self.parameterset[1].value 
     70        return self.model.runXY(x) 
     71        
     72 
     73class Data(object): 
     74    """ Wrapper class  for SANS data """ 
     75    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None): 
     76         
     77        if  sans_data !=None: 
     78            self.x= sans_data.x 
     79            self.y= sans_data.y 
     80            self.dx= sans_data.dx 
     81            self.dy= sans_data.dy 
     82            
     83        elif (x!=None and y!=None and dy!=None): 
     84                self.x=x 
     85                self.y=y 
     86                self.dx=dx 
     87                self.dy=dy 
     88        else: 
     89            raise ValueError,\ 
     90            "Data is missing x, y or dy, impossible to compute residuals later on" 
     91        self.qmin=None 
     92        self.qmax=None 
     93        
     94    def setFitRange(self,mini=None,maxi=None): 
     95        """ to set the fit range""" 
     96        self.qmin=mini 
     97        self.qmax=maxi 
     98    def getFitRange(self): 
     99         return self.qmin, self.qmax 
     100    def residuals(self, fn): 
     101        """ @param fn: function that return model value 
     102            @return residuals 
     103        """ 
     104        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)] 
     105        if self.qmin==None and self.qmax==None:  
     106            fx =[fn(v) for v in x] 
     107            return (y - fx)/dy 
     108        else: 
     109            idx = (x>=self.qmin) & (x <= self.qmax) 
     110            fx = [fn(item)for item in x[idx ]] 
     111            return (y[idx] - fx)/dy[idx] 
     112           
     113             
     114          
     115    def residuals_deriv(self, model, pars=[]): 
     116        """  
     117            @return residuals derivatives . 
     118            @note: in this case just return empty array 
     119        """ 
     120        return [] 
     121     
     122class sansAssembly: 
     123    def __init__(self,Model=None , Data=None): 
     124       self.model = Model 
     125       self.data  = Data 
     126       self.res=[] 
     127    def chisq(self, params): 
     128        """ 
     129            Calculates chi^2 
     130            @param params: list of parameter values 
     131            @return: chi^2 
     132        """ 
     133        sum = 0 
     134        for item in self.res: 
     135            sum += item*item 
     136        return sum 
     137    def __call__(self,params): 
     138        self.model.setParams(params) 
     139        self.res= self.data.residuals(self.model.eval) 
     140        return self.res 
     141     
    2142class FitEngine: 
    3143    def __init__(self): 
     
    8148            @param listdata: list of data  
    9149             
    10             @return xtemp, ytemp,dytemp:  x,y,dy respectively of data all combined 
    11                 if xi,yi,dyi of two or more data are the same the second appearance of xi,yi, 
    12                 dyi is ignored in the concatenation. 
     150            @return Data: 
    13151                 
    14152            @raise: if listdata is empty  will return None 
     
    22160            ytemp=[] 
    23161            dytemp=[] 
     162            self.mini=None 
     163            self.maxi=None 
    24164                
    25165            for data in listdata: 
     166                mini,maxi=data.getFitRange() 
     167                if self.mini==None and self.maxi==None: 
     168                    self.mini=mini 
     169                    self.maxi=maxi 
     170                else: 
     171                    if mini < self.mini: 
     172                        self.mini=mini 
     173                    if self.maxi < maxi: 
     174                        self.maxi=maxi 
     175                         
     176                     
    26177                for i in range(len(data.x)): 
    27178                    xtemp.append(data.x[i]) 
     
    31182                    else: 
    32183                        raise RuntimeError, "Fit._concatenateData: y-errors missing" 
    33             return xtemp, ytemp,dytemp 
    34      
     184            #return xtemp, ytemp,dytemp 
     185            data= Data(x=xtemp,y=ytemp,dy=dytemp) 
     186            data.setFitRange(self.mini, self.maxi) 
     187            return data 
    35188    def set_model(self,model,name,Uid,pars=[]): 
    36         """  
    37        
    38             Receive a dictionary of parameter and save it Parameter list 
    39             For scipy.fit use. 
    40             Set model in a FitArrange object and add that object in a dictionary 
    41             with key Uid. 
    42             @param model: model on with parameter values are set 
    43             @param name: model name 
    44             @param Uid: unique key corresponding to a fitArrange object with model 
    45             @param pars: dictionary of paramaters name and value 
    46             pars={parameter's name: parameter's value} 
    47              
    48         """ 
    49         print "AbstractFitEngine:  fitting parmater",pars 
    50         
    51189        if len(pars) >0: 
    52             self.parameters=[] 
     190            self.paramList = [] 
    53191            if model==None: 
    54192                raise ValueError, "AbstractFitEngine: Specify parameters to fit" 
    55193            else: 
    56                 model.name=name 
    57                 for param_name in pars: 
    58                     value=model.getParam(param_name) 
    59                     if value==None: 
    60                         raise ValueError ,"%s has not value set"%param_name 
    61                     param = Parameter(model,param_name,value) 
    62                     self.parameters.append(param) 
    63                     
    64                     self.paramList.append(param_name) 
    65             print "AbstractFitEngine: self.paramList2", self.paramList 
     194                model.name = name 
     195                self.paramList=pars 
    66196            #A fitArrange is already created but contains dList only at Uid 
    67197            if self.fitArrangeList.has_key(Uid): 
     
    69199            else: 
    70200            #no fitArrange object has been create with this Uid 
    71                 fitproblem= FitArrange() 
     201                fitproblem = FitArrange() 
    72202                fitproblem.set_model(model) 
    73                 self.fitArrangeList[Uid]=fitproblem 
     203                self.fitArrangeList[Uid] = fitproblem 
    74204        else: 
    75205            raise ValueError, "park_integration:missing parameters" 
    76          
    77          
    78     def set_data(self,data,Uid): 
     206     
     207    def set_data(self,data,Uid,qmin=None,qmax=None): 
    79208        """ Receives plottable, creates a list of data to fit,set data 
    80209            in a FitArrange object and adds that object in a dictionary  
     
    83212            @param Uid: unique key corresponding to a fitArrange object with data 
    84213            """ 
     214        if qmin !=None and qmax !=None: 
     215            data.setFitRange(mini=qmin,maxi=qmax) 
    85216        #A fitArrange is already created but contains model only at Uid 
    86217        if self.fitArrangeList.has_key(Uid): 
     
    90221            fitproblem= FitArrange() 
    91222            fitproblem.add_data(data) 
    92             self.fitArrangeList[Uid]=fitproblem 
    93              
     223            self.fitArrangeList[Uid]=fitproblem     
     224    
    94225    def get_model(self,Uid): 
    95226        """  
     
    107238        if self.fitArrangeList.has_key(Uid): 
    108239            del self.fitArrangeList[Uid] 
    109              
    110        
    111     
    112     
    113 class Parameter: 
    114     """ 
    115         Class to handle model parameters 
    116     """ 
    117     def __init__(self, model, name, value=None): 
    118             self.model = model 
    119             self.name = name 
    120             if not value==None: 
    121                 self.model.setParam(self.name, value) 
    122             
    123     def set(self, value): 
    124         """ 
    125             Set the value of the parameter 
    126         """ 
    127         self.model.setParam(self.name, value) 
    128  
    129     def __call__(self): 
    130         """  
    131             Return the current value of the parameter 
    132         """ 
    133         return self.model.getParam(self.name) 
     240 
    134241     
    135242class FitArrange: 
  • park_integration/Fitting.py

    r9855699 r48882d1  
    5353        """Perform the fit """ 
    5454        return self._engine.fit(qmin,qmax) 
    55     def set_model(self,model,name,Uid,pars={}): 
    56         """ Set model """ 
    57         self._engine.set_model(model,name,Uid, pars) 
    58     def set_data(self,data,Uid): 
    59         """ Receive plottable and create a list of data to fit""" 
    60         self._engine.set_data(data,Uid) 
     55    def set_model(self,model,name,Uid,pars=[]): 
     56         self._engine.set_model(model,name,Uid,pars) 
     57    
     58    def set_data(self,data,Uid,qmin=None, qmax=None): 
     59        self._engine.set_data(data,Uid,qmin,qmax) 
    6160    def get_model(self,Uid): 
    6261        """ return list of data""" 
  • park_integration/ParkFitting.py

    ree5b04c r48882d1  
    66import time 
    77import numpy 
    8  
    98import park 
    109from park import fit,fitresult 
    1110from park import assembly 
    1211from park.fitmc import FitSimplex, FitMC 
    13  
    1412from sans.guitools.plottables import Data1D 
    1513from Loader import Load 
    16 from AbstractFitEngine import FitEngine, Parameter, FitArrange 
    17 class SansParameter(park.Parameter): 
    18     """ 
    19         SANS model parameters for use in the PARK fitting service. 
    20         The parameter attribute value is redirected to the underlying 
    21         parameter value in the SANS model. 
    22     """ 
    23     def __init__(self, name, model): 
    24          self._model, self._name = model,name 
    25          self.set(model.getParam(name)) 
    26           
    27     def _getvalue(self): return self._model.getParam(self.name) 
    28      
    29     def _setvalue(self,value):  
    30         self._model.setParam(self.name, value) 
    31          
    32     value = property(_getvalue,_setvalue) 
    33      
    34     def _getrange(self): 
    35         lo,hi = self._model.details[self.name][1:] 
    36         if lo is None: lo = -numpy.inf 
    37         if hi is None: hi = numpy.inf 
    38         return lo,hi 
    39      
    40     def _setrange(self,r): 
    41         self._model.details[self.name][1:] = r 
    42     range = property(_getrange,_setrange) 
     14from AbstractFitEngine import FitEngine,FitArrange,Model 
    4315 
    44  
    45 class Model(object): 
    46     """ 
    47         PARK wrapper for SANS models. 
    48     """ 
    49     def __init__(self, sans_model): 
    50         self.model = sans_model 
    51         #print "ParkFitting:sans model",self.model 
    52         sansp = sans_model.getParamList() 
    53         #print "ParkFitting: sans model parameter list",sansp 
    54         parkp = [SansParameter(p,sans_model) for p in sansp] 
    55         #print "ParkFitting: park model parameter ",parkp 
    56         self.parameterset = park.ParameterSet(sans_model.name,pars=parkp) 
    57          
    58     def eval(self,x): 
    59         #print "eval",self.parameterset[0].value,self.parameterset[1].value 
    60         #print "model run ",self.model.run(x) 
    61         return self.model.run(x) 
    62      
    63 class Data(object): 
    64     """ Wrapper class  for SANS data """ 
    65     def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None): 
    66         if not sans_data==None: 
    67             self.x= sans_data.x 
    68             self.y= sans_data.y 
    69             self.dx= sans_data.dx 
    70             self.dy= sans_data.dy 
    71         else: 
    72             if x!=None and y!=None and dy!=None: 
    73                 self.x=x 
    74                 self.y=y 
    75                 self.dx=dx 
    76                 self.dy=dy 
    77             else: 
    78                 raise ValueError,\ 
    79                 "Data is missing x, y or dy, impossible to compute residuals later on" 
    80         self.qmin=None 
    81         self.qmax=None 
    82         
    83     def setFitRange(self,mini=None,maxi=None): 
    84         """ to set the fit range""" 
    85         self.qmin=mini 
    86         self.qmax=maxi 
    87          
    88     def residuals(self, fn): 
    89         """ @param fn: function that return model value 
    90             @return residuals 
    91         """ 
    92         x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)] 
    93         if self.qmin==None and self.qmax==None:  
    94             self.fx = fn(x) 
    95             return (y - fn(x))/dy 
    96          
    97         else: 
    98             self.fx = fn(x[idx]) 
    99             idx = x>=self.qmin & x <= self.qmax 
    100             return (y[idx] - fn(x[idx]))/dy[idx] 
    101              
    102           
    103     def residuals_deriv(self, model, pars=[]): 
    104         """  
    105             @return residuals derivatives . 
    106             @note: in this case just return empty array 
    107         """ 
    108         return [] 
    109  
    110              
    11116class ParkFit(FitEngine): 
    11217    """  
     
    15459        i=0 
    15560        for k,value in self.fitArrangeList.iteritems(): 
    156             sansmodel=value.get_model() 
     61            #sansmodel=value.get_model() 
    15762            #wrap sans model 
    158             parkmodel = Model(sansmodel) 
     63            #parkmodel = Model(sansmodel) 
     64            parkmodel = value.get_model() 
    15965            #print "ParkFitting: createproblem: just create a model",parkmodel.parameterset 
    16066            for p in parkmodel.parameterset: 
     
    16268                #if p.isfixed(): 
    16369                #print 'parameters',p.name 
    164                 #print "self.paramList",self.paramList 
     70                print "parkfitting: self.paramList",self.paramList 
    16571                if p.isfixed() and p._getname()in self.paramList: 
    16672                    p.set([-numpy.inf,numpy.inf]) 
    16773            i+=1     
    16874            Ldata=value.get_data() 
    169             x,y,dy=self._concatenateData(Ldata) 
    170             #wrap sansdata 
    171             parkdata=Data(x,y,dy,None) 
     75            parkdata=self._concatenateData(Ldata) 
     76             
    17277            couple=(parkmodel,parkdata) 
    17378            #print "Parkfitting: fitness",couple    
     
    204109        localfit.ftol = 1e-8 
    205110        fitter = FitMC(localfit=localfit) 
    206          
     111        print "ParkFitting: result1" 
    207112        result = fit.fit(self.problem, 
    208113                     fitter=fitter, 
     
    212117            #for p in result.parameters: 
    213118            #    print "fit in park fitting", p.name, p.value,p.stderr 
    214             return result.fitness,result.pvec,result.cov,result 
     119            #return result.fitness,result.pvec,result.cov,result 
     120            return result 
    215121        else: 
    216122            raise ValueError, "SVD did not converge" 
  • park_integration/ScipyFitting.py

    ree5b04c r48882d1  
    44    simple fit with scipy optimizer. 
    55""" 
     6#import scipy.linalg 
     7import numpy  
    68from sans.guitools.plottables import Data1D 
    79from Loader import Load 
    810from scipy import optimize 
    9 from AbstractFitEngine import FitEngine, Parameter 
    10 from AbstractFitEngine import FitArrange 
    1111 
     12from AbstractFitEngine import FitEngine, sansAssembly 
     13from AbstractFitEngine import FitArrange,Data 
     14class fitresult: 
     15    """ 
     16        Storing fit result 
     17    """ 
     18    calls     = None 
     19    fitness   = None 
     20    chisqr    = None 
     21    pvec      = None 
     22    cov       = None 
     23    info      = None 
     24    mesg      = None 
     25    success   = None 
     26    stderr    = None 
     27    parameters= None 
     28     
    1229class ScipyFit(FitEngine): 
    1330    """  
     
    4259        self.fitArrangeList={} 
    4360        self.paramList=[] 
    44          
    4561    def fit(self,qmin=None, qmax=None): 
    46         """ 
    47             Performs fit with scipy optimizer.It can only perform fit with one model 
    48             and a set of data. 
    49             @note: Cannot perform more than one fit at the time. 
    50              
    51             @param pars: Dictionary of parameter names for the model and their values 
    52             @param qmin: The minimum value of data's range to be fit 
    53             @param qmax: The maximum value of data's range to be fit 
    54             @return chisqr: Value of the goodness of fit metric 
    55             @return out: list of parameter with the best value found during fitting 
    56             @return cov: Covariance matrix 
    57         """ 
    58         # Protect against simultanous fitting attempts 
     62         # Protect against simultanous fitting attempts 
    5963        if len(self.fitArrangeList)>1:  
    6064            raise RuntimeError, "Scipy can't fit more than a single fit problem at a time." 
     
    6670        listdata = fitproblem.get_data() 
    6771        # Concatenate dList set (contains one or more data)before fitting 
    68         xtemp,ytemp,dytemp=self._concatenateData( listdata) 
     72        data=self._concatenateData( listdata) 
    6973        #Assign a fit range is not boundaries were given 
    7074        if qmin==None: 
    71             qmin= min(xtemp) 
     75            qmin= min(data.x) 
    7276        if qmax==None: 
    73             qmax= max(xtemp)  
    74         #perform the fit  
    75         chisqr, out, cov = fitHelper(model,self.parameters, xtemp,ytemp, dytemp ,qmin,qmax) 
    76         return chisqr, out, cov 
     77            qmax= max(data.x)  
     78        functor= sansAssembly(model,data) 
     79        print "scipyfitting:param list",model.getParams(self.paramList) 
     80        print "scipyfitting:functor",functor(model.getParams(self.paramList)) 
    7781     
    78  
    79 def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None): 
    80     """ 
    81         Fit function 
    82         @param model: sans model object 
    83         @param pars: list of parameters 
    84         @param x: vector of x data 
    85         @param y: vector of y data 
    86         @param err_y: vector of y errors  
    87         @return chisqr: Value of the goodness of fit metric 
    88         @return out: list of parameter with the best value found during fitting 
    89         @return cov: Covariance matrix 
    90     """ 
    91     def f(params): 
    92         """ 
    93             Calculates the vector of residuals for each point  
    94             in y for a given set of input parameters. 
    95             @param params: list of parameter values 
    96             @return: vector of residuals 
    97         """ 
    98         i = 0 
    99         for p in pars: 
    100             p.set(params[i]) 
    101             i += 1 
     82        out, cov_x, info, mesg, success = optimize.leastsq(functor,model.getParams(self.paramList), full_output=1, warning=True) 
     83        chisqr = functor.chisq(out) 
    10284         
    103         residuals = [] 
    104         for j in range(len(x)): 
    105             if x[j] >= qmin and x[j] <= qmax: 
    106                 residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] ) 
     85        print "scipyfitting: info",mesg 
     86        print"scipyfitting : success",success 
     87        print "scipyfitting: out", out 
     88        print "scipyfitting: cov_x", cov_x 
     89        print "scipyfitting: chisqr", chisqr 
     90         
     91        if not (numpy.isnan(out).any()): 
     92                result = fitresult() 
     93                result.fitness = chisqr 
     94                result.cov  = cov_x 
     95                 
     96                result.pvec = out 
     97                result.success =success 
     98                
     99                return result 
     100        else:   
     101            raise ValueError, "SVD did not converge" 
     102         
     103        
     104               
    107105             
    108         return residuals 
    109          
    110     def chi2(params): 
    111         """ 
    112             Calculates chi^2 
    113             @param params: list of parameter values 
    114             @return: chi^2 
    115         """ 
    116         sum = 0 
    117         res = f(params) 
    118         for item in res: 
    119             sum += item*item 
    120         return sum 
    121          
    122     p = [param() for param in pars] 
    123     out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True) 
    124     #print info, mesg, success 
    125     # Calculate chi squared 
    126     if len(pars)>1: 
    127         chisqr = chi2(out) 
    128     elif len(pars)==1: 
    129         chisqr = chi2([out]) 
    130          
    131     return chisqr, out, cov_x     
    132  
     106       
  • park_integration/test/test_large_model.py

    rf44dbc7 r48882d1  
    55from sans.guitools.plottables import Theory1D 
    66from sans.guitools.plottables import Data1D 
    7  
     7from sans.fit.AbstractFitEngine import Data,Model 
    88import math 
    99class testFitModule(unittest.TestCase): 
    1010    """ test fitting """ 
    11      
    12        
    13     def testfit_11Data_1Model(self): 
    14         """ test fitting for one data and one model park vs scipy""" 
     11    def test_cylinder_park(self): 
     12        """ test fitting large model with park""" 
    1513        #load data 
    1614        from sans.fit.Loader import Load 
     
    1816        load.set_filename("cyl_testdata.txt") 
    1917        load.set_values() 
    20         data1 = Data1D(x=[], y=[],dx=None, dy=None) 
    21         load.load_data(data1) 
     18        data11 = Data1D(x=[], y=[],dx=None, dy=None) 
     19        load.load_data(data11) 
     20        data1=Data(sans_data=data11) 
    2221         
    23         load.set_filename("testdata_line1.txt") 
    24         load.set_values() 
    25         data2 = Data1D(x=[], y=[],dx=None, dy=None) 
    26         load.load_data(data2) 
    2722         
    2823        #Importing the Fit module 
     
    3328        from sans.models.CylinderModel import CylinderModel 
    3429        model1  = CylinderModel() 
    35         #model2  = CylinderModel() 
     30        model =Model(model1) 
    3631         
    3732        #Do the fit SCIPY 
    3833        fitter.set_data(data1,1) 
    3934        import math 
    40         pars1={'background':0,'contrast':3*math.pow(10, -6),\ 
    41                 'cyl_phi':1,'cyl_theta':1,'length':400,'radius':20,'scale':1} 
    42         fitter.set_model(model1,"M1",1,pars1) 
     35        #pars1=['background','contrast', 'length'] 
     36        pars1=['background','contrast',\ 
     37                'cyl_phi','cyl_theta','length','radius','scale'] 
     38        pars1.sort() 
     39        fitter.set_model(model,"M1",1,pars1) 
     40        fitter.set_data(data1,1) 
     41       
     42        result=fitter.fit() 
     43        print "park",result.fitness,result.cov, result.pvec 
     44        self.assert_(result.fitness) 
    4345         
    44         #fitter.set_data(data2,2) 
    45         #fitter.set_model(model1,"M1",2,pars1) 
     46      
     47    def test_cylinder_scipy(self): 
     48        """ test fitting large model with scipy""" 
     49        #load data 
     50        from sans.fit.Loader import Load 
     51        load= Load() 
     52        load.set_filename("cyl_testdata.txt") 
     53        load.set_values() 
     54        data11 = Data1D(x=[], y=[],dx=None, dy=None) 
     55        load.load_data(data11) 
     56        data1=Data(sans_data=data11) 
    4657         
    47         chisqr1, out1, cov1=fitter.fit() 
    48         print "park",chisqr1, out1, cov1 
    49         self.assert_(chisqr1) 
    5058         
    51         
     59        #Importing the Fit module 
     60        from sans.fit.Fitting import Fit 
     61        fitter= Fit('scipy') 
     62         
     63        # Receives the type of model for the fitting 
     64        from sans.models.CylinderModel import CylinderModel 
     65        model1  = CylinderModel() 
     66        model =Model(model1) 
     67         
     68        #Do the fit SCIPY 
     69        fitter.set_data(data1,1) 
     70        import math 
     71        #pars1=['background','contrast', 'length'] 
     72        pars1=['background','contrast',\ 
     73                'cyl_phi','cyl_theta','length','radius','scale'] 
     74        pars1.sort() 
     75        fitter.set_model(model,"M1",1,pars1) 
     76        fitter.set_data(data1,1) 
    5277       
     78        result=fitter.fit() 
     79        print "scipy",result.fitness,result.cov, result.pvec 
     80        self.assert_(result.fitness) 
     81         
     82     
  • park_integration/test/testfitting.py

    r985c88b r48882d1  
    55from sans.guitools.plottables import Theory1D 
    66from sans.guitools.plottables import Data1D 
    7  
     7from sans.fit.AbstractFitEngine import Data, Model 
    88import math 
    99class testFitModule(unittest.TestCase): 
     
    1515        from sans.fit.Loader import Load 
    1616        load= Load() 
    17          
    1817        load.set_filename("testdata_line.txt") 
    1918        self.assertEqual(load.get_filename(),"testdata_line.txt") 
     
    2322        dx=[] 
    2423        dy=[] 
    25          
    2624        x,y,dx,dy = load.get_values() 
    27          
    2825        # test that values have been loaded 
    2926        self.assertNotEqual(x, None) 
     
    5956        load.set_filename("testdata_line.txt") 
    6057        load.set_values() 
    61         data1 = Data1D(x=[], y=[],dx=None, dy=None) 
    62         load.load_data(data1) 
     58        data11 = Data1D(x=[], y=[],dx=None, dy=None) 
     59        load.load_data(data11) 
    6360         
    6461        #Importing the Fit module 
     
    6865        # Receives the type of model for the fitting 
    6966        from sans.guitools.LineModel import LineModel 
    70         model1  = LineModel() 
    71         model2  = LineModel() 
     67        model11  = LineModel() 
     68        model22  = LineModel() 
    7269         
    7370        #Do the fit SCIPY 
    74         model1.setParam( 'A', 2) 
    75         model1.setParam( 'B', 1) 
     71        model11.setParam( 'A', 2) 
     72        model11.setParam( 'B', 1) 
     73        data1=Data(sans_data=data11) 
     74        model1 =Model(model11) 
     75        model2 =Model(model22) 
     76       
    7677        fitter.set_data(data1,1) 
    7778        fitter.set_model(model1,"M1",1,['A','B']) 
    7879         
    79         chisqr1, out1, cov1=fitter.fit() 
     80        result= fitter.fit() 
     81        out1=result.pvec 
     82        chisqr1=result.fitness 
     83        cov1=result.cov 
     84        print "scipy",chisqr1, out1, cov1 
    8085        """ testing SCIPy results""" 
    8186        self.assert_(math.fabs(out1[1]-2.5)/math.sqrt(cov1[1][1]) < 2) 
     
    8792        #Do the fit 
    8893        fitter.set_data(data1,1) 
    89         model2.setParam( 'A', 2) 
    90         model2.setParam( 'B', 1) 
     94        model2.setParams( [2,1]) 
     95     
    9196        fitter.set_model(model2,"M1",1,['A','B']) 
    9297        
    93         chisqr2, out2, cov2,result=fitter.fit(None,None) 
    94          
     98        result2=fitter.fit(None,None) 
     99        out2=result2.pvec 
     100        chisqr2=result2.fitness 
     101        cov2=result2.cov 
    95102        self.assert_(math.fabs(out2[1]-2.5)/math.sqrt(cov2[1][1]) < 2) 
    96103        self.assert_(math.fabs(out2[0]-4.0)/math.sqrt(cov2[0][0]) < 2) 
     
    103110        self.assertAlmostEquals(cov1[1][1], cov2[1][1],1) 
    104111        self.assertAlmostEquals(chisqr1, chisqr2) 
     112          
     113    def testfit_1Data_1Model(self): 
     114        """ test fitting for one data and one model cipy""" 
     115        #load data 
     116        from sans.fit.Loader import Load 
     117        load= Load() 
     118        load.set_filename("testdata_line.txt") 
     119        load.set_values() 
     120        data11 = Data1D(x=[], y=[],dx=None, dy=None) 
     121        load.load_data(data11) 
     122        data1=Data(sans_data=data11) 
     123     
     124        #Importing the Fit module 
     125        from sans.fit.Fitting import Fit 
     126        fitter= Fit('scipy') 
     127         
     128        # Receives the type of model for the fitting 
     129        from sans.guitools.LineModel import LineModel 
     130        model1  = LineModel() 
     131        model =Model(model1) 
     132         
     133        #Do the fit SCIPY 
     134        fitter.set_data(data1,1) 
     135        import math 
    105136        
     137        pars1=['A','B'] 
     138        pars1.sort() 
     139        fitter.set_model(model,"M1",1,pars1) 
     140        result=fitter.fit() 
     141        print "scipy",result.fitness,result.cov, result.pvec 
     142     
     143        self.assert_(result.fitness) 
     144         
     145        
     146     
    106147       
Note: See TracChangeset for help on using the changeset viewer.