Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/ParkFitting.py

    r9d6d5ba r8d074d9  
    2424from sans.fit.AbstractFitEngine import FitEngine 
    2525from sans.fit.AbstractFitEngine import FResult 
    26    
     26 
     27class SansParameter(park.Parameter): 
     28    """ 
     29    SANS model parameters for use in the PARK fitting service. 
     30    The parameter attribute value is redirected to the underlying 
     31    parameter value in the SANS model. 
     32    """ 
     33    def __init__(self, name, model, data): 
     34        """ 
     35            :param name: the name of the model parameter 
     36            :param model: the sans model to wrap as a park model 
     37        """ 
     38        park.Parameter.__init__(self, name) 
     39        #self._model, self._name = model, name 
     40        self.data = data 
     41        self.model = model 
     42        #set the value for the parameter of the given name 
     43        self.set(model.getParam(name)) 
     44 
     45        # TODO: model is missing parameter ranges for dispersion parameters 
     46        if name not in model.details: 
     47            #print "setting details for",name 
     48            model.details[name] = ["", None, None] 
     49 
     50    def _getvalue(self): 
     51        """ 
     52        override the _getvalue of park parameter 
     53 
     54        :return value the parameter associates with self.name 
     55 
     56        """ 
     57        return self.model.getParam(self.name) 
     58 
     59    def _setvalue(self, value): 
     60        """ 
     61        override the _setvalue pf park parameter 
     62 
     63        :param value: the value to set on a given parameter 
     64 
     65        """ 
     66        self.model.setParam(self.name, value) 
     67 
     68    value = property(_getvalue, _setvalue) 
     69 
     70    def _getrange(self): 
     71        """ 
     72        Override _getrange of park parameter 
     73        return the range of parameter 
     74        """ 
     75        #if not  self.name in self._model.getDispParamList(): 
     76        lo, hi = self.model.details[self.name][1:3] 
     77        if lo is None: lo = -numpy.inf 
     78        if hi is None: hi = numpy.inf 
     79        if lo > hi: 
     80            raise ValueError, "wrong fit range for parameters" 
     81 
     82        return lo, hi 
     83 
     84    def get_name(self): 
     85        """ 
     86        """ 
     87        return self._getname() 
     88 
     89    def _setrange(self, r): 
     90        """ 
     91        override _setrange of park parameter 
     92 
     93        :param r: the value of the range to set 
     94 
     95        """ 
     96        self.model.details[self.name][1:3] = r 
     97    range = property(_getrange, _setrange) 
     98 
     99 
     100class ParkModel(park.Model): 
     101    """ 
     102    PARK wrapper for SANS models. 
     103    """ 
     104    def __init__(self, sans_model, sans_data=None, **kw): 
     105        """ 
     106        :param sans_model: the sans model to wrap using park interface 
     107 
     108        """ 
     109        park.Model.__init__(self, **kw) 
     110        self.model = sans_model 
     111        self.name = sans_model.name 
     112        self.data = sans_data 
     113        #list of parameters names 
     114        self.sansp = sans_model.getParamList() 
     115        #list of park parameter 
     116        self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp] 
     117        #list of parameter set 
     118        self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp) 
     119        self.pars = [] 
     120 
     121    def get_params(self, fitparams): 
     122        """ 
     123        return a list of value of paramter to fit 
     124 
     125        :param fitparams: list of paramaters name to fit 
     126 
     127        """ 
     128        list_params = [] 
     129        self.pars = fitparams 
     130        for item in fitparams: 
     131            for element in self.parkp: 
     132                if element.name == str(item): 
     133                    list_params.append(element.value) 
     134        return list_params 
     135 
     136    def set_params(self, paramlist, params): 
     137        """ 
     138        Set value for parameters to fit 
     139 
     140        :param params: list of value for parameters to fit 
     141 
     142        """ 
     143        try: 
     144            for i in range(len(self.parkp)): 
     145                for j in range(len(paramlist)): 
     146                    if self.parkp[i].name == paramlist[j]: 
     147                        self.parkp[i].value = params[j] 
     148                        self.model.setParam(self.parkp[i].name, params[j]) 
     149        except: 
     150            raise 
     151 
     152    def eval(self, x): 
     153        """ 
     154            Override eval method of park model. 
     155 
     156            :param x: the x value used to compute a function 
     157        """ 
     158        try: 
     159            return self.model.evalDistribution(x) 
     160        except: 
     161            raise 
     162 
     163    def eval_derivs(self, x, pars=[]): 
     164        """ 
     165        Evaluate the model and derivatives wrt pars at x. 
     166 
     167        pars is a list of the names of the parameters for which derivatives 
     168        are desired. 
     169 
     170        This method needs to be specialized in the model to evaluate the 
     171        model function.  Alternatively, the model can implement is own 
     172        version of residuals which calculates the residuals directly 
     173        instead of calling eval. 
     174        """ 
     175        return [] 
     176 
     177 
    27178class SansFitResult(fitresult.FitResult): 
    28179    def __init__(self, *args, **kwrds): 
     
    244395        return fitpars 
    245396     
    246     def all_results(self, result): 
     397    def extend_results_with_calculated_parameters(self, result): 
    247398        """ 
    248399        Extend result from the fit with the calculated parameters. 
     
    292443                # dividing residuals by N in order to be consistent with Scipy 
    293444                m.chisq = numpy.sum(m.residuals**2/N)  
    294                 resid.append(m.weight*m.residuals/math.sqrt(N)) 
     445                resid.append(m.weight*m.residuals) 
    295446        self.residuals = numpy.hstack(resid) 
    296447        N = len(self.residuals) 
    297448        self.degrees_of_freedom = N-k if N>k else 1 
    298449        self.chisq = numpy.sum(self.residuals**2) 
    299         return self.chisq 
     450        return self.chisq/self.degrees_of_freedom 
    300451     
    301452class ParkFit(FitEngine): 
     
    354505            if fproblem.get_to_fit() == 1: 
    355506                fitproblems.append(fproblem) 
    356         if len(fitproblems) == 0:  
     507        if len(fitproblems) == 0: 
    357508            raise RuntimeError, "No Assembly scheduled for Park fitting." 
    358             return 
    359509        for item in fitproblems: 
    360             parkmodel = item.get_model() 
     510            model = item.get_model() 
     511            parkmodel = ParkModel(model.model, model.data) 
     512            parkmodel.pars = item.pars 
    361513            if reset_flag: 
    362514                # reset the initial value; useful for batch 
     
    364516                    ind = item.pars.index(name) 
    365517                    parkmodel.model.setParam(name, item.vals[ind]) 
     518 
     519            # set the constraints into the model 
     520            for p,v in item.constraints: 
     521                parkmodel.parameterset[str(p)].set(str(v)) 
    366522             
    367523            for p in parkmodel.parameterset: 
    368524                ## does not allow status change for constraint parameters 
    369525                if p.status != 'computed': 
    370                     if p.get_name()in item.pars: 
     526                    if p.get_name() in item.pars: 
    371527                        ## make parameters selected for  
    372528                        #fit will be between boundaries 
     
    383539    def fit(self, msg_q=None,  
    384540            q=None, handler=None, curr_thread=None,  
    385                                         ftol=1.49012e-8, reset_flag=False): 
     541            ftol=1.49012e-8, reset_flag=False): 
    386542        """ 
    387543        Performs fit with park.fit module.It can  perform fit with one model 
     
    407563        localfit = SansFitSimplex() 
    408564        localfit.ftol = ftol 
    409          
     565        localfit.xtol = 1e-6 
     566 
    410567        # See `park.fitresult.FitHandler` for details. 
    411568        fitter = SansFitMC(localfit=localfit, start_points=1) 
     
    416573        try: 
    417574            result = fit.fit(self.problem, fitter=fitter, handler=handler) 
    418             self.problem.all_results(result) 
     575            self.problem.extend_results_with_calculated_parameters(result) 
    419576             
    420577        except LinAlgError: 
    421578            raise ValueError, "SVD did not converge" 
     579 
     580        if result is None: 
     581            raise RuntimeError("park did not return a fit result") 
    422582     
    423583        for m in self.problem.parts: 
     
    427587            small_result.theory = theory 
    428588            small_result.residuals = residuals 
    429             small_result.pvec = [] 
    430             small_result.cov = [] 
    431             small_result.stderr = [] 
    432             small_result.param_list = [] 
    433             small_result.residuals = m.residuals 
    434             if result is not None: 
    435                 for p in result.parameters: 
    436                     if p.data.name == small_result.data.name and \ 
    437                             p.model.name == small_result.model.name: 
    438                         small_result.index = m.data.idx 
    439                         small_result.fitness = result.fitness 
    440                         small_result.pvec.append(p.value) 
    441                         small_result.stderr.append(p.stderr) 
    442                         name_split = p.name.split('.') 
    443                         name = name_split[1].strip() 
    444                         if len(name_split) > 2: 
    445                             name += '.' + name_split[2].strip() 
    446                         small_result.param_list.append(name) 
     589            small_result.index = m.data.idx 
     590            small_result.fitness = result.fitness 
     591 
     592            # Extract the parameters that are part of this model; make sure 
     593            # they match the fitted parameters for this model, and place them 
     594            # in the same order as they occur in the model. 
     595            pars = {} 
     596            for p in result.parameters: 
     597                #if p.data.name == small_result.data.name and 
     598                if p.model.name == small_result.model.name: 
     599                    model_name, par_name = p.name.split('.', 1) 
     600                    pars[par_name] = (p.value, p.stderr) 
     601            #assert len(pars.keys()) == len(m.model.pars) 
     602            v,dv = zip(*[pars[p] for p in m.model.pars]) 
     603            small_result.pvec = v 
     604            small_result.stderr = dv 
     605            small_result.param_list = m.model.pars 
     606 
     607            # normalize chisq by degrees of freedom 
     608            dof = len(small_result.residuals)-len(small_result.pvec) 
     609            small_result.fitness = numpy.sum(residuals**2)/dof 
     610 
    447611            result_list.append(small_result)     
    448612        if q != None: 
Note: See TracChangeset for help on using the changeset viewer.