Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/AbstractFitEngine.py

    r6c00702 r8d074d9  
    33#import logging 
    44import sys 
     5import math 
    56import numpy 
    6 import math 
    7 import park 
     7 
    88from sans.dataloader.data_info import Data1D 
    99from sans.dataloader.data_info import Data2D 
    10 _SMALLVALUE = 1.0e-10     
    11      
    12 class SansParameter(park.Parameter): 
    13     """ 
    14     SANS model parameters for use in the PARK fitting service. 
    15     The parameter attribute value is redirected to the underlying 
    16     parameter value in the SANS model. 
    17     """ 
    18     def __init__(self, name, model, data): 
    19         """ 
    20             :param name: the name of the model parameter 
    21             :param model: the sans model to wrap as a park model 
    22         """ 
    23         park.Parameter.__init__(self, name) 
    24         self._model, self._name = model, name 
    25         self.data = data 
    26         self.model = model 
    27         #set the value for the parameter of the given name 
    28         self.set(model.getParam(name)) 
    29           
    30     def _getvalue(self): 
    31         """ 
    32         override the _getvalue of park parameter 
    33          
    34         :return value the parameter associates with self.name 
    35          
    36         """ 
    37         return self._model.getParam(self.name) 
    38      
    39     def _setvalue(self, value): 
    40         """ 
    41         override the _setvalue pf park parameter 
    42          
    43         :param value: the value to set on a given parameter 
    44          
    45         """ 
    46         self._model.setParam(self.name, value) 
    47          
    48     value = property(_getvalue, _setvalue) 
    49      
    50     def _getrange(self): 
    51         """ 
    52         Override _getrange of park parameter 
    53         return the range of parameter 
    54         """ 
    55         #if not  self.name in self._model.getDispParamList(): 
    56         lo, hi = self._model.details[self.name][1:3] 
    57         if lo is None: lo = -numpy.inf 
    58         if hi is None: hi = numpy.inf 
    59         if lo > hi: 
    60             raise ValueError, "wrong fit range for parameters" 
    61          
    62         return lo, hi 
    63      
    64     def get_name(self): 
    65         """ 
    66         """ 
    67         return self._getname() 
    68      
    69     def _setrange(self, r): 
    70         """ 
    71         override _setrange of park parameter 
    72          
    73         :param r: the value of the range to set 
    74          
    75         """ 
    76         self._model.details[self.name][1:3] = r 
    77     range = property(_getrange, _setrange) 
    78      
    79      
    80 class Model(park.Model): 
    81     """ 
    82     PARK wrapper for SANS models. 
     10_SMALLVALUE = 1.0e-10 
     11 
     12# Note: duplicated from park 
     13class FitHandler(object): 
     14    """ 
     15    Abstract interface for fit thread handler. 
     16 
     17    The methods in this class are called by the optimizer as the fit 
     18    progresses. 
     19 
     20    Note that it is up to the optimizer to call the fit handler correctly, 
     21    reporting all status changes and maintaining the 'done' flag. 
     22    """ 
     23    done = False 
     24    """True when the fit job is complete""" 
     25    result = None 
     26    """The current best result of the fit""" 
     27 
     28    def improvement(self): 
     29        """ 
     30        Called when a result is observed which is better than previous 
     31        results from the fit. 
     32 
     33        result is a FitResult object, with parameters, #calls and fitness. 
     34        """ 
     35    def error(self, msg): 
     36        """ 
     37        Model had an error; print traceback 
     38        """ 
     39    def progress(self, current, expected): 
     40        """ 
     41        Called each cycle of the fit, reporting the current and the 
     42        expected amount of work.   The meaning of these values is 
     43        optimizer dependent, but they can be converted into a percent 
     44        complete using (100*current)//expected. 
     45 
     46        Progress is updated each iteration of the fit, whatever that 
     47        means for the particular optimization algorithm.  It is called 
     48        after any calls to improvement for the iteration so that the 
     49        update handler can control I/O bandwidth by suppressing 
     50        intermediate improvements until the fit is complete. 
     51        """ 
     52    def finalize(self): 
     53        """ 
     54        Fit is complete; best results are reported 
     55        """ 
     56    def abort(self): 
     57        """ 
     58        Fit was aborted. 
     59        """ 
     60 
     61    # TODO: not sure how these are used, but they are needed for running the fit 
     62    def update_fit(self, last=False): pass 
     63    def set_result(self, result=None): self.result = result 
     64 
     65class Model: 
     66    """ 
     67    Fit wrapper for SANS models. 
    8368    """ 
    8469    def __init__(self, sans_model, sans_data=None, **kw): 
    8570        """ 
    8671        :param sans_model: the sans model to wrap using park interface 
    87          
    88         """ 
    89         park.Model.__init__(self, **kw) 
     72 
     73        """ 
    9074        self.model = sans_model 
    9175        self.name = sans_model.name 
    9276        self.data = sans_data 
    93         #list of parameters names 
    94         self.sansp = sans_model.getParamList() 
    95         #list of park parameter 
    96         self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp] 
    97         #list of parameter set 
    98         self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp) 
    99         self.pars = [] 
    100    
     77 
    10178    def get_params(self, fitparams): 
    10279        """ 
    10380        return a list of value of paramter to fit 
    104          
     81 
    10582        :param fitparams: list of paramaters name to fit 
    106          
    107         """ 
    108         list_params = [] 
    109         self.pars = [] 
    110         self.pars = fitparams 
    111         for item in fitparams: 
    112             for element in self.parkp: 
    113                 if element.name == str(item): 
    114                     list_params.append(element.value) 
    115         return list_params 
    116      
     83 
     84        """ 
     85        return [self.model.getParam(k) for k in fitparams] 
     86 
    11787    def set_params(self, paramlist, params): 
    11888        """ 
    11989        Set value for parameters to fit 
    120          
     90 
    12191        :param params: list of value for parameters to fit 
    122          
    123         """ 
    124         try: 
    125             for i in range(len(self.parkp)): 
    126                 for j in range(len(paramlist)): 
    127                     if self.parkp[i].name == paramlist[j]: 
    128                         self.parkp[i].value = params[j] 
    129                         self.model.setParam(self.parkp[i].name, params[j]) 
    130         except: 
    131             raise 
    132    
     92 
     93        """ 
     94        for k,v in zip(paramlist, params): 
     95            self.model.setParam(k,v) 
     96 
     97    def set(self, **kw): 
     98        self.set_params(*zip(*kw.items())) 
     99 
    133100    def eval(self, x): 
    134101        """ 
    135102            Override eval method of park model. 
    136          
     103 
    137104            :param x: the x value used to compute a function 
    138105        """ 
     
    141108        except: 
    142109            raise 
    143          
     110 
    144111    def eval_derivs(self, x, pars=[]): 
    145112        """ 
     
    154121        instead of calling eval. 
    155122        """ 
    156         return [] 
    157  
    158      
     123        raise NotImplementedError('no derivatives available') 
     124 
     125    def __call__(self, x): 
     126        return self.eval(x) 
     127 
    159128class FitData1D(Data1D): 
    160129    """ 
     
    185154        """ 
    186155        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy) 
     156        self.num_points = len(x) 
    187157        self.sans_data = data 
    188158        self.smearer = smearer 
     
    251221        """ 
    252222        return self.qmin, self.qmax 
    253          
     223 
     224    def size(self): 
     225        """ 
     226        Number of measurement points in data set after masking, etc. 
     227        """ 
     228        return len(self.x) 
     229 
    254230    def residuals(self, fn): 
    255231        """ 
     
    293269    def __init__(self, sans_data2d, data=None, err_data=None): 
    294270        Data2D.__init__(self, data=data, err_data=err_data) 
    295         """ 
    296             Data can be initital with a data (sans plottable) 
    297             or with vectors. 
    298         """ 
     271        # Data can be initialized with a sans plottable or with vectors. 
    299272        self.res_err_image = [] 
     273        self.num_points = 0 # will be set by set_data 
    300274        self.idx = [] 
    301275        self.qmin = None 
     
    339313        self.idx = (self.idx) & (self.mask) 
    340314        self.idx = (self.idx) & (numpy.isfinite(self.data)) 
     315        self.num_points = numpy.sum(self.idx) 
    341316 
    342317    def set_smearer(self, smearer): 
     
    372347        """ 
    373348        return self.qmin, self.qmax 
    374       
     349 
     350    def size(self): 
     351        """ 
     352        Number of measurement points in data set after masking, etc. 
     353        """ 
     354        return numpy.sum(self.idx) 
     355 
    375356    def residuals(self, fn): 
    376357        """ 
     
    409390 
    410391 
    411 class SansAssembly: 
    412     """ 
    413     Sans Assembly class a class wrapper to be call in optimizer.leastsq method 
    414     """ 
    415     def __init__(self, paramlist, model=None, data=None, fitresult=None, 
    416                  handler=None, curr_thread=None, msg_q=None): 
    417         """ 
    418         :param Model: the model wrapper fro sans -model 
    419         :param Data: the data wrapper for sans data 
    420          
    421         """ 
    422         self.model = model 
    423         self.data = data 
    424         self.paramlist = paramlist 
    425         self.msg_q = msg_q 
    426         self.curr_thread = curr_thread 
    427         self.handler = handler 
    428         self.fitresult = fitresult 
    429         self.res = [] 
    430         self.true_res = [] 
    431         self.func_name = "Functor" 
    432         self.theory = None 
    433          
    434     def chisq(self): 
    435         """ 
    436         Calculates chi^2 
    437          
    438         :param params: list of parameter values 
    439          
    440         :return: chi^2 
    441          
    442         """ 
    443         total = 0 
    444         for item in self.true_res: 
    445             total += item * item 
    446         if len(self.true_res) == 0: 
    447             return None 
    448         return total / len(self.true_res) 
    449      
    450     def __call__(self, params): 
    451         """ 
    452             Compute residuals 
    453             :param params: value of parameters to fit 
    454         """ 
    455         #import thread 
    456         self.model.set_params(self.paramlist, params) 
    457         #print "params", params 
    458         self.true_res, theory = self.data.residuals(self.model.eval) 
    459         self.theory = copy.deepcopy(theory) 
    460         # check parameters range 
    461         if self.check_param_range(): 
    462             # if the param value is outside of the bound 
    463             # just silent return res = inf 
    464             return self.res 
    465         self.res = self.true_res 
    466          
    467         if self.fitresult is not None: 
    468             self.fitresult.set_model(model=self.model) 
    469             self.fitresult.residuals = self.true_res 
    470             self.fitresult.iterations += 1 
    471             self.fitresult.theory = theory 
    472             
    473             #fitness = self.chisq(params=params) 
    474             fitness = self.chisq() 
    475             self.fitresult.pvec = params 
    476             self.fitresult.set_fitness(fitness=fitness) 
    477             if self.msg_q is not None: 
    478                 self.msg_q.put(self.fitresult) 
    479                  
    480             if self.handler is not None: 
    481                 self.handler.set_result(result=self.fitresult) 
    482                 self.handler.update_fit() 
    483  
    484             if self.curr_thread != None: 
    485                 try: 
    486                     self.curr_thread.isquit() 
    487                 except: 
    488                     #msg = "Fitting: Terminated...       Note: Forcing to stop " 
    489                     #msg += "fitting may cause a 'Functor error message' " 
    490                     #msg += "being recorded in the log file....." 
    491                     #self.handler.stop(msg) 
    492                     raise 
    493           
    494         return self.res 
    495      
    496     def check_param_range(self): 
    497         """ 
    498         Check the lower and upper bound of the parameter value 
    499         and set res to the inf if the value is outside of the 
    500         range 
    501         :limitation: the initial values must be within range. 
    502         """ 
    503  
    504         #time.sleep(0.01) 
    505         is_outofbound = False 
    506         # loop through the fit parameters 
    507         for p in self.model.parameterset: 
    508             param_name = p.get_name() 
    509             if param_name in self.paramlist: 
    510                  
    511                 # if the range was defined, check the range 
    512                 if numpy.isfinite(p.range[0]): 
    513                     if p.value == 0: 
    514                         # This value works on Scipy 
    515                         # Do not change numbers below 
    516                         value = _SMALLVALUE 
    517                     else: 
    518                         value = p.value 
    519                     # For leastsq, it needs a bit step back from the boundary 
    520                     val = p.range[0] - value * _SMALLVALUE 
    521                     if p.value < val: 
    522                         self.res *= 1e+6 
    523                          
    524                         is_outofbound = True 
    525                         break 
    526                 if numpy.isfinite(p.range[1]): 
    527                     # This value works on Scipy 
    528                     # Do not change numbers below 
    529                     if p.value == 0: 
    530                         value = _SMALLVALUE 
    531                     else: 
    532                         value = p.value 
    533                     # For leastsq, it needs a bit step back from the boundary 
    534                     val = p.range[1] + value * _SMALLVALUE 
    535                     if p.value > val: 
    536                         self.res *= 1e+6 
    537                         is_outofbound = True 
    538                         break 
    539  
    540         return is_outofbound 
    541      
    542      
     392 
    543393class FitEngine: 
    544394    def __init__(self): 
     
    571421         
    572422        """ 
    573         if model == None: 
    574             raise ValueError, "AbstractFitEngine: Need to set model to fit" 
    575          
    576         new_model = model 
     423        if not pars: 
     424            raise ValueError("no fitting parameters") 
     425 
     426        if model is None: 
     427            raise ValueError("no model to fit") 
     428 
    577429        if not issubclass(model.__class__, Model): 
    578             new_model = Model(model, data) 
    579          
    580         if len(constraints) > 0: 
    581             for constraint in constraints: 
    582                 name, value = constraint 
    583                 try: 
    584                     new_model.parameterset[str(name)].set(str(value)) 
    585                 except: 
    586                     msg = "Fit Engine: Error occurs when setting the constraint" 
    587                     msg += " %s for parameter %s " % (value, name) 
    588                     raise ValueError, msg 
    589                  
    590         if len(pars) > 0: 
    591             temp = [] 
    592             for item in pars: 
    593                 if item in new_model.model.getParamList(): 
    594                     temp.append(item) 
    595                     self.param_list.append(item) 
    596                 else: 
    597                      
    598                     msg = "wrong parameter %s used " % str(item) 
    599                     msg += "to set model %s. Choose " % str(new_model.model.name) 
    600                     msg += "parameter name within %s" % \ 
    601                                 str(new_model.model.getParamList()) 
    602                     raise ValueError, msg 
    603                
    604             #A fitArrange is already created but contains data_list only at id 
    605             if self.fit_arrange_dict.has_key(id): 
    606                 self.fit_arrange_dict[id].set_model(new_model) 
    607                 self.fit_arrange_dict[id].pars = pars 
    608             else: 
    609             #no fitArrange object has been create with this id 
    610                 fitproblem = FitArrange() 
    611                 fitproblem.set_model(new_model) 
    612                 fitproblem.pars = pars 
    613                 self.fit_arrange_dict[id] = fitproblem 
    614                 vals = [] 
    615                 for name in pars: 
    616                     vals.append(new_model.model.getParam(name)) 
    617                 self.fit_arrange_dict[id].vals = vals 
    618         else: 
    619             raise ValueError, "park_integration:missing parameters" 
    620      
     430            model = Model(model, data) 
     431 
     432        sasmodel = model.model 
     433        available_parameters = sasmodel.getParamList() 
     434        for p in pars: 
     435            if p not in available_parameters: 
     436                raise ValueError("parameter %s not available in model %s; use one of [%s] instead" 
     437                                 %(p, sasmodel.name, ", ".join(available_parameters))) 
     438 
     439        if id not in self.fit_arrange_dict: 
     440            self.fit_arrange_dict[id] = FitArrange() 
     441 
     442        self.fit_arrange_dict[id].set_model(model) 
     443        self.fit_arrange_dict[id].pars = pars 
     444        self.fit_arrange_dict[id].vals = [sasmodel.getParam(name) for name in pars] 
     445        self.fit_arrange_dict[id].constraints = constraints 
     446 
     447        self.param_list.extend(pars) 
     448 
    621449    def set_data(self, data, id, smearer=None, qmin=None, qmax=None): 
    622450        """ 
     
    700528        self.vals = [] 
    701529        self.selected = 0 
    702          
     530 
    703531    def set_model(self, model): 
    704532        """ 
     
    752580        """ 
    753581        return self.selected 
    754      
    755      
    756 IS_MAC = True 
    757 if sys.platform.count("win32") > 0: 
    758     IS_MAC = False 
    759  
    760582 
    761583class FResult(object): 
     
    765587    def __init__(self, model=None, param_list=None, data=None): 
    766588        self.calls = None 
    767         self.pars = [] 
    768589        self.fitness = None 
    769590        self.chisqr = None 
     
    776597        self.residuals = [] 
    777598        self.index = [] 
    778         self.parameters = None 
    779         self.is_mac = IS_MAC 
    780599        self.model = model 
    781600        self.data = data 
     
    803622        if self.pvec == None and self.model is None and self.param_list is None: 
    804623            return "No results" 
    805         n = len(self.model.parameterset) 
    806          
    807         result_param = zip(xrange(n), self.model.parameterset) 
    808         msg1 = ["[Iteration #: %s ]" % self.iterations] 
    809         msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))] 
    810         if not self.is_mac: 
    811             msg2 = ["P%-3d  %s......|.....%s" % \ 
    812                 (p[0], p[1], p[1].value)\ 
    813                   for p in result_param if p[1].name in self.param_list] 
    814             msg = msg1 + msg3 + msg2 
    815         else: 
    816             msg = msg1 + msg3 
    817         msg = "\n".join(msg) 
    818         return msg 
     624 
     625        sasmodel = self.model.model 
     626        pars = enumerate(sasmodel.getParamList()) 
     627        msg1 = "[Iteration #: %s ]" % self.iterations 
     628        msg3 = "=== goodness of fit: %s ===" % (str(self.fitness)) 
     629        msg2 = ["P%-3d  %s......|.....%s" % (i, v, sasmodel.getParam(v)) 
     630                for i,v in pars if v in self.param_list] 
     631        msg = [msg1, msg3] + msg2 
     632        return "\n".join(msg) 
    819633     
    820634    def print_summary(self): 
    821635        """ 
    822636        """ 
    823         print self 
     637        print str(self) 
Note: See TracChangeset for help on using the changeset viewer.