Changeset 95d58d3 in sasview for src


Ignore:
Timestamp:
Apr 10, 2014 8:05:28 PM (11 years ago)
Author:
pkienzle
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
90f49a8
Parents:
6fe5100
Message:

fix fit line test for bumps/scipy/park and enable it as part of test suite

Location:
src/sans/fit
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/AbstractFitEngine.py

    r6fe5100 r95d58d3  
    5858        Fit was aborted. 
    5959        """ 
     60 
     61    # TODO: not sure how these are used, but they are needed for running the fit 
     62    def update_fit(self, last=False): pass 
     63    def set_result(self, result=None): self.result = result 
    6064 
    6165class Model: 
     
    217221        """ 
    218222        return self.qmin, self.qmax 
    219          
     223 
     224    def size(self): 
     225        """ 
     226        Number of measurement points in data set after masking, etc. 
     227        """ 
     228        return len(self.x) 
     229 
    220230    def residuals(self, fn): 
    221231        """ 
     
    259269    def __init__(self, sans_data2d, data=None, err_data=None): 
    260270        Data2D.__init__(self, data=data, err_data=err_data) 
    261         """ 
    262             Data can be initital with a data (sans plottable) 
    263             or with vectors. 
    264         """ 
     271        # Data can be initialized with a sans plottable or with vectors. 
    265272        self.res_err_image = [] 
    266         self.num_points = data.size 
     273        self.num_points = 0 # will be set by set_data 
    267274        self.idx = [] 
    268275        self.qmin = None 
     
    306313        self.idx = (self.idx) & (self.mask) 
    307314        self.idx = (self.idx) & (numpy.isfinite(self.data)) 
     315        self.num_points = numpy.sum(self.idx) 
    308316 
    309317    def set_smearer(self, smearer): 
     
    339347        """ 
    340348        return self.qmin, self.qmax 
    341       
     349 
     350    def size(self): 
     351        """ 
     352        Number of measurement points in data set after masking, etc. 
     353        """ 
     354        return numpy.sum(self.idx) 
     355 
    342356    def residuals(self, fn): 
    343357        """ 
     
    410424            raise ValueError, "AbstractFitEngine: Need to set model to fit" 
    411425         
    412         new_model = model 
    413426        if not issubclass(model.__class__, Model): 
    414             new_model = Model(model, data) 
    415          
     427            model = Model(model, data) 
     428 
     429        sasmodel = model.model 
    416430        if len(constraints) > 0: 
    417431            for constraint in constraints: 
    418432                name, value = constraint 
    419433                try: 
    420                     new_model.parameterset[str(name)].set(str(value)) 
     434                    model.parameterset[str(name)].set(str(value)) 
    421435                except: 
    422436                    msg = "Fit Engine: Error occurs when setting the constraint" 
     
    427441            temp = [] 
    428442            for item in pars: 
    429                 if item in new_model.model.getParamList(): 
     443                if item in sasmodel.getParamList(): 
    430444                    temp.append(item) 
    431445                    self.param_list.append(item) 
     
    433447                     
    434448                    msg = "wrong parameter %s used " % str(item) 
    435                     msg += "to set model %s. Choose " % str(new_model.model.name) 
     449                    msg += "to set model %s. Choose " % str(sasmodel.name) 
    436450                    msg += "parameter name within %s" % \ 
    437                                 str(new_model.model.getParamList()) 
     451                                str(sasmodel.getParamList()) 
    438452                    raise ValueError, msg 
    439453               
    440454            #A fitArrange is already created but contains data_list only at id 
    441455            if self.fit_arrange_dict.has_key(id): 
    442                 self.fit_arrange_dict[id].set_model(new_model) 
     456                self.fit_arrange_dict[id].set_model(model) 
    443457                self.fit_arrange_dict[id].pars = pars 
    444458            else: 
    445459            #no fitArrange object has been create with this id 
    446460                fitproblem = FitArrange() 
    447                 fitproblem.set_model(new_model) 
     461                fitproblem.set_model(model) 
    448462                fitproblem.pars = pars 
    449463                self.fit_arrange_dict[id] = fitproblem 
    450464                vals = [] 
    451465                for name in pars: 
    452                     vals.append(new_model.model.getParam(name)) 
     466                    vals.append(sasmodel.getParam(name)) 
    453467                self.fit_arrange_dict[id].vals = vals 
    454468        else: 
     
    634648            return "No results" 
    635649 
    636         pars = enumerate(self.model.model.getParamList()) 
     650        sasmodel = self.model.model 
     651        pars = enumerate(sasmodel.getParamList()) 
    637652        msg1 = "[Iteration #: %s ]" % self.iterations 
    638653        msg3 = "=== goodness of fit: %s ===" % (str(self.fitness)) 
    639         msg2 = ["P%-3d  %s......|.....%s" % (i, v, self.model.model.getParam(v)) 
     654        msg2 = ["P%-3d  %s......|.....%s" % (i, v, sasmodel.getParam(v)) 
    640655                for i,v in pars if v in self.param_list] 
    641656        msg = [msg1, msg3] + msg2 
     
    645660        """ 
    646661        """ 
    647         print self 
     662        print str(self) 
  • src/sans/fit/BumpsFitting.py

    r6fe5100 r95d58d3  
    33""" 
    44import sys 
    5 import copy 
    65 
    76import numpy 
     
    1312from sans.fit.AbstractFitEngine import FResult 
    1413 
    15 class SansAssembly(object): 
     14class SasProblem(object): 
    1615    """ 
    17     Sans Assembly class a class wrapper to be call in optimizer.leastsq method 
     16    Wrap the SAS model in a form that can be understood by bumps. 
    1817    """ 
    19     def __init__(self, paramlist, model=None, data=None, fitresult=None, 
     18    def __init__(self, param_list, model=None, data=None, fitresult=None, 
    2019                 handler=None, curr_thread=None, msg_q=None): 
    2120        """ 
     
    2524        self.model = model 
    2625        self.data = data 
    27         self.paramlist = paramlist 
     26        self.param_list = param_list 
    2827        self.msg_q = msg_q 
    2928        self.curr_thread = curr_thread 
     
    3736    @property 
    3837    def dof(self): 
    39         return self.data.num_points - len(self.paramlist) 
     38        return self.data.num_points - len(self.param_list) 
    4039 
    4140    def summarize(self): 
    42         return "summarize" 
    43  
    44     def nllf(self, pvec=None): 
    45         residuals = self.residuals(pvec) 
     41        """ 
     42        Return a stylized list of parameter names and values with range bars 
     43        suitable for printing. 
     44        """ 
     45        output = [] 
     46        bounds = self.bounds() 
     47        for i,p in enumerate(self.getp()): 
     48            name = self.param_list[i] 
     49            low,high = bounds[:,i] 
     50            range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"), 
     51                              ("%g]"%high if numpy.isfinite(high) else "inf)"))) 
     52            if not numpy.isfinite(p): 
     53                bar = "*invalid* " 
     54            else: 
     55                bar = ['.']*10 
     56                if numpy.isfinite(high-low): 
     57                    position = int(9.999999999 * float(p-low)/float(high-low)) 
     58                    if position < 0: bar[0] = '<' 
     59                    elif position > 9: bar[9] = '>' 
     60                    else: bar[position] = '|' 
     61                bar = "".join(bar) 
     62            output.append("%40s %s %10g in %s"%(name,bar,p,range)) 
     63        return "\n".join(output) 
     64 
     65    def nllf(self, p=None): 
     66        residuals = self.residuals(p) 
    4667        return 0.5*numpy.sum(residuals**2) 
    4768 
    48     def setp(self, params): 
    49         self.model.set_params(self.paramlist, params) 
     69    def setp(self, p): 
     70        for k,v in zip(self.param_list, p): 
     71            self.model.setParam(k,v) 
     72        #self.model.set_params(self.param_list, params) 
    5073 
    5174    def getp(self): 
    52         return numpy.asarray(self.model.get_params(self.paramlist)) 
     75        return numpy.array([self.model.getParam(k) for k in self.param_list]) 
     76        #return numpy.asarray(self.model.get_params(self.param_list)) 
    5377 
    5478    def bounds(self): 
    55         return numpy.array([self._getrange(p) for p in self.paramlist]).T 
     79        return numpy.array([self._getrange(p) for p in self.param_list]).T 
    5680 
    5781    def labels(self): 
    58         return self.paramlist 
     82        return self.param_list 
    5983 
    6084    def _getrange(self, p): 
     
    6387        return the range of parameter 
    6488        """ 
    65         lo, hi = self.model.model.details[p][1:3] 
     89        lo, hi = self.model.details[p][1:3] 
    6690        if lo is None: lo = -numpy.inf 
    6791        if hi is None: hi = numpy.inf 
     
    6993 
    7094    def randomize(self, n): 
    71         pvec = self.getp() 
     95        p = self.getp() 
    7296        # since randn is symmetric and random, doesn't matter 
    7397        # point value is negative. 
    7498        # TODO: throw in bounds checking! 
    75         return numpy.random.randn(n, len(self.paramlist))*pvec + pvec 
     99        return numpy.random.randn(n, len(self.param_list))*p + p 
    76100 
    77101    def chisq(self): 
     
    84108 
    85109        """ 
    86         total = 0 
    87         for item in self.res: 
    88             total += item * item 
    89         if len(self.res) == 0: 
    90             return None 
    91         return total / len(self.res) 
     110        return numpy.sum(self.res**2)/self.dof 
    92111 
    93112    def residuals(self, params=None): 
     
    99118        #import thread 
    100119        #print "params", params 
    101         self.res, self.theory = self.data.residuals(self.model.eval) 
    102  
     120        self.res, self.theory = self.data.residuals(self.model.evalDistribution) 
     121 
     122        # TODO: this belongs in monitor not residuals calculation 
    103123        if self.fitresult is not None: 
    104             self.fitresult.set_model(model=self.model) 
     124            #self.fitresult.set_model(model=self.model) 
    105125            self.fitresult.residuals = self.res+0 
    106126            self.fitresult.iterations += 1 
     
    109129            #fitness = self.chisq(params=params) 
    110130            fitness = self.chisq() 
    111             self.fitresult.pvec = params 
     131            self.fitresult.p = params 
    112132            self.fitresult.set_fitness(fitness=fitness) 
    113133            if self.msg_q is not None: 
     
    131151    __call__ = residuals 
    132152 
    133     def check_param_range(self): 
     153    def _DEAD_check_param_range(self): 
    134154        """ 
    135155        Check the lower and upper bound of the parameter value 
     
    142162        is_outofbound = False 
    143163        # loop through the fit parameters 
    144         model = self.model.model 
    145         for p in self.paramlist: 
     164        model = self.model 
     165        for p in self.param_list: 
    146166            value = model.getParam(p) 
    147167            low,high = model.details[p][1:3] 
     
    196216            raise RuntimeError, msg 
    197217        elif len(fitproblem) == 0 : 
    198             raise RuntimeError, "No Assembly scheduled for Scipy fitting." 
     218            raise RuntimeError, "No problem scheduled for fitting." 
    199219        model = fitproblem[0].get_model() 
    200220        if reset_flag: 
     
    203223                ind = fitproblem[0].pars.index(name) 
    204224                model.setParam(name, fitproblem[0].vals[ind]) 
    205         listdata = [] 
    206225        listdata = fitproblem[0].get_data() 
    207226        # Concatenate dList set (contains one or more data)before fitting 
     
    209228 
    210229        self.curr_thread = curr_thread 
    211         ftol = ftol 
    212230 
    213231        result = FResult(model=model, data=data, param_list=self.param_list) 
     
    217235        if handler is not None: 
    218236            handler.set_result(result=result) 
    219         functor = SansAssembly(paramlist=self.param_list, 
    220                                model=model, 
    221                                data=data, 
    222                                handler=handler, 
    223                                fitresult=result, 
    224                                curr_thread=curr_thread, 
    225                                msg_q=msg_q) 
     237        problem = SasProblem(param_list=self.param_list, 
     238                              model=model.model, 
     239                              data=data, 
     240                              handler=handler, 
     241                              fitresult=result, 
     242                              curr_thread=curr_thread, 
     243                              msg_q=msg_q) 
    226244        try: 
    227             run_bumps(functor, result) 
     245            #run_bumps(problem, result, ftol) 
     246            run_scipy(problem, result, ftol) 
    228247        except: 
    229248            if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt: 
     
    245264        return [result] 
    246265 
    247 def run_bumps(problem, result): 
     266def run_bumps(problem, result, ftol): 
    248267    fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT] 
    249     fitdriver = fitters.FitDriver(fitopts.fitclass, problem=problem,  
    250         abort_test=lambda: False, **fitopts.options) 
     268    fitclass = fitopts.fitclass 
     269    options = fitopts.options.copy() 
     270    options['ftol'] = ftol 
     271    fitdriver = fitters.FitDriver(fitclass, problem=problem, 
     272                                  abort_test=lambda: False, **options) 
    251273    mapper = SerialMapper  
    252274    fitdriver.mapper = mapper.start_mapper(problem, None) 
     
    256278        import traceback; traceback.print_exc() 
    257279        raise 
    258     mapper.stop_mapper(fitdriver.mapper) 
    259     fitdriver.show() 
    260     #fitdriver.plot() 
    261     result.fitness = fbest * 2. / len(result.pars)  
    262     result.stderr  = numpy.ones(len(result.pars)) 
    263     result.pvec = best  
     280    finally: 
     281        mapper.stop_mapper(fitdriver.mapper) 
     282    #print "best,fbest",best,fbest,problem.dof 
     283    result.fitness = 2*fbest/problem.dof 
     284    #print "fitness",result.fitness 
     285    result.stderr  = fitdriver.stderr() 
     286    result.pvec = best 
     287    # TODO: track success better 
    264288    result.success = True 
    265289    result.theory = problem.theory 
    266290 
    267 def run_scipy(model, result): 
     291def run_scipy(model, result, ftol): 
    268292    # This import must be here; otherwise it will be confused when more 
    269293    # than one thread exist. 
    270294    from scipy import optimize 
    271295 
    272     out, cov_x, _, mesg, success = optimize.leastsq(functor, 
    273                                                     model.get_params(self.param_list), 
     296    out, cov_x, _, mesg, success = optimize.leastsq(model.residuals, 
     297                                                    model.getp(), 
    274298                                                    ftol=ftol, 
    275299                                                    full_output=1) 
     
    278302    else: 
    279303        stderr = [] 
    280     result.fitness = functor.chisqr() 
     304    result.fitness = model.chisq() 
    281305    result.stderr  = stderr 
    282306    result.pvec = out 
    283307    result.success = success 
    284     result.theory = functor.theory 
    285  
     308    result.theory = model.theory 
     309 
  • src/sans/fit/ParkFitting.py

    r6fe5100 r95d58d3  
    9393 
    9494 
    95 class Model(park.Model): 
     95class ParkModel(park.Model): 
    9696    """ 
    9797    PARK wrapper for SANS models. 
     
    391391        return fitpars 
    392392     
    393     def all_results(self, result): 
     393    def extend_results_with_calculated_parameters(self, result): 
    394394        """ 
    395395        Extend result from the fit with the calculated parameters. 
     
    439439                # dividing residuals by N in order to be consistent with Scipy 
    440440                m.chisq = numpy.sum(m.residuals**2/N)  
    441                 resid.append(m.weight*m.residuals/math.sqrt(N)) 
     441                resid.append(m.weight*m.residuals) 
    442442        self.residuals = numpy.hstack(resid) 
    443443        N = len(self.residuals) 
    444444        self.degrees_of_freedom = N-k if N>k else 1 
    445445        self.chisq = numpy.sum(self.residuals**2) 
    446         return self.chisq 
     446        return self.chisq/self.degrees_of_freedom 
    447447     
    448448class ParkFit(FitEngine): 
     
    505505            return 
    506506        for item in fitproblems: 
    507             parkmodel = item.get_model() 
     507            model = item.get_model() 
     508            parkmodel = ParkModel(model.model, model.data) 
    508509            if reset_flag: 
    509510                # reset the initial value; useful for batch 
     
    554555        localfit = SansFitSimplex() 
    555556        localfit.ftol = ftol 
    556          
     557        localfit.xtol = 1e-6 
     558 
    557559        # See `park.fitresult.FitHandler` for details. 
    558560        fitter = SansFitMC(localfit=localfit, start_points=1) 
     
    563565        try: 
    564566            result = fit.fit(self.problem, fitter=fitter, handler=handler) 
    565             self.problem.all_results(result) 
     567            self.problem.extend_results_with_calculated_parameters(result) 
    566568             
    567569        except LinAlgError: 
     
    592594                            name += '.' + name_split[2].strip() 
    593595                        small_result.param_list.append(name) 
     596                # normalize chisq by degrees of freedom 
     597                small_result.fitness /= len(small_result.residuals)-len(small_result.pvec) 
    594598            result_list.append(small_result)     
    595599        if q != None: 
  • src/sans/fit/ScipyFitting.py

    r6fe5100 r95d58d3  
    4949        if len(self.true_res) == 0: 
    5050            return None 
    51         return total / len(self.true_res) 
     51        return total / (len(self.true_res) - len(self.paramlist)) 
    5252 
    5353    def __call__(self, params): 
     
    205205        _check_param_range(model.model, self.param_list) 
    206206         
    207         result = FResult(model=model, data=data, param_list=self.param_list) 
     207        result = FResult(model=model.model, data=data, param_list=self.param_list) 
    208208        result.pars = fitproblem[0].pars 
    209209        result.fitter_id = self.fitter_id 
Note: See TracChangeset for help on using the changeset viewer.