Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sans/fit/ParkFitting.py

    r8d074d9 r9d6d5ba  
    2424from sans.fit.AbstractFitEngine import FitEngine 
    2525from sans.fit.AbstractFitEngine import FResult 
    26  
    27 class SansParameter(park.Parameter): 
    28     """ 
    29     SANS model parameters for use in the PARK fitting service. 
    30     The parameter attribute value is redirected to the underlying 
    31     parameter value in the SANS model. 
    32     """ 
    33     def __init__(self, name, model, data): 
    34         """ 
    35             :param name: the name of the model parameter 
    36             :param model: the sans model to wrap as a park model 
    37         """ 
    38         park.Parameter.__init__(self, name) 
    39         #self._model, self._name = model, name 
    40         self.data = data 
    41         self.model = model 
    42         #set the value for the parameter of the given name 
    43         self.set(model.getParam(name)) 
    44  
    45         # TODO: model is missing parameter ranges for dispersion parameters 
    46         if name not in model.details: 
    47             #print "setting details for",name 
    48             model.details[name] = ["", None, None] 
    49  
    50     def _getvalue(self): 
    51         """ 
    52         override the _getvalue of park parameter 
    53  
    54         :return value the parameter associates with self.name 
    55  
    56         """ 
    57         return self.model.getParam(self.name) 
    58  
    59     def _setvalue(self, value): 
    60         """ 
    61         override the _setvalue pf park parameter 
    62  
    63         :param value: the value to set on a given parameter 
    64  
    65         """ 
    66         self.model.setParam(self.name, value) 
    67  
    68     value = property(_getvalue, _setvalue) 
    69  
    70     def _getrange(self): 
    71         """ 
    72         Override _getrange of park parameter 
    73         return the range of parameter 
    74         """ 
    75         #if not  self.name in self._model.getDispParamList(): 
    76         lo, hi = self.model.details[self.name][1:3] 
    77         if lo is None: lo = -numpy.inf 
    78         if hi is None: hi = numpy.inf 
    79         if lo > hi: 
    80             raise ValueError, "wrong fit range for parameters" 
    81  
    82         return lo, hi 
    83  
    84     def get_name(self): 
    85         """ 
    86         """ 
    87         return self._getname() 
    88  
    89     def _setrange(self, r): 
    90         """ 
    91         override _setrange of park parameter 
    92  
    93         :param r: the value of the range to set 
    94  
    95         """ 
    96         self.model.details[self.name][1:3] = r 
    97     range = property(_getrange, _setrange) 
    98  
    99  
    100 class ParkModel(park.Model): 
    101     """ 
    102     PARK wrapper for SANS models. 
    103     """ 
    104     def __init__(self, sans_model, sans_data=None, **kw): 
    105         """ 
    106         :param sans_model: the sans model to wrap using park interface 
    107  
    108         """ 
    109         park.Model.__init__(self, **kw) 
    110         self.model = sans_model 
    111         self.name = sans_model.name 
    112         self.data = sans_data 
    113         #list of parameters names 
    114         self.sansp = sans_model.getParamList() 
    115         #list of park parameter 
    116         self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp] 
    117         #list of parameter set 
    118         self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp) 
    119         self.pars = [] 
    120  
    121     def get_params(self, fitparams): 
    122         """ 
    123         return a list of value of paramter to fit 
    124  
    125         :param fitparams: list of paramaters name to fit 
    126  
    127         """ 
    128         list_params = [] 
    129         self.pars = fitparams 
    130         for item in fitparams: 
    131             for element in self.parkp: 
    132                 if element.name == str(item): 
    133                     list_params.append(element.value) 
    134         return list_params 
    135  
    136     def set_params(self, paramlist, params): 
    137         """ 
    138         Set value for parameters to fit 
    139  
    140         :param params: list of value for parameters to fit 
    141  
    142         """ 
    143         try: 
    144             for i in range(len(self.parkp)): 
    145                 for j in range(len(paramlist)): 
    146                     if self.parkp[i].name == paramlist[j]: 
    147                         self.parkp[i].value = params[j] 
    148                         self.model.setParam(self.parkp[i].name, params[j]) 
    149         except: 
    150             raise 
    151  
    152     def eval(self, x): 
    153         """ 
    154             Override eval method of park model. 
    155  
    156             :param x: the x value used to compute a function 
    157         """ 
    158         try: 
    159             return self.model.evalDistribution(x) 
    160         except: 
    161             raise 
    162  
    163     def eval_derivs(self, x, pars=[]): 
    164         """ 
    165         Evaluate the model and derivatives wrt pars at x. 
    166  
    167         pars is a list of the names of the parameters for which derivatives 
    168         are desired. 
    169  
    170         This method needs to be specialized in the model to evaluate the 
    171         model function.  Alternatively, the model can implement is own 
    172         version of residuals which calculates the residuals directly 
    173         instead of calling eval. 
    174         """ 
    175         return [] 
    176  
    177  
     26   
    17827class SansFitResult(fitresult.FitResult): 
    17928    def __init__(self, *args, **kwrds): 
     
    395244        return fitpars 
    396245     
    397     def extend_results_with_calculated_parameters(self, result): 
     246    def all_results(self, result): 
    398247        """ 
    399248        Extend result from the fit with the calculated parameters. 
     
    443292                # dividing residuals by N in order to be consistent with Scipy 
    444293                m.chisq = numpy.sum(m.residuals**2/N)  
    445                 resid.append(m.weight*m.residuals) 
     294                resid.append(m.weight*m.residuals/math.sqrt(N)) 
    446295        self.residuals = numpy.hstack(resid) 
    447296        N = len(self.residuals) 
    448297        self.degrees_of_freedom = N-k if N>k else 1 
    449298        self.chisq = numpy.sum(self.residuals**2) 
    450         return self.chisq/self.degrees_of_freedom 
     299        return self.chisq 
    451300     
    452301class ParkFit(FitEngine): 
     
    505354            if fproblem.get_to_fit() == 1: 
    506355                fitproblems.append(fproblem) 
    507         if len(fitproblems) == 0: 
     356        if len(fitproblems) == 0:  
    508357            raise RuntimeError, "No Assembly scheduled for Park fitting." 
     358            return 
    509359        for item in fitproblems: 
    510             model = item.get_model() 
    511             parkmodel = ParkModel(model.model, model.data) 
    512             parkmodel.pars = item.pars 
     360            parkmodel = item.get_model() 
    513361            if reset_flag: 
    514362                # reset the initial value; useful for batch 
     
    516364                    ind = item.pars.index(name) 
    517365                    parkmodel.model.setParam(name, item.vals[ind]) 
    518  
    519             # set the constraints into the model 
    520             for p,v in item.constraints: 
    521                 parkmodel.parameterset[str(p)].set(str(v)) 
    522366             
    523367            for p in parkmodel.parameterset: 
    524368                ## does not allow status change for constraint parameters 
    525369                if p.status != 'computed': 
    526                     if p.get_name() in item.pars: 
     370                    if p.get_name()in item.pars: 
    527371                        ## make parameters selected for  
    528372                        #fit will be between boundaries 
     
    539383    def fit(self, msg_q=None,  
    540384            q=None, handler=None, curr_thread=None,  
    541             ftol=1.49012e-8, reset_flag=False): 
     385                                        ftol=1.49012e-8, reset_flag=False): 
    542386        """ 
    543387        Performs fit with park.fit module.It can  perform fit with one model 
     
    563407        localfit = SansFitSimplex() 
    564408        localfit.ftol = ftol 
    565         localfit.xtol = 1e-6 
    566  
     409         
    567410        # See `park.fitresult.FitHandler` for details. 
    568411        fitter = SansFitMC(localfit=localfit, start_points=1) 
     
    573416        try: 
    574417            result = fit.fit(self.problem, fitter=fitter, handler=handler) 
    575             self.problem.extend_results_with_calculated_parameters(result) 
     418            self.problem.all_results(result) 
    576419             
    577420        except LinAlgError: 
    578421            raise ValueError, "SVD did not converge" 
    579  
    580         if result is None: 
    581             raise RuntimeError("park did not return a fit result") 
    582422     
    583423        for m in self.problem.parts: 
     
    587427            small_result.theory = theory 
    588428            small_result.residuals = residuals 
    589             small_result.index = m.data.idx 
    590             small_result.fitness = result.fitness 
    591  
    592             # Extract the parameters that are part of this model; make sure 
    593             # they match the fitted parameters for this model, and place them 
    594             # in the same order as they occur in the model. 
    595             pars = {} 
    596             for p in result.parameters: 
    597                 #if p.data.name == small_result.data.name and 
    598                 if p.model.name == small_result.model.name: 
    599                     model_name, par_name = p.name.split('.', 1) 
    600                     pars[par_name] = (p.value, p.stderr) 
    601             #assert len(pars.keys()) == len(m.model.pars) 
    602             v,dv = zip(*[pars[p] for p in m.model.pars]) 
    603             small_result.pvec = v 
    604             small_result.stderr = dv 
    605             small_result.param_list = m.model.pars 
    606  
    607             # normalize chisq by degrees of freedom 
    608             dof = len(small_result.residuals)-len(small_result.pvec) 
    609             small_result.fitness = numpy.sum(residuals**2)/dof 
    610  
     429            small_result.pvec = [] 
     430            small_result.cov = [] 
     431            small_result.stderr = [] 
     432            small_result.param_list = [] 
     433            small_result.residuals = m.residuals 
     434            if result is not None: 
     435                for p in result.parameters: 
     436                    if p.data.name == small_result.data.name and \ 
     437                            p.model.name == small_result.model.name: 
     438                        small_result.index = m.data.idx 
     439                        small_result.fitness = result.fitness 
     440                        small_result.pvec.append(p.value) 
     441                        small_result.stderr.append(p.stderr) 
     442                        name_split = p.name.split('.') 
     443                        name = name_split[1].strip() 
     444                        if len(name_split) > 2: 
     445                            name += '.' + name_split[2].strip() 
     446                        small_result.param_list.append(name) 
    611447            result_list.append(small_result)     
    612448        if q != None: 
Note: See TracChangeset for help on using the changeset viewer.