- Timestamp:
- May 15, 2014 11:23:22 AM (10 years ago)
- Branches:
- master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
- Children:
- 4e9f227
- Parents:
- 76f132a
- Location:
- src/sans
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
src/sans/fit/BumpsFitting.py
r8d074d9 re3efa6b3 8 8 from bumps import fitters 9 9 from bumps.mapper import SerialMapper 10 from bumps import parameter 11 from bumps.fitproblem import FitProblem 10 12 11 13 from sans.fit.AbstractFitEngine import FitEngine … … 21 23 22 24 def __call__(self, history): 25 if self.handler is None: return 23 26 self.handler.progress(history.step[0], self.max_step) 24 27 if len(history.step)>1 and history.step[1] > history.step[0]: … … 46 49 self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1])) 47 50 except: 48 self.convergence.append((best, )) 49 50 class SasProblem(object): 51 """ 52 Wrap the SAS model in a form that can be understood by bumps. 53 """ 54 def __init__(self, param_list, model=None, data=None, fitresult=None, 55 handler=None, curr_thread=None, msg_q=None): 56 """ 57 :param Model: the model wrapper fro sans -model 58 :param Data: the data wrapper for sans data 59 """ 51 self.convergence.append((best, best,best,best,best,best)) 52 53 54 class SasFitness(object): 55 """ 56 Wrap SAS model as a bumps fitness object 57 """ 58 def __init__(self, name, model, data, fitted=[], **kw): 59 self.name = name 60 60 self.model = model 61 61 self.data = data 62 self.param_list = param_list 63 self.res = None 64 self.theory = None 65 66 @property 67 def name(self): 68 return self.model.name 69 70 @property 71 def dof(self): 72 return self.data.num_points - len(self.param_list) 73 74 def summarize(self): 75 """ 76 Return a stylized list of parameter names and values with range bars 77 suitable for printing. 78 """ 79 output = [] 80 bounds = self.bounds() 81 for i,p in enumerate(self.getp()): 82 name = self.param_list[i] 83 low,high = bounds[:,i] 84 range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"), 85 ("%g]"%high if numpy.isfinite(high) else "inf)"))) 86 if not numpy.isfinite(p): 87 bar = "*invalid* " 62 self._define_pars() 63 self._init_pars(kw) 64 self.set_fitted(fitted) 65 self._dirty = True 66 67 def _define_pars(self): 68 self._pars = {} 69 for k in self.model.getParamList(): 70 name = ".".join((self.name,k)) 71 value = self.model.getParam(k) 72 bounds = self.model.details.get(k,["",None,None])[1:3] 73 self._pars[k] = parameter.Parameter(value=value, bounds=bounds, 74 fixed=True, name=name) 75 76 def _init_pars(self, kw): 77 for k,v in kw.items(): 78 # dispersion parameters initialized with _field instead of .field 79 if k.endswith('_width'): k = k[:-6]+'.width' 80 elif k.endswith('_npts'): k = k[:-5]+'.npts' 81 elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas' 82 elif k.endswith('_type'): k = k[:-5]+'.type' 83 if k not in self._pars: 84 formatted_pars = ", ".join(sorted(self._pars.keys())) 85 raise KeyError("invalid parameter %r for %s--use one of: %s" 86 %(k, self.model, formatted_pars)) 87 if '.' in k and not k.endswith('.width'): 88 self.model.setParam(k, v) 89 elif isinstance(v, parameter.BaseParameter): 90 self._pars[k] = v 91 elif isinstance(v, (tuple,list)): 92 low, high = v 93 self._pars[k].value = (low+high)/2 94 self._pars[k].range(low,high) 88 95 else: 89 bar = ['.']*10 90 if numpy.isfinite(high-low): 91 position = int(9.999999999 * float(p-low)/float(high-low)) 92 if position < 0: bar[0] = '<' 93 elif position > 9: bar[9] = '>' 94 else: bar[position] = '|' 95 bar = "".join(bar) 96 output.append("%40s %s %10g in %s"%(name,bar,p,range)) 97 return "\n".join(output) 98 99 def nllf(self, p=None): 100 residuals = self.residuals(p) 101 return 0.5*numpy.sum(residuals**2) 102 103 def setp(self, p): 104 for k,v in zip(self.param_list, p): 105 self.model.setParam(k,v) 106 #self.model.set_params(self.param_list, params) 107 108 def getp(self): 109 return numpy.array([self.model.getParam(k) for k in self.param_list]) 110 #return numpy.asarray(self.model.get_params(self.param_list)) 111 112 def bounds(self): 113 return numpy.array([self._getrange(p) for p in self.param_list]).T 114 115 def labels(self): 116 return self.param_list 117 118 def _getrange(self, p): 119 """ 120 Override _getrange of park parameter 121 return the range of parameter 122 """ 123 lo, hi = self.model.details.get(p,["",None,None])[1:3] 124 if lo is None: lo = -numpy.inf 125 if hi is None: hi = numpy.inf 126 return lo, hi 127 128 def randomize(self, n): 129 p = self.getp() 130 # since randn is symmetric and random, doesn't matter 131 # point value is negative. 132 # TODO: throw in bounds checking! 133 return numpy.random.randn(n, len(self.param_list))*p + p 134 135 def chisq(self): 136 """ 137 Calculates chi^2 138 139 :param params: list of parameter values 140 141 :return: chi^2 142 143 """ 144 return numpy.sum(self.res**2)/self.dof 145 146 def residuals(self, params=None): 147 """ 148 Compute residuals 149 :param params: value of parameters to fit 150 """ 151 if params is not None: self.setp(params) 152 #import thread 153 #print "params", params 154 self.res, self.theory = self.data.residuals(self.model.evalDistribution) 155 return self.res 156 157 BOUNDS_PENALTY = 1e6 # cost for going out of bounds on unbounded fitters 158 class MonitoredSasProblem(SasProblem): 159 """ 160 SAS problem definition for optimizers which do not have monitoring or bounds. 161 """ 162 def __init__(self, param_list, model=None, data=None, fitresult=None, 163 handler=None, curr_thread=None, msg_q=None, update_rate=1): 164 """ 165 :param Model: the model wrapper fro sans -model 166 :param Data: the data wrapper for sans data 167 """ 168 SasProblem.__init__(self, param_list, model, data) 169 self.msg_q = msg_q 170 self.curr_thread = curr_thread 171 self.handler = handler 172 self.fitresult = fitresult 173 #self.last_update = time.time() 174 #self.func_name = "Functor" 175 #self.name = "Fill in proper name!" 176 177 def residuals(self, p): 178 """ 179 Cost function for scipy.optimize.leastsq, which does not have a monitor 180 built into the algorithm, and instead relies on a monitor built into 181 the cost function. 182 """ 183 # Note: technically, unbounded fitters and unmonitored fitters are 184 self.setp(p) 185 186 # Compute penalty for being out of bounds which increases the farther 187 # you get out of bounds. This allows derivative following algorithms 188 # to point back toward the feasible region. 189 penalty = self.bounds_penalty() 190 if penalty > 0: 191 self.theory = numpy.ones(self.data.num_points) 192 self.res = self.theory*(penalty/self.data.num_points) + BOUNDS_PENALTY 193 return self.res 194 195 # If no penalty, then we are not out of bounds and we can use the 196 # normal residual calculation 197 SasProblem.residuals(self, p) 198 199 # send update to the application 200 if True: 201 #self.fitresult.set_model(model=self.model) 202 # copy residuals into fit results 203 self.fitresult.residuals = self.res+0 204 self.fitresult.iterations += 1 205 self.fitresult.theory = self.theory+0 206 207 self.fitresult.p = numpy.array(p) # force copy, and coversion to array 208 self.fitresult.set_fitness(fitness=self.chisq()) 209 if self.msg_q is not None: 210 self.msg_q.put(self.fitresult) 211 212 if self.handler is not None: 213 self.handler.set_result(result=self.fitresult) 214 self.handler.update_fit() 215 216 if self.curr_thread != None: 217 try: 218 self.curr_thread.isquit() 219 except: 220 #msg = "Fitting: Terminated... Note: Forcing to stop " 221 #msg += "fitting may cause a 'Functor error message' " 222 #msg += "being recorded in the log file....." 223 #self.handler.stop(msg) 224 raise 225 226 return self.res 227 228 def bounds_penalty(self): 229 from numpy import sum, where 230 p, bounds = self.getp(), self.bounds() 231 return (sum(where(p<bounds[:,0], bounds[:,0]-p, 0)**2) 232 + sum(where(p>bounds[:,1], bounds[:,1]-p, 0)**2) ) 96 self._pars[k].value = v 97 self.update() 98 99 def set_fitted(self, param_list): 100 """ 101 Flag a set of parameters as fitted parameters. 102 """ 103 for k,p in self._pars.items(): 104 p.fixed = (k not in param_list) 105 self.fitted_pars = [self._pars[k] for k in param_list] 106 self.fitted_par_names = param_list 107 108 # ===== Fitness interface ==== 109 def parameters(self): 110 return self._pars 111 112 def update(self): 113 for k,v in self._pars.items(): 114 self.model.setParam(k,v.value) 115 self._dirty = True 116 117 def _recalculate(self): 118 if self._dirty: 119 self._residuals, self._theory = self.data.residuals(self.model.evalDistribution) 120 self._dirty = False 121 122 def numpoints(self): 123 return numpy.sum(self.data.idx) # number of fitted points 124 125 def nllf(self): 126 return 0.5*numpy.sum(self.residuals()**2) 127 128 def theory(self): 129 self._recalculate() 130 return self._theory 131 132 def residuals(self): 133 self._recalculate() 134 return self._residuals 135 136 # Not implementing the data methods for now: 137 # 138 # resynth_data/restore_data/save/plot 233 139 234 140 class BumpsFit(FitEngine): … … 247 153 q=None, handler=None, curr_thread=None, 248 154 ftol=1.49012e-8, reset_flag=False): 249 """ 250 """ 251 fitproblem = [] 252 for fproblem in self.fit_arrange_dict.itervalues(): 253 if fproblem.get_to_fit() == 1: 254 fitproblem.append(fproblem) 255 if len(fitproblem) > 1 : 256 msg = "Bumps can't fit more than a single fit problem at a time." 257 raise RuntimeError, msg 258 elif len(fitproblem) == 0 : 259 raise RuntimeError, "No problem scheduled for fitting." 260 model = fitproblem[0].get_model() 261 if reset_flag: 262 # reset the initial value; useful for batch 263 for name in fitproblem[0].pars: 264 ind = fitproblem[0].pars.index(name) 265 model.setParam(name, fitproblem[0].vals[ind]) 266 data = fitproblem[0].get_data() 267 268 self.curr_thread = curr_thread 269 270 result = FResult(model=model, data=data, param_list=self.param_list) 271 result.pars = fitproblem[0].pars 272 result.fitter_id = self.fitter_id 273 result.index = data.idx 274 if handler is not None: 275 handler.set_result(result=result) 276 277 if True: # bumps 278 problem = SasProblem(param_list=self.param_list, 279 model=model.model, 280 data=data) 281 run_bumps(problem, result, ftol, 282 handler, curr_thread, msg_q) 283 else: # scipy levenburg marquardt 284 problem = SasProblem(param_list=self.param_list, 285 model=model.model, 286 data=data, 287 handler=handler, 288 fitresult=result, 289 curr_thread=curr_thread, 290 msg_q=msg_q) 291 run_levenburg_marquardt(problem, result, ftol) 292 155 # Build collection of bumps fitness calculators 156 models = [ SasFitness(name="M%d"%(i+1), 157 model=M.get_model().model, 158 data=M.get_data(), 159 fitted=M.pars) 160 for i,M in enumerate(self.fit_arrange_dict.values()) 161 if M.get_to_fit() == 1 ] 162 problem = FitProblem(models) 163 164 # Run the fit 165 result = run_bumps(problem, handler, curr_thread) 293 166 if handler is not None: 294 167 handler.update_fit(last=True) 168 169 # TODO: shouldn't reference internal parameters 170 varying = problem._parameters 171 # collect the results 172 all_results = [] 173 for M in problem.models: 174 fitness = M.fitness 175 fitted_index = [varying.index(p) for p in fitness.fitted_pars] 176 R = FResult(model=fitness.model, data=fitness.data, 177 param_list=fitness.fitted_par_names) 178 R.theory = fitness.theory() 179 R.residuals = fitness.residuals() 180 R.fitter_id = self.fitter_id 181 R.stderr = result['stderr'][fitted_index] 182 R.pvec = result['value'][fitted_index] 183 R.success = result['success'] 184 R.fitness = numpy.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index)) 185 R.convergence = result['convergence'] 186 if result['uncertainty'] is not None: 187 R.uncertainty_state = result['uncertainty'] 188 all_results.append(R) 189 295 190 if q is not None: 296 q.put( result)191 q.put(all_results) 297 192 return q 298 #if success < 1 or success > 5: 299 # result.fitness = None 300 return [result] 301 302 def run_bumps(problem, result, ftol, handler, curr_thread, msg_q): 193 else: 194 return all_results 195 196 def run_bumps(problem, handler, curr_thread): 303 197 def abort_test(): 304 198 if curr_thread is None: return False … … 313 207 fitclass = fitopts.fitclass 314 208 options = fitopts.options.copy() 315 max_step s= fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0)316 if 'monitors' not in options:317 options['monitors'] = [BumpsMonitor(handler, max_steps)]318 options['monitors'] += [ ConvergenceMonitor() ]319 options['ftol'] = ftol209 max_step = fitopts.options.get('steps', 0) + fitopts.options.get('burn', 0) 210 options['monitors'] = [ 211 BumpsMonitor(handler, max_step), 212 ConvergenceMonitor(), 213 ] 320 214 fitdriver = fitters.FitDriver(fitclass, problem=problem, 321 215 abort_test=abort_test, **options) 322 216 mapper = SerialMapper 323 217 fitdriver.mapper = mapper.start_mapper(problem, None) 218 import time; T0 = time.time() 324 219 try: 325 220 best, fbest = fitdriver.fit() … … 329 224 finally: 330 225 mapper.stop_mapper(fitdriver.mapper) 331 #print "best,fbest",best,fbest,problem.dof 332 result.fitness = 2*fbest/problem.dof 333 #print "fitness",result.fitness 334 result.stderr = fitdriver.stderr() 335 result.pvec = best 336 # TODO: track success better 337 result.success = True 338 result.theory = problem.theory 339 # For the convergence plot 340 pop = numpy.asarray(options['monitors'][-1].convergence) 341 result.convergence = 2*pop/problem.dof 342 # Bumps uncertainty state 343 try: result.uncertainty_state = fitdriver.fitter.state 344 except AttributeError: pass 345 346 def run_levenburg_marquardt(problem, result, ftol): 347 # This import must be here; otherwise it will be confused when more 348 # than one thread exist. 349 from scipy import optimize 350 351 out, cov_x, _, mesg, success = optimize.leastsq(problem.residuals, 352 problem.getp(), 353 ftol=ftol, 354 full_output=1) 355 if cov_x is not None and numpy.isfinite(cov_x).all(): 356 stderr = numpy.sqrt(numpy.diag(cov_x)) 357 else: 358 stderr = [] 359 result.fitness = problem.chisq() 360 result.stderr = stderr 361 result.pvec = out 362 result.success = success 363 result.theory = problem.theory 364 226 227 228 convergence_list = options['monitors'][-1].convergence 229 convergence = (2*numpy.asarray(convergence_list)/problem.dof 230 if convergence_list else numpy.empty((0,1),'d')) 231 return { 232 'value': best, 233 'stderr': fitdriver.stderr(), 234 'success': True, # better success reporting in bumps 235 'convergence': convergence, 236 'uncertainty': getattr(fitdriver.fitter, 'state', None), 237 } 238 -
src/sans/fit/Fitting.py
r6fe5100 re3efa6b3 32 32 33 33 """ 34 def __init__(self, engine='scipy' ):34 def __init__(self, engine='scipy', *args, **kw): 35 35 """ 36 36 """ … … 38 38 self._engine = None 39 39 self.fitter_id = None 40 self.set_engine(engine )40 self.set_engine(engine, *args, **kw) 41 41 42 42 def __setattr__(self, name, value): … … 55 55 self.__dict__[name] = value 56 56 57 def set_engine(self, word ):57 def set_engine(self, word, *args, **kw): 58 58 """ 59 59 Select the type of Fit … … 66 66 """ 67 67 try: 68 self._engine = ENGINES[word]( )68 self._engine = ENGINES[word](*args, **kw) 69 69 except KeyError, exc: 70 70 raise KeyError("fit engine should be one of scipy, park or bumps") -
src/sans/perspectives/fitting/fit_thread.py
ra855fec re3efa6b3 18 18 19 19 def __init__(self, 20 21 22 23 24 batch_inputs=None,25 20 fn, 21 page_id, 22 handler, 23 batch_outputs, 24 batch_inputs=None, 25 pars=None, 26 26 completefn = None, 27 27 updatefn = None, … … 30 30 ftol = None, 31 31 reset_flag = False): 32 CalcThread.__init__(self,completefn, 32 CalcThread.__init__(self, 33 completefn, 33 34 updatefn, 34 35 yieldtime, … … 80 81 list_map_get_attr.append(map_getattr) 81 82 #from multiprocessing import Pool 82 inputs = zip(list_map_get_attr, self.fitter, list_fit_function,83 83 inputs = zip(list_map_get_attr, self.fitter, list_fit_function, 84 list_q, list_q, list_handler,list_curr_thread,list_ftol, 84 85 list_reset_flag) 85 86 result = map(map_apply, inputs) … … 87 88 self.complete(result=result, 88 89 batch_inputs=self.batch_inputs, 89 90 batch_outputs=self.batch_outputs, 90 91 page_id=self.page_id, 91 92 pars = self.pars,
Note: See TracChangeset
for help on using the changeset viewer.