Changeset 6fe5100 in sasview for src/sans/fit/ScipyFitting.py
- Timestamp:
- Apr 6, 2014 7:29:59 AM (11 years ago)
- Branches:
- master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
- Children:
- 95d58d3
- Parents:
- 960fdbb
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
src/sans/fit/ScipyFitting.py
r5777106 r6fe5100 1 2 3 1 """ 4 2 ScipyFitting module contains FitArrange , ScipyFit, … … 6 4 simple fit with scipy optimizer. 7 5 """ 6 import sys 7 import copy 8 8 9 9 import numpy 10 import sys11 12 10 13 11 from sans.fit.AbstractFitEngine import FitEngine 14 from sans.fit.AbstractFitEngine import SansAssembly 15 from sans.fit.AbstractFitEngine import FitAbort 16 from sans.fit.AbstractFitEngine import Model 17 from sans.fit.AbstractFitEngine import FResult 12 from sans.fit.AbstractFitEngine import FResult 13 14 class SansAssembly: 15 """ 16 Sans Assembly class a class wrapper to be call in optimizer.leastsq method 17 """ 18 def __init__(self, paramlist, model=None, data=None, fitresult=None, 19 handler=None, curr_thread=None, msg_q=None): 20 """ 21 :param Model: the model wrapper fro sans -model 22 :param Data: the data wrapper for sans data 23 24 """ 25 self.model = model 26 self.data = data 27 self.paramlist = paramlist 28 self.msg_q = msg_q 29 self.curr_thread = curr_thread 30 self.handler = handler 31 self.fitresult = fitresult 32 self.res = [] 33 self.true_res = [] 34 self.func_name = "Functor" 35 self.theory = None 36 37 def chisq(self): 38 """ 39 Calculates chi^2 40 41 :param params: list of parameter values 42 43 :return: chi^2 44 45 """ 46 total = 0 47 for item in self.true_res: 48 total += item * item 49 if len(self.true_res) == 0: 50 return None 51 return total / len(self.true_res) 52 53 def __call__(self, params): 54 """ 55 Compute residuals 56 :param params: value of parameters to fit 57 """ 58 #import thread 59 self.model.set_params(self.paramlist, params) 60 #print "params", params 61 self.true_res, theory = self.data.residuals(self.model.eval) 62 self.theory = copy.deepcopy(theory) 63 # check parameters range 64 if self.check_param_range(): 65 # if the param value is outside of the bound 66 # just silent return res = inf 67 return self.res 68 self.res = self.true_res 69 70 if self.fitresult is not None: 71 self.fitresult.set_model(model=self.model) 72 self.fitresult.residuals = self.true_res 73 self.fitresult.iterations += 1 74 self.fitresult.theory = theory 75 76 #fitness = self.chisq(params=params) 77 fitness = self.chisq() 78 self.fitresult.pvec = params 79 self.fitresult.set_fitness(fitness=fitness) 80 if self.msg_q is not None: 81 self.msg_q.put(self.fitresult) 82 83 if self.handler is not None: 84 self.handler.set_result(result=self.fitresult) 85 self.handler.update_fit() 86 87 if self.curr_thread != None: 88 try: 89 self.curr_thread.isquit() 90 except: 91 #msg = "Fitting: Terminated... Note: Forcing to stop " 92 #msg += "fitting may cause a 'Functor error message' " 93 #msg += "being recorded in the log file....." 94 #self.handler.stop(msg) 95 raise 96 97 return self.res 98 99 def check_param_range(self): 100 """ 101 Check the lower and upper bound of the parameter value 102 and set res to the inf if the value is outside of the 103 range 104 :limitation: the initial values must be within range. 105 """ 106 107 #time.sleep(0.01) 108 is_outofbound = False 109 # loop through the fit parameters 110 model = self.model.model 111 for p in self.paramlist: 112 value = model.getParam(p) 113 low,high = model.details[p][1:3] 114 if low is not None and numpy.isfinite(low): 115 if p.value == 0: 116 # This value works on Scipy 117 # Do not change numbers below 118 value = _SMALLVALUE 119 # For leastsq, it needs a bit step back from the boundary 120 val = low - value * _SMALLVALUE 121 if value < val: 122 self.res *= 1e+6 123 is_outofbound = True 124 break 125 if high is not None and numpy.isfinite(high): 126 # This value works on Scipy 127 # Do not change numbers below 128 if value == 0: 129 value = _SMALLVALUE 130 # For leastsq, it needs a bit step back from the boundary 131 val = high + value * _SMALLVALUE 132 if value > val: 133 self.res *= 1e+6 134 is_outofbound = True 135 break 136 137 return is_outofbound 18 138 19 139 class ScipyFit(FitEngine): … … 50 170 """ 51 171 FitEngine.__init__(self) 52 self.fit_arrange_dict = {}53 self.param_list = []54 172 self.curr_thread = None 55 173 #def fit(self, *args, **kw): … … 68 186 msg = "Scipy can't fit more than a single fit problem at a time." 69 187 raise RuntimeError, msg 70 return 71 elif len(fitproblem) == 0 : 188 elif len(fitproblem) == 0 : 72 189 raise RuntimeError, "No Assembly scheduled for Scipy fitting." 73 return74 190 model = fitproblem[0].get_model() 75 191 if reset_flag: … … 87 203 88 204 # Check the initial value if it is within range 89 self._check_param_range(model)205 _check_param_range(model.model, self.param_list) 90 206 91 207 result = FResult(model=model, data=data, param_list=self.param_list) … … 94 210 if handler is not None: 95 211 handler.set_result(result=result) 212 functor = SansAssembly(paramlist=self.param_list, 213 model=model, 214 data=data, 215 handler=handler, 216 fitresult=result, 217 curr_thread=curr_thread, 218 msg_q=msg_q) 96 219 try: 97 220 # This import must be here; otherwise it will be confused when more … … 99 222 from scipy import optimize 100 223 101 functor = SansAssembly(paramlist=self.param_list,102 model=model,103 data=data,104 handler=handler,105 fitresult=result,106 curr_thread=curr_thread,107 msg_q=msg_q)108 224 out, cov_x, _, mesg, success = optimize.leastsq(functor, 109 225 model.get_params(self.param_list), 110 111 226 ftol=ftol, 227 full_output=1) 112 228 except: 113 229 if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt: … … 142 258 143 259 144 def _check_param_range(self, model): 145 """ 146 Check parameter range and set the initial value inside 147 if it is out of range. 148 149 : model: park model object 150 """ 151 is_outofbound = False 152 # loop through parameterset 153 for p in model.parameterset: 154 param_name = p.get_name() 155 # proceed only if the parameter name is in the list of fitting 156 if param_name in self.param_list: 157 # if the range was defined, check the range 158 if numpy.isfinite(p.range[0]): 159 if p.value <= p.range[0]: 160 # 10 % backing up from the border if not zero 161 # for Scipy engine to work properly. 162 shift = self._get_zero_shift(p.range[0]) 163 new_value = p.range[0] + shift 164 p.value = new_value 165 is_outofbound = True 166 if numpy.isfinite(p.range[1]): 167 if p.value >= p.range[1]: 168 shift = self._get_zero_shift(p.range[1]) 169 # 10 % backing up from the border if not zero 170 # for Scipy engine to work properly. 171 new_value = p.range[1] - shift 172 # Check one more time if the new value goes below 173 # the low bound, If so, re-evaluate the value 174 # with the mean of the range. 175 if numpy.isfinite(p.range[0]): 176 if new_value < p.range[0]: 177 new_value = (p.range[0] + p.range[1]) / 2.0 178 # Todo: 179 # Need to think about when both min and max are same. 180 p.value = new_value 181 is_outofbound = True 182 183 return is_outofbound 184 185 def _get_zero_shift(self, range): 186 """ 187 Get 10% shift of the param value = 0 based on the range value 188 189 : param range: min or max value of the bounds 190 """ 191 if range == 0: 192 shift = 0.1 193 else: 194 shift = 0.1 * range 195 196 return shift 197 260 def _check_param_range(model, param_list): 261 """ 262 Check parameter range and set the initial value inside 263 if it is out of range. 264 265 : model: park model object 266 """ 267 # loop through parameterset 268 for p in param_list: 269 value = model.getParam(p) 270 low,high = model.details[p][1:3] 271 # if the range was defined, check the range 272 if low is not None and value <= low: 273 value = low + _get_zero_shift(low) 274 if high is not None and value > high: 275 value = high - _get_zero_shift(high) 276 # Check one more time if the new value goes below 277 # the low bound, If so, re-evaluate the value 278 # with the mean of the range. 279 if low is not None and value < low: 280 value = 0.5 * (low+high) 281 model.setParam(p, value) 282 283 def _get_zero_shift(limit): 284 """ 285 Get 10% shift of the param value = 0 based on the range value 286 287 : param range: min or max value of the bounds 288 """ 289 return 0.1 (limit if limit != 0.0 else 1.0) 290 198 291 199 292 #def profile(fn, *args, **kw):
Note: See TracChangeset
for help on using the changeset viewer.