Changeset 95d58d3 in sasview for src/sans/fit/BumpsFitting.py


Ignore:
Timestamp:
Apr 10, 2014 8:05:28 PM (10 years ago)
Author:
pkienzle
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
90f49a8
Parents:
6fe5100
Message:

fix fit line test for bumps/scipy/park and enable it as part of test suite

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/BumpsFitting.py

    r6fe5100 r95d58d3  
    33""" 
    44import sys 
    5 import copy 
    65 
    76import numpy 
     
    1312from sans.fit.AbstractFitEngine import FResult 
    1413 
    15 class SansAssembly(object): 
     14class SasProblem(object): 
    1615    """ 
    17     Sans Assembly class a class wrapper to be call in optimizer.leastsq method 
     16    Wrap the SAS model in a form that can be understood by bumps. 
    1817    """ 
    19     def __init__(self, paramlist, model=None, data=None, fitresult=None, 
     18    def __init__(self, param_list, model=None, data=None, fitresult=None, 
    2019                 handler=None, curr_thread=None, msg_q=None): 
    2120        """ 
     
    2524        self.model = model 
    2625        self.data = data 
    27         self.paramlist = paramlist 
     26        self.param_list = param_list 
    2827        self.msg_q = msg_q 
    2928        self.curr_thread = curr_thread 
     
    3736    @property 
    3837    def dof(self): 
    39         return self.data.num_points - len(self.paramlist) 
     38        return self.data.num_points - len(self.param_list) 
    4039 
    4140    def summarize(self): 
    42         return "summarize" 
    43  
    44     def nllf(self, pvec=None): 
    45         residuals = self.residuals(pvec) 
     41        """ 
     42        Return a stylized list of parameter names and values with range bars 
     43        suitable for printing. 
     44        """ 
     45        output = [] 
     46        bounds = self.bounds() 
     47        for i,p in enumerate(self.getp()): 
     48            name = self.param_list[i] 
     49            low,high = bounds[:,i] 
     50            range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"), 
     51                              ("%g]"%high if numpy.isfinite(high) else "inf)"))) 
     52            if not numpy.isfinite(p): 
     53                bar = "*invalid* " 
     54            else: 
     55                bar = ['.']*10 
     56                if numpy.isfinite(high-low): 
     57                    position = int(9.999999999 * float(p-low)/float(high-low)) 
     58                    if position < 0: bar[0] = '<' 
     59                    elif position > 9: bar[9] = '>' 
     60                    else: bar[position] = '|' 
     61                bar = "".join(bar) 
     62            output.append("%40s %s %10g in %s"%(name,bar,p,range)) 
     63        return "\n".join(output) 
     64 
     65    def nllf(self, p=None): 
     66        residuals = self.residuals(p) 
    4667        return 0.5*numpy.sum(residuals**2) 
    4768 
    48     def setp(self, params): 
    49         self.model.set_params(self.paramlist, params) 
     69    def setp(self, p): 
     70        for k,v in zip(self.param_list, p): 
     71            self.model.setParam(k,v) 
     72        #self.model.set_params(self.param_list, params) 
    5073 
    5174    def getp(self): 
    52         return numpy.asarray(self.model.get_params(self.paramlist)) 
     75        return numpy.array([self.model.getParam(k) for k in self.param_list]) 
     76        #return numpy.asarray(self.model.get_params(self.param_list)) 
    5377 
    5478    def bounds(self): 
    55         return numpy.array([self._getrange(p) for p in self.paramlist]).T 
     79        return numpy.array([self._getrange(p) for p in self.param_list]).T 
    5680 
    5781    def labels(self): 
    58         return self.paramlist 
     82        return self.param_list 
    5983 
    6084    def _getrange(self, p): 
     
    6387        return the range of parameter 
    6488        """ 
    65         lo, hi = self.model.model.details[p][1:3] 
     89        lo, hi = self.model.details[p][1:3] 
    6690        if lo is None: lo = -numpy.inf 
    6791        if hi is None: hi = numpy.inf 
     
    6993 
    7094    def randomize(self, n): 
    71         pvec = self.getp() 
     95        p = self.getp() 
    7296        # since randn is symmetric and random, doesn't matter 
    7397        # point value is negative. 
    7498        # TODO: throw in bounds checking! 
    75         return numpy.random.randn(n, len(self.paramlist))*pvec + pvec 
     99        return numpy.random.randn(n, len(self.param_list))*p + p 
    76100 
    77101    def chisq(self): 
     
    84108 
    85109        """ 
    86         total = 0 
    87         for item in self.res: 
    88             total += item * item 
    89         if len(self.res) == 0: 
    90             return None 
    91         return total / len(self.res) 
     110        return numpy.sum(self.res**2)/self.dof 
    92111 
    93112    def residuals(self, params=None): 
     
    99118        #import thread 
    100119        #print "params", params 
    101         self.res, self.theory = self.data.residuals(self.model.eval) 
    102  
     120        self.res, self.theory = self.data.residuals(self.model.evalDistribution) 
     121 
     122        # TODO: this belongs in monitor not residuals calculation 
    103123        if self.fitresult is not None: 
    104             self.fitresult.set_model(model=self.model) 
     124            #self.fitresult.set_model(model=self.model) 
    105125            self.fitresult.residuals = self.res+0 
    106126            self.fitresult.iterations += 1 
     
    109129            #fitness = self.chisq(params=params) 
    110130            fitness = self.chisq() 
    111             self.fitresult.pvec = params 
     131            self.fitresult.p = params 
    112132            self.fitresult.set_fitness(fitness=fitness) 
    113133            if self.msg_q is not None: 
     
    131151    __call__ = residuals 
    132152 
    133     def check_param_range(self): 
     153    def _DEAD_check_param_range(self): 
    134154        """ 
    135155        Check the lower and upper bound of the parameter value 
     
    142162        is_outofbound = False 
    143163        # loop through the fit parameters 
    144         model = self.model.model 
    145         for p in self.paramlist: 
     164        model = self.model 
     165        for p in self.param_list: 
    146166            value = model.getParam(p) 
    147167            low,high = model.details[p][1:3] 
     
    196216            raise RuntimeError, msg 
    197217        elif len(fitproblem) == 0 : 
    198             raise RuntimeError, "No Assembly scheduled for Scipy fitting." 
     218            raise RuntimeError, "No problem scheduled for fitting." 
    199219        model = fitproblem[0].get_model() 
    200220        if reset_flag: 
     
    203223                ind = fitproblem[0].pars.index(name) 
    204224                model.setParam(name, fitproblem[0].vals[ind]) 
    205         listdata = [] 
    206225        listdata = fitproblem[0].get_data() 
    207226        # Concatenate dList set (contains one or more data)before fitting 
     
    209228 
    210229        self.curr_thread = curr_thread 
    211         ftol = ftol 
    212230 
    213231        result = FResult(model=model, data=data, param_list=self.param_list) 
     
    217235        if handler is not None: 
    218236            handler.set_result(result=result) 
    219         functor = SansAssembly(paramlist=self.param_list, 
    220                                model=model, 
    221                                data=data, 
    222                                handler=handler, 
    223                                fitresult=result, 
    224                                curr_thread=curr_thread, 
    225                                msg_q=msg_q) 
     237        problem = SasProblem(param_list=self.param_list, 
     238                              model=model.model, 
     239                              data=data, 
     240                              handler=handler, 
     241                              fitresult=result, 
     242                              curr_thread=curr_thread, 
     243                              msg_q=msg_q) 
    226244        try: 
    227             run_bumps(functor, result) 
     245            #run_bumps(problem, result, ftol) 
     246            run_scipy(problem, result, ftol) 
    228247        except: 
    229248            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt: 
     
    245264        return [result] 
    246265 
    247 def run_bumps(problem, result): 
     266def run_bumps(problem, result, ftol): 
    248267    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT] 
    249     fitdriver = fitters.FitDriver(fitopts.fitclass, problem=problem,  
    250         abort_test=lambda: False, **fitopts.options) 
     268    fitclass = fitopts.fitclass 
     269    options = fitopts.options.copy() 
     270    options['ftol'] = ftol 
     271    fitdriver = fitters.FitDriver(fitclass, problem=problem, 
     272                                  abort_test=lambda: False, **options) 
    251273    mapper = SerialMapper  
    252274    fitdriver.mapper = mapper.start_mapper(problem, None) 
     
    256278        import traceback; traceback.print_exc() 
    257279        raise 
    258     mapper.stop_mapper(fitdriver.mapper) 
    259     fitdriver.show() 
    260     #fitdriver.plot() 
    261     result.fitness = fbest * 2. / len(result.pars)  
    262     result.stderr  = numpy.ones(len(result.pars)) 
    263     result.pvec = best  
     280    finally: 
     281        mapper.stop_mapper(fitdriver.mapper) 
     282    #print "best,fbest",best,fbest,problem.dof 
     283    result.fitness = 2*fbest/problem.dof 
     284    #print "fitness",result.fitness 
     285    result.stderr  = fitdriver.stderr() 
     286    result.pvec = best 
     287    # TODO: track success better 
    264288    result.success = True 
    265289    result.theory = problem.theory 
    266290 
    267 def run_scipy(model, result): 
     291def run_scipy(model, result, ftol): 
    268292    # This import must be here; otherwise it will be confused when more 
    269293    # than one thread exist. 
    270294    from scipy import optimize 
    271295 
    272     out, cov_x, _, mesg, success = optimize.leastsq(functor, 
    273                                                     model.get_params(self.param_list), 
     296    out, cov_x, _, mesg, success = optimize.leastsq(model.residuals, 
     297                                                    model.getp(), 
    274298                                                    ftol=ftol, 
    275299                                                    full_output=1) 
     
    278302    else: 
    279303        stderr = [] 
    280     result.fitness = functor.chisqr() 
     304    result.fitness = model.chisq() 
    281305    result.stderr  = stderr 
    282306    result.pvec = out 
    283307    result.success = success 
    284     result.theory = functor.theory 
    285  
     308    result.theory = model.theory 
     309 
Note: See TracChangeset for help on using the changeset viewer.