Changeset 6fe5100 in sasview for src/sans/fit/AbstractFitEngine.py


Ignore:
Timestamp:
Apr 6, 2014 7:29:59 AM (10 years ago)
Author:
pkienzle
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
95d58d3
Parents:
960fdbb
Message:

Bumps first pass. Fitting works but no pretty pictures

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/AbstractFitEngine.py

    r6c00702 r6fe5100  
    33#import logging 
    44import sys 
     5import math 
    56import numpy 
    6 import math 
    7 import park 
     7 
    88from sans.dataloader.data_info import Data1D 
    99from sans.dataloader.data_info import Data2D 
    10 _SMALLVALUE = 1.0e-10     
    11      
    12 class SansParameter(park.Parameter): 
    13     """ 
    14     SANS model parameters for use in the PARK fitting service. 
    15     The parameter attribute value is redirected to the underlying 
    16     parameter value in the SANS model. 
    17     """ 
    18     def __init__(self, name, model, data): 
    19         """ 
    20             :param name: the name of the model parameter 
    21             :param model: the sans model to wrap as a park model 
    22         """ 
    23         park.Parameter.__init__(self, name) 
    24         self._model, self._name = model, name 
    25         self.data = data 
    26         self.model = model 
    27         #set the value for the parameter of the given name 
    28         self.set(model.getParam(name)) 
    29           
    30     def _getvalue(self): 
    31         """ 
    32         override the _getvalue of park parameter 
    33          
    34         :return value the parameter associates with self.name 
    35          
    36         """ 
    37         return self._model.getParam(self.name) 
    38      
    39     def _setvalue(self, value): 
    40         """ 
    41         override the _setvalue pf park parameter 
    42          
    43         :param value: the value to set on a given parameter 
    44          
    45         """ 
    46         self._model.setParam(self.name, value) 
    47          
    48     value = property(_getvalue, _setvalue) 
    49      
    50     def _getrange(self): 
    51         """ 
    52         Override _getrange of park parameter 
    53         return the range of parameter 
    54         """ 
    55         #if not  self.name in self._model.getDispParamList(): 
    56         lo, hi = self._model.details[self.name][1:3] 
    57         if lo is None: lo = -numpy.inf 
    58         if hi is None: hi = numpy.inf 
    59         if lo > hi: 
    60             raise ValueError, "wrong fit range for parameters" 
    61          
    62         return lo, hi 
    63      
    64     def get_name(self): 
    65         """ 
    66         """ 
    67         return self._getname() 
    68      
    69     def _setrange(self, r): 
    70         """ 
    71         override _setrange of park parameter 
    72          
    73         :param r: the value of the range to set 
    74          
    75         """ 
    76         self._model.details[self.name][1:3] = r 
    77     range = property(_getrange, _setrange) 
    78      
    79      
    80 class Model(park.Model): 
    81     """ 
    82     PARK wrapper for SANS models. 
     10_SMALLVALUE = 1.0e-10 
     11 
     12# Note: duplicated from park 
     13class FitHandler(object): 
     14    """ 
     15    Abstract interface for fit thread handler. 
     16 
     17    The methods in this class are called by the optimizer as the fit 
     18    progresses. 
     19 
     20    Note that it is up to the optimizer to call the fit handler correctly, 
     21    reporting all status changes and maintaining the 'done' flag. 
     22    """ 
     23    done = False 
     24    """True when the fit job is complete""" 
     25    result = None 
     26    """The current best result of the fit""" 
     27 
     28    def improvement(self): 
     29        """ 
     30        Called when a result is observed which is better than previous 
     31        results from the fit. 
     32 
     33        result is a FitResult object, with parameters, #calls and fitness. 
     34        """ 
     35    def error(self, msg): 
     36        """ 
     37        Model had an error; print traceback 
     38        """ 
     39    def progress(self, current, expected): 
     40        """ 
     41        Called each cycle of the fit, reporting the current and the 
     42        expected amount of work.   The meaning of these values is 
     43        optimizer dependent, but they can be converted into a percent 
     44        complete using (100*current)//expected. 
     45 
     46        Progress is updated each iteration of the fit, whatever that 
     47        means for the particular optimization algorithm.  It is called 
     48        after any calls to improvement for the iteration so that the 
     49        update handler can control I/O bandwidth by suppressing 
     50        intermediate improvements until the fit is complete. 
     51        """ 
     52    def finalize(self): 
     53        """ 
     54        Fit is complete; best results are reported 
     55        """ 
     56    def abort(self): 
     57        """ 
     58        Fit was aborted. 
     59        """ 
     60 
     61class Model: 
     62    """ 
     63    Fit wrapper for SANS models. 
    8364    """ 
    8465    def __init__(self, sans_model, sans_data=None, **kw): 
    8566        """ 
    8667        :param sans_model: the sans model to wrap using park interface 
    87          
    88         """ 
    89         park.Model.__init__(self, **kw) 
     68 
     69        """ 
    9070        self.model = sans_model 
    9171        self.name = sans_model.name 
    9272        self.data = sans_data 
    93         #list of parameters names 
    94         self.sansp = sans_model.getParamList() 
    95         #list of park parameter 
    96         self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp] 
    97         #list of parameter set 
    98         self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp) 
    99         self.pars = [] 
    100    
     73 
    10174    def get_params(self, fitparams): 
    10275        """ 
    10376        return a list of value of paramter to fit 
    104          
     77 
    10578        :param fitparams: list of paramaters name to fit 
    106          
    107         """ 
    108         list_params = [] 
    109         self.pars = [] 
    110         self.pars = fitparams 
    111         for item in fitparams: 
    112             for element in self.parkp: 
    113                 if element.name == str(item): 
    114                     list_params.append(element.value) 
    115         return list_params 
    116      
     79 
     80        """ 
     81        return [self.model.getParam(k) for k in fitparams] 
     82 
    11783    def set_params(self, paramlist, params): 
    11884        """ 
    11985        Set value for parameters to fit 
    120          
     86 
    12187        :param params: list of value for parameters to fit 
    122          
    123         """ 
    124         try: 
    125             for i in range(len(self.parkp)): 
    126                 for j in range(len(paramlist)): 
    127                     if self.parkp[i].name == paramlist[j]: 
    128                         self.parkp[i].value = params[j] 
    129                         self.model.setParam(self.parkp[i].name, params[j]) 
    130         except: 
    131             raise 
    132    
     88 
     89        """ 
     90        for k,v in zip(paramlist, params): 
     91            self.model.setParam(k,v) 
     92 
     93    def set(self, **kw): 
     94        self.set_params(*zip(*kw.items())) 
     95 
    13396    def eval(self, x): 
    13497        """ 
    13598            Override eval method of park model. 
    136          
     99 
    137100            :param x: the x value used to compute a function 
    138101        """ 
     
    141104        except: 
    142105            raise 
    143          
     106 
    144107    def eval_derivs(self, x, pars=[]): 
    145108        """ 
     
    154117        instead of calling eval. 
    155118        """ 
    156         return [] 
    157  
    158      
     119        raise NotImplementedError('no derivatives available') 
     120 
     121    def __call__(self, x): 
     122        return self.eval(x) 
     123 
    159124class FitData1D(Data1D): 
    160125    """ 
     
    185150        """ 
    186151        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy) 
     152        self.num_points = len(x) 
    187153        self.sans_data = data 
    188154        self.smearer = smearer 
     
    298264        """ 
    299265        self.res_err_image = [] 
     266        self.num_points = data.size 
    300267        self.idx = [] 
    301268        self.qmin = None 
     
    409376 
    410377 
    411 class SansAssembly: 
    412     """ 
    413     Sans Assembly class a class wrapper to be call in optimizer.leastsq method 
    414     """ 
    415     def __init__(self, paramlist, model=None, data=None, fitresult=None, 
    416                  handler=None, curr_thread=None, msg_q=None): 
    417         """ 
    418         :param Model: the model wrapper fro sans -model 
    419         :param Data: the data wrapper for sans data 
    420          
    421         """ 
    422         self.model = model 
    423         self.data = data 
    424         self.paramlist = paramlist 
    425         self.msg_q = msg_q 
    426         self.curr_thread = curr_thread 
    427         self.handler = handler 
    428         self.fitresult = fitresult 
    429         self.res = [] 
    430         self.true_res = [] 
    431         self.func_name = "Functor" 
    432         self.theory = None 
    433          
    434     def chisq(self): 
    435         """ 
    436         Calculates chi^2 
    437          
    438         :param params: list of parameter values 
    439          
    440         :return: chi^2 
    441          
    442         """ 
    443         total = 0 
    444         for item in self.true_res: 
    445             total += item * item 
    446         if len(self.true_res) == 0: 
    447             return None 
    448         return total / len(self.true_res) 
    449      
    450     def __call__(self, params): 
    451         """ 
    452             Compute residuals 
    453             :param params: value of parameters to fit 
    454         """ 
    455         #import thread 
    456         self.model.set_params(self.paramlist, params) 
    457         #print "params", params 
    458         self.true_res, theory = self.data.residuals(self.model.eval) 
    459         self.theory = copy.deepcopy(theory) 
    460         # check parameters range 
    461         if self.check_param_range(): 
    462             # if the param value is outside of the bound 
    463             # just silent return res = inf 
    464             return self.res 
    465         self.res = self.true_res 
    466          
    467         if self.fitresult is not None: 
    468             self.fitresult.set_model(model=self.model) 
    469             self.fitresult.residuals = self.true_res 
    470             self.fitresult.iterations += 1 
    471             self.fitresult.theory = theory 
    472             
    473             #fitness = self.chisq(params=params) 
    474             fitness = self.chisq() 
    475             self.fitresult.pvec = params 
    476             self.fitresult.set_fitness(fitness=fitness) 
    477             if self.msg_q is not None: 
    478                 self.msg_q.put(self.fitresult) 
    479                  
    480             if self.handler is not None: 
    481                 self.handler.set_result(result=self.fitresult) 
    482                 self.handler.update_fit() 
    483  
    484             if self.curr_thread != None: 
    485                 try: 
    486                     self.curr_thread.isquit() 
    487                 except: 
    488                     #msg = "Fitting: Terminated...       Note: Forcing to stop " 
    489                     #msg += "fitting may cause a 'Functor error message' " 
    490                     #msg += "being recorded in the log file....." 
    491                     #self.handler.stop(msg) 
    492                     raise 
    493           
    494         return self.res 
    495      
    496     def check_param_range(self): 
    497         """ 
    498         Check the lower and upper bound of the parameter value 
    499         and set res to the inf if the value is outside of the 
    500         range 
    501         :limitation: the initial values must be within range. 
    502         """ 
    503  
    504         #time.sleep(0.01) 
    505         is_outofbound = False 
    506         # loop through the fit parameters 
    507         for p in self.model.parameterset: 
    508             param_name = p.get_name() 
    509             if param_name in self.paramlist: 
    510                  
    511                 # if the range was defined, check the range 
    512                 if numpy.isfinite(p.range[0]): 
    513                     if p.value == 0: 
    514                         # This value works on Scipy 
    515                         # Do not change numbers below 
    516                         value = _SMALLVALUE 
    517                     else: 
    518                         value = p.value 
    519                     # For leastsq, it needs a bit step back from the boundary 
    520                     val = p.range[0] - value * _SMALLVALUE 
    521                     if p.value < val: 
    522                         self.res *= 1e+6 
    523                          
    524                         is_outofbound = True 
    525                         break 
    526                 if numpy.isfinite(p.range[1]): 
    527                     # This value works on Scipy 
    528                     # Do not change numbers below 
    529                     if p.value == 0: 
    530                         value = _SMALLVALUE 
    531                     else: 
    532                         value = p.value 
    533                     # For leastsq, it needs a bit step back from the boundary 
    534                     val = p.range[1] + value * _SMALLVALUE 
    535                     if p.value > val: 
    536                         self.res *= 1e+6 
    537                         is_outofbound = True 
    538                         break 
    539  
    540         return is_outofbound 
    541      
    542      
     378 
    543379class FitEngine: 
    544380    def __init__(self): 
     
    754590     
    755591     
    756 IS_MAC = True 
    757 if sys.platform.count("win32") > 0: 
    758     IS_MAC = False 
    759  
    760  
    761592class FResult(object): 
    762593    """ 
     
    777608        self.index = [] 
    778609        self.parameters = None 
    779         self.is_mac = IS_MAC 
    780610        self.model = model 
    781611        self.data = data 
     
    803633        if self.pvec == None and self.model is None and self.param_list is None: 
    804634            return "No results" 
    805         n = len(self.model.parameterset) 
    806          
    807         result_param = zip(xrange(n), self.model.parameterset) 
    808         msg1 = ["[Iteration #: %s ]" % self.iterations] 
    809         msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))] 
    810         if not self.is_mac: 
    811             msg2 = ["P%-3d  %s......|.....%s" % \ 
    812                 (p[0], p[1], p[1].value)\ 
    813                   for p in result_param if p[1].name in self.param_list] 
    814             msg = msg1 + msg3 + msg2 
    815         else: 
    816             msg = msg1 + msg3 
    817         msg = "\n".join(msg) 
    818         return msg 
     635 
     636        pars = enumerate(self.model.model.getParamList()) 
     637        msg1 = "[Iteration #: %s ]" % self.iterations 
     638        msg3 = "=== goodness of fit: %s ===" % (str(self.fitness)) 
     639        msg2 = ["P%-3d  %s......|.....%s" % (i, v, self.model.model.getParam(v)) 
     640                for i,v in pars if v in self.param_list] 
     641        msg = [msg1, msg3] + msg2 
     642        return "\n".join(msg) 
    819643     
    820644    def print_summary(self): 
Note: See TracChangeset for help on using the changeset viewer.